www.pudn.com > 基于VC的神经网络开发程序包(源码).rar > TrainingSet.cpp
#include "../include/TrainingSet.h" #include "../include/Exception.h" #include "../include/File.h" #include "../include/Neuron.h" #include "../include/defines.h" #include#include #include #include using namespace std; namespace annie { TrainingSet::TrainingSet(int in,int out) { this->_nInputs=in; this->_nOutputs=out; } TrainingSet::~TrainingSet() {} void TrainingSet::load_text(const char *filename) { File file; try { file.open(filename); } catch (Exception &e) { string error(getClassName()); error = error + "::" + getClassName() + "() - " + e.what(); throw Exception(error); } string s; s=file.readWord(); if (s.compare(getClassName())) { string error(getClassName()); error = error + "::" + getClassName() + "() - File provided isn't a TrainingSet TEXT_FILE."; throw Exception(error); } while(!file.eof()) { s=file.readWord(); if (!s.compare("INPUTS")) _nInputs=file.readInt(); else if (!s.compare("OUTPUTS")) _nOutputs=file.readInt(); else if (!s.compare("IO_PAIRS")) { int j; VECTOR input,output; while (!file.eof()) { input.clear(); output.clear(); for (j=0;j<_nInputs;j++) input.push_back(file.readDouble()); for (j=0;j<_nOutputs;j++) output.push_back(file.readDouble()); _inputs.push_back(input); _outputs.push_back(output); } } } } void TrainingSet::load_binary(const char *filename) { ifstream file; double version; int i; file.open(filename,ios::binary); if (!file) throw Exception("TrainingSet::load_binary() - Couldn't open file for reading"); file.read((char*)&version,sizeof(version)); if (version!=atof(ANNIE_VERSION)) throw Exception("TrainingSet::load_binary() - Invalid training set file encoutered (invalid version)"); file.read((char*)&_nInputs,sizeof(_nInputs)); file.read((char*)&_nOutputs,sizeof(_nOutputs)); _inputs.clear(); _outputs.clear(); VECTOR v; real tmp; while (!file.eof()) { v.clear(); for (i=0;i<_nInputs;i++) { file.read((char*)&tmp,sizeof(tmp)); v.push_back(tmp); } //Check this!! Why should it be giving EOF on read late? if (file.eof()) break; _inputs.push_back(v); v.clear(); for (i=0;i<_nOutputs;i++) { file.read((char*)&tmp,sizeof(tmp)); v.push_back(tmp); } _outputs.push_back(v); } file.close(); } TrainingSet::TrainingSet(const char *filename, int file_type) { _nInputs=_nOutputs==0; if (file_type == annie::TEXT_FILE) load_text(filename); else if (file_type == annie::BINARY_FILE) load_binary(filename); //else error } void TrainingSet::addIOpair(real *input, real *output) { VECTOR in,out; int i; for (i=0;i<_nInputs;i++) in.push_back(input[i]); for (i=0;i<_nOutputs;i++) out.push_back(output[i]); addIOpair(in,out); } void TrainingSet::addIOpair(VECTOR input, VECTOR output) { _inputs.push_back(input); _outputs.push_back(output); } bool TrainingSet::epochOver() { if (_inputIter==_inputs.end() && _outputIter==_outputs.end()) return true; return false; } void TrainingSet::initialize() { _inputIter=_inputs.begin(); _outputIter=_outputs.begin(); } void TrainingSet::getNextPair(VECTOR &input, VECTOR &desired) { if (_inputIter==_inputs.end()) { string error(getClassName()); error = error + "::getNextPair() - Passed the last I/O pair already. No more left."; throw Exception(error); } input=*_inputIter; desired=*_outputIter; _inputIter++; _outputIter++; } ostream& operator << (std::ostream& s, TrainingSet &T) { VECTOR::iterator it; s< ::iterator ioIn,ioOut; for (ioIn=T._inputs.begin(),ioOut=T._outputs.begin();ioIn!=T._inputs.end();ioIn++,ioOut++) { for (it=ioIn->begin();it!=ioIn->end();it++) s<<(*it)< begin();it!=ioOut->end();it++) s<<(*it)< ::iterator ioIn,ioOut; VECTOR::iterator it; for (ioIn=_inputs.begin(),ioOut=_outputs.begin();ioIn!=_inputs.end();ioIn++,ioOut++) { for (it=ioIn->begin();it!=ioIn->end();it++) file.write((char*)&(*it),sizeof(*it)); for (it=ioOut->begin();it!=ioOut->end();it++) file.write((char*)&(*it),sizeof(*it)); } file.close(); } void TrainingSet::save(const char *filename, int file_type) { if (file_type == TEXT_FILE) save_text(filename); else if (file_type == BINARY_FILE) save_binary(filename); else { string error(getClassName()); error = error + "::save() - Invalid file type specified."; throw Exception(error); } } int TrainingSet::getSize() { return _inputs.size(); } int TrainingSet::getInputSize() { return _nInputs; } int TrainingSet::getOutputSize() { return _nOutputs; } const char * TrainingSet::getClassName() { return "TrainingSet"; } //void //TrainingSet::shuffle() //{ // int size = getSize()-1; // vector< VECTOR >::iterator inIt,outIt; // // int chosen; // while(size>=0) // { // chosen = (int)(fabs(random())*size); // inIt = &_inputs[chosen]; // outIt = &_outputs[chosen]; // // _inputs.push_back(*inIt); // _outputs.push_back(*outIt); // // inIt = _inputs.erase(inIt); // outIt = _outputs.erase(outIt); // size--; // } //} }; //namespace annie