www.pudn.com > firev0.01.rar > distancefileclassify.cpp


/*
This file is part of the FIRE -- Flexible Image Retrieval System

FIRE is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

FIRE is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with FIRE; if not, write to the Free Software
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include "getpot.hpp"
#include "filelist.hpp"
#include "gzstream.hpp"
#include "diag.hpp"

using namespace std;

typedef pair > ER_Weights;
typedef deque< ER_Weights > WeightVector;

typedef  vector< vector< vector  > > DistMatrix;

bool showConfMat;

void print(const WeightVector &v) {
  WeightVector::const_iterator i;
  for(i=v.begin();ifirst << " ";
    for(vector::const_iterator j=i->second.begin();jsecond.end();++j) {
      cout << *j << " ";
    }
    cout << endl;
  }
  cout << endl << endl;
}

bool there(const WeightVector &v, const ER_Weights &e) {
  bool result=false;
  WeightVector::const_iterator i;
  for(i=v.begin();i &e) {
  bool result=false;
  WeightVector::const_iterator i;
  for(i=v.begin();isecond == e) result=true;
  }
  return result;
}


void insert(WeightVector &v, const ER_Weights &e) {
  WeightVector::iterator i;
  for(i=v.begin();ifirst<=e.first;++i);
  v.insert(i,e);
}


double classify(const vector &W, const DistMatrix &D, const int anzClasses,const FileList& fileList, const int nOfDistances) {
    //classify
    
  int nOfImages=fileList.size();
  int correct=0, error=0, classified=0;
  int bestIndex;
  double aktDist, bestDist;
  vector > confusionMatrix(anzClasses, vector(anzClasses,0));
  
  for(int i=0;i::max();

    for(int j=0;j de1, de2, inter;


    if(fileList.withCls()) {
      if(fileList.cls(i) == fileList.cls(bestIndex)) { ++correct; } 
      else { ++error; }
      ++classified;
      ++confusionMatrix[fileList.cls(i)][fileList.cls(bestIndex)];
    } else if (fileList.withDescription()) {
      de1=fileList.decomposedDescription(i);
      de2=fileList.decomposedDescription(bestIndex);
      inter.clear();
      set_intersection(de1.begin(), de1.end(), 
                       de2.begin(), de2.end(),
                       inserter(inter, inter.begin()));
      
      if(inter.size() > 0) {
        ++correct;
      } else {
        ++error;
      }
      ++classified;
    }
    
    DBG(DBG_MESSAGE) << "Classified " << fileList[i] << " correct=" << correct << endl;

  }
  
  double ER=double(error)/double(classified);
  
  cout << "RESULT: error=" << error << " correct=" << correct << " classified=" << classified << " -> ER: " << ER << endl;
  cout << "W[0.."<" << endl
       << "   -distanceFile" << endl
       << "   -weightFile" << endl
       << "   -steepestDescent" << endl
       << "   -weightX " <" << endl
       << "   -filelist to override filelist given in distfile" << endl
       << "   -showconfmat to show confusion matrices" << endl
       << "   -stepSize (0.1)" << endl
       << "   -maxWeight (1.0)" << endl
       << "   -indicesTrain, -indicesText (for crosseval)" << endl
       << endl;
}


int main(int argc, char **argv) {
  GetPot cl(argc, argv);
  
  double stepSize=cl.follow(0.1,"-stepSize");
  double maxWeight=cl.follow(1.0,"-maxWeight");
  
  showConfMat=cl.search("-showconfmat");

  if(cl.search("-h")) {
    USAGE();
    exit(10);
  }
 
  vector W;

  DistMatrix D;
  string keyword, fileListFileName;
  int nOfDistances,i,j;
  FileList fileList;
  vector  namesOfDistances;
  int anzClasses=1, nOfImages;
  string filename=cl.follow("dists.dat","-distanceFile");

  igzstream ifs(filename.c_str());

  //reading matrix file
  if(!ifs) {
    ERR << "Unable to open distfile '"<< filename << "'." << endl;
  } else {
    while(!ifs.eof()) {
      ifs >> keyword;
      if(!ifs.eof()) {
        if(keyword=="#") {
          char devnull[1024];
          ifs.getline(devnull,1024);
        } else if( keyword=="dim") {
          ifs >> nOfDistances;
          namesOfDistances=vector(nOfDistances);
          W=vector(nOfDistances,1.0);
        } else if( keyword=="filename") {
          ifs >> fileListFileName;
          if(cl.search("-filelist")) fileListFileName=cl.follow("nofilelistnamegiven","-filelist");
          fileList.load(fileListFileName);
          nOfImages=fileList.size();
          D=DistMatrix(nOfImages,vector >(nOfImages, vector(nOfDistances,0.0)));
        } else if( keyword=="distnames") {
          for(int i=0;i> namesOfDistances[i];
          }
        } else if( keyword=="dists") {
          ifs >> i >> j;
          for(int k=0;k> D[i][j][k];
          }
        } else {
          ERR << "Unknown keyword :'" << keyword << "'." << endl;
        }
      }
    }
    ifs.close();
  }
  for(unsigned int i=0;ianzClasses) anzClasses=fileList.cls(i)+1;}
  //normalize matrix 
  nOfImages=fileList.size();
  double sum;
  
  for(int i=0;i aktW; double aktER;
      ER_Weights tmpERWeights;
      
      aktER=classify(W,D,anzClasses, fileList, nOfDistances);
      WeightVector toDo(0);
      WeightVector done(0);
      
      toDo.push_back(ER_Weights(aktER,W));
      
      while(!toDo.empty()) {
        //           cout << "TODO:" << endl;
        //           print(toDo);
        //           cout << "DONE:" << endl;
        //           print(done);
        
        cout << "STARTING FROM ER=" << toDo[0].first << endl;
        W=toDo[0].second;
        insert(done,toDo[0]);
        toDo.pop_front();
        
        vector bestW;
        for(int i=0;i=0) { //no weights < 0 allowed
            aktW[i]-=stepSize;
            if(!there(toDo,aktW) && !there(done,aktW)) {
              aktER=classify(aktW,D,anzClasses,fileList,nOfDistances);
              tmpERWeights=ER_Weights(aktER,aktW);
              insert(toDo,tmpERWeights); 
            }
            if(writingToFile) {
              ofs << aktER ;
              for(int w=0;w> W[i];
        }
        aktER=classify(W,D,anzClasses,fileList,nOfDistances);
        if(writingToFile) {
          ofs << aktER ;
          for(int w=0;w