www.pudn.com > TextClassify.rar > Classifier.cpp
// Classifier.cpp: implementation of the CClassifier class. // ////////////////////////////////////////////////////////////////////// #include "stdafx.h" #include "TextClassify.h" #include "Classifier.h" #include "WordSegment.h" #include "Message.h" #include#include #ifdef _DEBUG #undef THIS_FILE static char THIS_FILE[]=__FILE__; #define new DEBUG_NEW #endif CClassifier theClassifier; const DWORD CClassifier::dwModelFileID=0xFFEFFFFF; ////////////////////////////////////////////////////////////////////// // Construction/Destruction ////////////////////////////////////////////////////////////////////// CClassifier::CClassifier() { n_Type=-1; m_pDocs=NULL; m_pSimilarityRatio=NULL; m_pProbability=NULL; m_lDocNum=0; m_nClassNum=0; } CClassifier::~CClassifier() { } //参数bGenDic=false代表无需重新扫描文档得到训练文档集中所有特征,一般在层次分类时使用 //参数nType用来决定分类模型的类别,nType=0代表KNN分类器,nType=1代表SVM分类器 bool CClassifier::Train(int nType, bool bFlag) { this->n_Type=nType; CTime startTime; CTimeSpan totalTime; if(bFlag) { InitTrain(); //生成所有候选特征项,将其保存在m_lstWordList中 GenDic(); } CMessage::PrintStatusInfo(""); if(m_lstWordList.GetCount()==0) return false; if(m_lstTrainCatalogList.GetCataNum()==0) return false; //清空特征项列表m_lstTrainWordList m_lstTrainWordList.InitWordList(); //为特征项列表m_lstWordList中的每个特征加权 CMessage::PrintInfo(_T("开始计算候选特征集中每个特征的类别区分度,请稍候...")); startTime=CTime::GetCurrentTime(); FeatherWeight(m_lstWordList); totalTime=CTime::GetCurrentTime()-startTime; CMessage::PrintInfo(_T("特征区分度计算结束,耗时")+totalTime.Format("%H:%M:%S")); CMessage::PrintStatusInfo(""); //从特征项列表m_lstWordList中选出最优特征 CMessage::PrintInfo(_T("开始进行特征选择,请稍候...")); startTime=CTime::GetCurrentTime(); FeatherSelection(m_lstTrainWordList); //为最优特征集m_lstTrainWordList中的每个特征建立一个ID m_lstTrainWordList.IndexWord(); totalTime=CTime::GetCurrentTime()-startTime; CMessage::PrintInfo(_T("特征选择结束,耗时")+totalTime.Format("%H:%M:%S")); CMessage::PrintStatusInfo(""); // 清空m_lstWordList,释放它占用的空间 m_lstWordList.InitWordList(); CMessage::PrintInfo("开始生成文档向量,请稍候..."); startTime=CTime::GetCurrentTime(); GenModel(); totalTime=CTime::GetCurrentTime()-startTime; CMessage::PrintInfo(_T("文档向量生成结束,耗时")+totalTime.Format("%H:%M:%S")); CMessage::PrintStatusInfo(""); CMessage::PrintInfo("开始保存分类模型,请稍候..."); startTime=CTime::GetCurrentTime(); WriteModel(m_paramClassifier.m_txtResultDir+"\\model.prj",nType); totalTime=CTime::GetCurrentTime()-startTime; CMessage::PrintInfo(_T("保存分类模型结束,耗时")+totalTime.Format("%H:%M:%S")); //训练SVM分类器必须在保存训练文档的文档向量后进行 if(nType == 1) { CMessage::PrintInfo("开始训练SVM,请稍候..."); m_lstTrainCatalogList.InitCatalogList(2); //删除文档向量所占用的空间 startTime=CTime::GetCurrentTime(); TrainSVM(); totalTime=CTime::GetCurrentTime()-startTime; CMessage::PrintInfo(_T("SVM分类器训练结束,耗时")+totalTime.Format("%H:%M:%S")); CMessage::PrintStatusInfo(""); } //为分类做好准备,否则不能进行分类 Prepare(); CMessage::PrintStatusInfo(""); return TRUE; } void CClassifier::TrainSVM() { CString str; CTime tmStart; CTimeSpan tmSpan; m_paramClassifier.m_strModelFile="model"; for(int i=1;i<=m_lstTrainCatalogList.GetCataNum();i++) { tmStart=CTime::GetCurrentTime(); str.Format("正在训练第%d个SVM分类器,请稍侯...",i); CMessage::PrintInfo(str); m_theSVM.com_param.trainfile=m_paramClassifier.m_txtResultDir+"\\train.txt"; m_theSVM.com_param.modelfile.Format("%s\\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i); m_theSVM.svm_learn_main(i); tmSpan=CTime::GetCurrentTime()-tmStart; str.Format("第%d个SVM分类器训练完成,耗时%s!",i,tmSpan.Format("%H:%M:%S")); CMessage::PrintInfo(str); } } void CClassifier::TrainBAYES() { /* CString str; CTime tmStart; CTimeSpan tmSpan; m_paramClassifier.m_strModelFile="model"; for(int i=1;i<=m_lstTrainCatalogList.GetCataNum();i++) { tmStart=CTime::GetCurrentTime(); str.Format("正在训练第%d个SVM分类器,请稍侯...",i); CMessage::PrintInfo(str); m_theSVM.com_param.trainfile=m_paramClassifier.m_txtResultDir+"\\train.txt"; m_theSVM.com_param.modelfile.Format("%s\\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i); m_theSVM.svm_learn_main(i); tmSpan=CTime::GetCurrentTime()-tmStart; str.Format("第%d个SVM分类器训练完成,耗时%s!",i,tmSpan.Format("%H:%M:%S")); CMessage::PrintInfo(str); } */ } // fill an array of CTrain::sSortType (train word length) // nCatalog mean the value of element of the array is the weight // of nCatalog(as an index of catalog) for each individual word // if nCatalog==-1 then sum weight for all catalog void CClassifier::GenSortBuf(CWordList& wordList,sSortType *psSortBuf,int nCatalog) { int nTotalCata=m_lstTrainCatalogList.GetCataNum(); int i; ASSERT(nCatalog n_Type==2) for(i=0;i m_pCataWeightPro[i]=wordnode.m_pCataWeightPro[i]; // CString strtemp; // strtemp.Format("123 %f",psSortBuf[lWordCount].pclsWordNode->m_pCataWeightPro[i]); // CMessage::PrintInfo(strtemp); } lWordCount++; } } //从m_lstWordList选出最优特征子集,存到dstWordList中 void CClassifier::FeatherSelection(CWordList& dstWordList) { if(m_lstWordList.GetCount()<=0) return; dstWordList.InitWordList(); m_lstWordList.IndexWord(); sSortType *psSortBuf; int nDistinctWordNum = m_lstWordList.GetCount(); psSortBuf = new sSortType[nDistinctWordNum ]; // the distinct number of the word ASSERT(psSortBuf!=NULL); long lDocNum=m_lstTrainCatalogList.GetDocNum(); for(int i=0;i ComputeWeight(lDocNum); else if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF_DIFF) psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum,true); wordNode.m_dWeight=psSortBuf[j].pclsWordNode->m_dWeight; wordNode.m_lDocFreq=psSortBuf[j].pclsWordNode->m_lDocFreq; wordNode.m_lWordFreq=psSortBuf[j].pclsWordNode->m_lWordFreq; dstWordList.SetAt(psSortBuf[j].word,wordNode); nSelectWordNum++; } } } // total selecting else //if(m_paramClassifier.m_nSelMode==CClassifierParam::nFSM_GolbalMode) { int iWord=0; GenSortBuf(m_lstWordList,psSortBuf,-1);//-1 mean sum all catalog Sort(psSortBuf,nDistinctWordNum-1); int nSelectWordNum=m_paramClassifier.m_nWordSize; if (nSelectWordNum>nDistinctWordNum) nSelectWordNum=nDistinctWordNum; for(i=0;i ComputeWeight(lDocNum); else if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF_DIFF) psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum,true); wordNode.m_dWeight=psSortBuf[i].pclsWordNode->m_dWeight; wordNode.m_lDocFreq=psSortBuf[i].pclsWordNode->m_lDocFreq; wordNode.m_lWordFreq=psSortBuf[i].pclsWordNode->m_lWordFreq; if(this->n_Type==2) // 拷贝词在不同类中的概率 for(int k=0;k m_pCataWeightPro[k]; wordNode.m_pCataWeightPro[k] = psSortBuf[i].pclsWordNode->m_pCataWeightPro[k]; // CString strtemp; // strtemp.Format("123 %f",wordNode.m_pCataWeightPro[k]); // CMessage::PrintInfo(strtemp); } dstWordList.SetAt(psSortBuf[i].word,wordNode); } } delete [] psSortBuf; } void CClassifier::FeatherWeight(CWordList& wordList) { // ------------------------------------------------------------------------------ // based on document number model int N; //总的文档数; int N_c; //C类的文档数 int N_ft; //含有ft的文档数 int N_c_ft; //C类中含有ft的文档数 // ------------------------------------------------------------------------------ // based on word number model long N_W; //总的词数 m_lWordNum; long N_W_C; //C类词数 CCatalogNode.m_lTotalWordNum; long N_W_f_t; //f_t出现的总次数 long N_W_C_f_t;//C类中f_t出现的次数 // ------------------------------------------------------------------------------ double P_c_ft,P_c_n_ft,P_n_c_ft,P_n_c_n_ft; double P_c,P_n_c; double P_ft,P_n_ft; // ------------------------------------------------------------------------------ POSITION pos_cata,pos_word; CString strWord; // calculate the weight of each word to all catalog N = m_lstTrainCatalogList.GetDocNum(); N_W = wordList.GetWordNum(); int nTotalCata=m_lstTrainCatalogList.GetCataNum(); pos_word = wordList.GetFirstPosition(); while(pos_word!= NULL) // for each word { CWordNode& wordnode = wordList.GetNext(pos_word,strWord); wordnode.m_dWeight=0; ASSERT(wordnode.m_pCataWeight); ASSERT(wordnode.m_pCataWeightPro); CMessage::PrintStatusInfo("特征:"+strWord+"..."); N_ft = wordnode.GetDocNum(); N_W_f_t = wordnode.GetWordNum(); int nCataCount=0; pos_cata = m_lstTrainCatalogList.GetFirstPosition(); while(pos_cata!=NULL) // for each catalog { CCatalogNode& catanode = m_lstTrainCatalogList.GetNext(pos_cata); N_c = catanode.GetDocNum(); N_W_C = catanode.m_lTotalWordNum; N_c_ft = wordnode.GetCataDocNum(catanode.m_idxCata); N_W_C_f_t =wordnode.GetCataWordNum(catanode.m_idxCata); // calculation model if(m_paramClassifier.m_nOpMode==CClassifierParam::nOpWordMode) { P_c = 1.0 * N_W_C /N_W; P_ft = 1.0 * N_W_f_t/N_W; P_c_ft = 1.0 * N_W_C_f_t/N_W; } else //if(m_paramClassifier.m_nOpMode==CClassifierParam::nOpDocMode) { P_c = 1.0 * N_c /N; //C类出现的概率 P_ft = 1.0 * N_ft/N; //含有ft的文档出现的概率 P_c_ft = 1.0 * N_c_ft/N; //C类中含有ft的文档的概率 } P_n_c = 1 - P_c; P_n_ft = 1 - P_ft; P_n_c_ft = P_ft - P_c_ft; P_c_n_ft = P_c - P_c_ft; P_n_c_n_ft = P_n_ft - P_c_n_ft; wordnode.m_pCataWeight[nCataCount]=0; // feature selection model if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_XXMode) { // Right half of IG if ( (fabs(P_c * P_n_ft) > dZero) && ( fabs(P_c_n_ft) > dZero) ) { wordnode.m_pCataWeight[nCataCount]+=P_c_n_ft * log( P_c_n_ft/(P_c * P_n_ft) ); } } else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_MIMode) { // Mutual Informaiton feature selection if ( (fabs(P_c * P_ft) > dZero) && (fabs(P_c_ft) > dZero) ) { wordnode.m_pCataWeight[nCataCount]+= P_c * log( P_c_ft/(P_c * P_ft) ); } } else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_CEMode) { // Cross Entropy for text feature selection if ( (fabs(P_c * P_ft) > dZero) && (fabs(P_c_ft) > dZero) ) { wordnode.m_pCataWeight[nCataCount]+= P_c_ft * log( P_c_ft/(P_c * P_ft) ); } } else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_X2Mode) { // X^2 Statistics feature selection if ( (fabs(P_n_c * P_ft * P_n_ft) > dZero) ) { wordnode.m_pCataWeight[nCataCount]+= (P_c_ft * P_n_c_n_ft - P_n_c_ft * P_c_n_ft) * (P_c_ft * P_n_c_n_ft - P_n_c_ft * P_c_n_ft) / ( P_ft * P_n_c * P_n_ft); } } else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_WEMode) { // Weight of Evielence for text feature selection double odds_c_ft; double odds_c; double P_c_inv_ft=P_c_ft/P_ft; if( fabs(P_c_inv_ft) < dZero ) odds_c_ft = 1.0 / ( N * N -1); else if ( fabs(P_c_inv_ft-1) < dZero ) odds_c_ft = N * N -1; else odds_c_ft = P_c_inv_ft / (1.0 - P_c_inv_ft); if( fabs(P_c) < dZero ) odds_c = 1.0 / ( N * N -1); else if ( fabs(P_c-1) < dZero ) odds_c = N * N -1; else odds_c = P_c / (1.0 - P_c); if( fabs(odds_c) > dZero && fabs(odds_c_ft) > dZero ) { wordnode.m_pCataWeight[nCataCount]+= P_c * P_ft * fabs( log(odds_c_ft / odds_c) ); } } else //if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_IGMode) { // Information gain feature selection if ( (fabs(P_c * P_n_ft) > dZero) && ( fabs(P_c_n_ft) > dZero) ) { wordnode.m_pCataWeight[nCataCount]+=P_c_n_ft * log( P_c_n_ft/(P_c * P_n_ft) ); } if ( (fabs(P_c * P_ft) > dZero) && (fabs(P_c_ft) > dZero) ) { wordnode.m_pCataWeight[nCataCount]+= P_c_ft * log( P_c_ft/(P_c * P_ft) ); } } wordnode.m_dWeight+=wordnode.m_pCataWeight[nCataCount]; wordnode.m_pCataWeightPro[nCataCount] = 1.0 * (N_c_ft+1)/(2+N);//词str属于类别nCataCount的概率 /* CString strtemp; strtemp.Format("第%d个类中,词的权重是%lf",nCataCount,wordnode.m_pCataWeight[nCataCount]); CMessage::PrintInfo(strtemp); */ nCataCount++; } ASSERT(nCataCount==nTotalCata); } CMessage::PrintStatusInfo(""); } //计算每一篇训练文档向量的每一维的权重 void CClassifier::ComputeWeight(bool bMult) { long lWordNum=m_lstTrainWordList.GetCount(); if(m_lstTrainWordList.GetCount()<=0) return; long lDocNum=m_lstTrainCatalogList.GetDocNum(); if(lDocNum<=0) return; m_lstTrainWordList.ComputeWeight(lDocNum,bMult); double sum=0.0; int i=0; POSITION pos_cata = m_lstTrainCatalogList.GetFirstPosition(); while(pos_cata != NULL) // for each catalog { //取类列表中的每一个类 CCatalogNode& cataNode = m_lstTrainCatalogList.GetNext(pos_cata); POSITION pos_doc = cataNode.GetFirstPosition(); while(pos_doc!=NULL) { CDocNode& docNode=cataNode.GetNext(pos_doc); sum=0.0; for(i=0;i Mid) Lo++; while(psData[Hi].dWeight < Mid) Hi--; if(Lo <= Hi) { t = psData[Lo]; psData[Lo]=psData[Hi]; psData[Hi]=t; Lo++; Hi--; } }while(Hi>Lo); if(Hi > iLo) QuickSort(psData, iLo, Hi); if(Lo < iHi) QuickSort(psData, Lo, iHi); } void CClassifier::Sort(sSortType *psData,int nSize) { QuickSort(psData,0,nSize); } // Give m_lstWordList & m_lstTrainCatalogList // Output the present vector of each document; // bFlag=false 层次分类的时候使用 void CClassifier::GenModel() { CDocNode::AllocTempBuffer(m_lstTrainWordList.GetCount()); POSITION pos_cata = m_lstTrainCatalogList.GetFirstPosition(); while(pos_cata != NULL) // for each catalog { //取类列表中的每一个类 CCatalogNode& cataNode = m_lstTrainCatalogList.GetNext(pos_cata); POSITION pos_doc = cataNode.GetFirstPosition(); while(pos_doc!=NULL) { CDocNode& docNode=cataNode.GetNext(pos_doc); if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese) docNode.ScanChineseWithDict(cataNode.m_strDirName.GetBuffer(0),m_lstTrainWordList); else docNode.ScanEnglishWithDict(cataNode.m_strDirName.GetBuffer(0),m_lstTrainWordList,m_paramClassifier.m_bStem); docNode.GenDocVector(); CMessage::PrintStatusInfo("生成文档"+docNode.m_strDocName+"的文档向量"); } } CDocNode::DeallocTempBuffer(); } // generate original dictionary (the largest one) // form train files bool CClassifier::GenDic() { m_lstWordList.InitWordList(); CTime startTime; CTimeSpan totalTime; startTime=CTime::GetCurrentTime(); CMessage::PrintInfo(_T("分词程序初始化,请稍候...")); if(!g_wordSeg.InitWorgSegment(theApp.m_strPath.GetBuffer(0),m_paramClassifier.m_nLanguageType)) { CMessage::PrintError(_T("分词程序初始化失败!")); return false; } if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese) g_wordSeg.SetSegSetting(CWordSegment::uPlace); totalTime=CTime::GetCurrentTime()-startTime; CMessage::PrintInfo(_T("分词程序初始化结束,耗时")+totalTime.Format("%H:%M:%S")); startTime=CTime::GetCurrentTime(); CMessage::PrintInfo(_T("开始扫描训练文档,请稍候...")); if(m_lstTrainCatalogList.BuildLib(m_paramClassifier.m_txtTrainDir)<=0) { CMessage::PrintError("训练文档的总数为0!"); return false; } CString strFileName; POSITION pos_cata=m_lstTrainCatalogList.GetFirstPosition(); int nCount,nCataNum; nCataNum=m_lstTrainCatalogList.GetCataNum(); while(pos_cata!=NULL) { CCatalogNode& catalognode=m_lstTrainCatalogList.GetNext(pos_cata); POSITION pos_doc=catalognode.GetFirstPosition(); while(pos_doc!=NULL) { CDocNode& docnode=catalognode.GetNext(pos_doc); CMessage::PrintStatusInfo(_T("扫描文档")+docnode.m_strDocName); if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese) nCount=docnode.ScanChinese(catalognode.m_strDirName.GetBuffer(0), m_lstWordList,nCataNum,catalognode.m_idxCata); else nCount=docnode.ScanEnglish(catalognode.m_strDirName.GetBuffer(0), m_lstWordList,nCataNum,catalognode.m_idxCata, m_paramClassifier.m_bStem); if(nCount==0) { CMessage::PrintError("文件"+catalognode.m_strDirName+"\\"+docnode.m_strDocName+"无内容!"); continue; } else if(nCount<0) { CMessage::PrintError("文件"+catalognode.m_strDirName+"\\"+docnode.m_strDocName+"无法打开!"); continue; } catalognode.m_lTotalWordNum+=nCount;// information collection point } } g_wordSeg.FreeWordSegment(); totalTime=CTime::GetCurrentTime()-startTime; CMessage::PrintInfo(_T("扫描训练文档结束,耗时")+totalTime.Format("%H:%M:%S")); return true; } void CClassifier::InitTrain() { m_lstTrainWordList.InitWordList(); m_lstTrainCatalogList.InitCatalogList(); m_lstWordList.InitWordList(); } //参数nType用来决定分类模型的类别,nType=0代表KNN分类器,nType=1代表SVM分类器 bool CClassifier::WriteModel(CString strFileName, int nType) { CFile fOut; if( !fOut.Open(strFileName,CFile::modeCreate | CFile::modeWrite) ) { CMessage::PrintError("无法创建文件"+strFileName+"!"); return false; } CArchive ar(&fOut,CArchive::store); if(nType==0) { m_lstTrainWordList.DumpToFile(m_paramClassifier.m_txtResultDir+"\\features.dat"); m_lstTrainWordList.DumpWordList(m_paramClassifier.m_txtResultDir+"\\features.txt"); m_lstTrainCatalogList.DumpToFile(m_paramClassifier.m_txtResultDir+"\\train.dat"); m_lstTrainCatalogList.DumpDocList(m_paramClassifier.m_txtResultDir+"\\train.txt"); m_paramClassifier.DumpToFile(m_paramClassifier.m_txtResultDir+"\\params.dat"); ar< >dwFileID; if(dwFileID!=dwModelFileID) { ar.Close(); fIn.Close(); CMessage::PrintError("分类模型文件的格式不正确!"); return false; } ar>>str; if(!m_paramClassifier.GetFromFile(strPath+"\\"+str)) { CMessage::PrintError(_T("无法打开训练参数文件"+str+"!")); return false; } m_paramClassifier.m_txtResultDir=strPath; if(m_paramClassifier.m_nClassifierType==0) { ar>>str; m_lstTrainWordList.InitWordList(); if(!m_lstTrainWordList.GetFromFile(strPath+"\\"+str)) { CMessage::PrintError(_T("无法打开特征类表文件"+str+"!")); return false; } ar>>str; m_lstTrainCatalogList.InitCatalogList(); if(!m_lstTrainCatalogList.GetFromFile(strPath+"\\"+str)) { CMessage::PrintError(_T("无法打开训练文档列表文件"+str+"!")); return false; } } else if(m_paramClassifier.m_nClassifierType==1) { ar>>str; m_lstTrainWordList.InitWordList(); if(!m_lstTrainWordList.GetFromFile(strPath+"\\"+str)) { CMessage::PrintError(_T("无法打开特征类表文件"+str+"!")); return false; } //对于SVM分类起来说m_lstTrainCatalogList其实没用 //此处读入它只是为了在CLeftViw中显示某些统计信息时使用 ar>>str; m_lstTrainCatalogList.InitCatalogList(); if(!m_lstTrainCatalogList.GetFromFile(strPath+"\\"+str)) { CMessage::PrintError(_T("无法打开训练文档列表文件"+str+"!")); return false; } ar>>str; if(!m_theSVM.com_param.GetFromFile(strPath+"\\"+str)) { CMessage::PrintError(_T("无法打开SVM训练参数文件"+str+"!")); return false; } m_theSVM.com_param.trainfile=strPath+"\\train.txt"; m_theSVM.com_param.resultpath=strPath; } else if(m_paramClassifier.m_nClassifierType==2) { ar>>str; m_lstTrainWordList.InitWordList(); if(!m_lstTrainWordList.GetFromFile(strPath+"\\"+str)) { CMessage::PrintError(_T("无法打开特征类表文件"+str+"!")); return false; } ar>>str; m_lstTrainCatalogList.InitCatalogList(); if(!m_lstTrainCatalogList.GetFromFile(strPath+"\\"+str)) { CMessage::PrintError(_T("无法打开训练文档列表文件"+str+"!")); return false; } ar>>str; if(!m_lstTrainWordList.GetProFromFile(strPath+"\\"+str)) { CMessage::PrintError(_T("无法打开特征词类属概率文件"+str+"!")); return false; } } ar.Close(); fIn.Close(); Prepare(); CTimeSpan totalTime=CTime::GetCurrentTime()-startTime; CMessage::PrintInfo(_T("分类模型文件已经打开,耗时")+totalTime.Format("%H:%M:%S")+"\r\n"); str.Empty(); m_paramClassifier.GetParamString(str); CMessage::PrintInfo(str); return true; } bool CClassifier::Classify() { m_lstTrainCatalogList.DumpCataList(m_paramClassifier.m_strResultDir+"\\classes.txt"); CTime startTime; CTimeSpan totalTime; startTime=CTime::GetCurrentTime(); CMessage::PrintInfo(_T("正在扫描测试文档,请稍候...")); if(m_paramClassifier.m_bEvaluation) { //vBuildLib方法中已经清空了g_lstTestCatalogList,所以此处无需再对其初始化 m_lstTestCatalogList.BuildLib(m_paramClassifier.m_strTestDir); if(!m_lstTestCatalogList.BuildCatalogID(m_lstTrainCatalogList)) { CMessage::PrintError("测试文件中包含有无法识别的类别!"); return false; } } else { m_lstTestCatalogList.InitCatalogList(); CCatalogNode catalognode; catalognode.m_strDirName=m_paramClassifier.m_strTestDir; catalognode.m_strCatalogName="测试文档"; catalognode.m_idxCata=-1; POSITION posTemp=m_lstTestCatalogList.AddCata(catalognode); CCatalogNode& cataTemp=m_lstTestCatalogList.GetAt(posTemp); cataTemp.SetStartDocID(0); cataTemp.ScanDirectory(m_paramClassifier.m_strTestDir); } if(m_lstTestCatalogList.GetDocNum()<=0) { CMessage::PrintError("测试文件总数为0!\r\n如果不需要对分类结果进行评价时,分类文档必须在\"分类文档目录\"下,而不是它的子目录下!"); return false; } totalTime=CTime::GetCurrentTime()-startTime; CMessage::PrintInfo(_T("扫描测试文档结束,耗时")+totalTime.Format("%H:%M:%S")); startTime=CTime::GetCurrentTime(); CMessage::PrintInfo(_T("正在对测试文档进行分类,请稍候...")); long lCorrect=0,lUnknown=0; lUnknown=Classify(m_lstTestCatalogList); lCorrect=SaveResults(m_lstTestCatalogList,m_paramClassifier.m_strResultDir+"\\results.txt"); long lTotalNum=m_lstTestCatalogList.GetDocNum()-lUnknown; CString str; totalTime=CTime::GetCurrentTime()-startTime; CMessage::PrintInfo(_T("测试文档分类结束,耗时")+totalTime.Format("%H:%M:%S")); if (lUnknown>0) { str.Format("无法分类的文档数%d:",lUnknown); CMessage::PrintInfo(str); } if(m_paramClassifier.m_bEvaluation&&lTotalNum>0&&lCorrect>0) str.Format("测试文档总数:%d,准确率:%f",m_lstTestCatalogList.GetDocNum(),(float)(lCorrect)/(float)(lTotalNum)); else str.Format("测试文档总数:%d",m_lstTestCatalogList.GetDocNum()); CMessage::PrintInfo(str); return true; } //对Smart格式的文档进行分类 bool CClassifier::ClassifySmart() { m_lstTrainCatalogList.DumpCataList(m_paramClassifier.m_strResultDir+"\\classes.txt"); m_lstTestCatalogList.InitCatalogList(); CCatalogNode catalognode; catalognode.m_strDirName=m_paramClassifier.m_strTestDir; catalognode.m_strCatalogName="测试文档"; catalognode.m_idxCata=-1; POSITION posTemp=m_lstTestCatalogList.AddCata(catalognode); CCatalogNode& cataTemp=m_lstTestCatalogList.GetAt(posTemp); FILE *stream1,*stream2; if( (stream1 = fopen( m_paramClassifier.m_strTestDir, "r" )) == NULL ) { CMessage::PrintError("无法打开文件"+m_paramClassifier.m_strTestDir+"!"); return false; } //如果是SVM分类器,则需要先将所有测试文档转换为向量,保存到文件test.dat if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM) { m_theSVM.com_param.classifyfile=m_paramClassifier.m_strResultDir+"\\test.dat"; if((stream2=fopen(m_theSVM.com_param.classifyfile,"w"))==NULL) { CMessage::PrintError("无法创建测试文档向量文件"+m_theSVM.com_param.classifyfile+"!"); return false; } } CTime startTime; CTimeSpan totalTime; startTime=CTime::GetCurrentTime(); CMessage::PrintInfo(_T("正在对测试文档进行分类,请稍候...")); char fname[10],type[1024],line[4096],content[100*1024]; //falg=1 下一行的内容是文档的类别 //flag=2 下一行的内容是文档的标题 //flag=3 下一行的内容是文档的内容 int flag=0,nCount,len,i; long lUnknown=0,lDocNum=0; CStringArray typeArray; CString strFileName,strCopyFile; bool bTitle=false; //是否已经读出标题 double dThreshold=(double)m_paramClassifier.m_dThreshold/100.0; int nWordNum=m_lstTrainWordList.GetCount(); while(!feof(stream1)) { if(fgets(line,4096,stream1)==NULL) continue; if(line[0]=='.') { if(flag==3) { CDocNode doc; posTemp=cataTemp.AddDoc(doc); CDocNode& docnode=cataTemp.GetAt(posTemp); docnode.m_strDocName=fname; if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN) { nCount=KNNCategory(content,docnode,false); } else { if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese) nCount=docnode.ScanChineseStringWithDict(content,m_lstTrainWordList)-1; else nCount=docnode.ScanEnglishStringWithDict(content,m_lstTrainWordList, m_paramClassifier.m_bStem)-1; fprintf(stream2,"%d",1); for(i=0;i 0) { str.Format("无法分类的文档数%d:",lUnknown); CMessage::PrintInfo(str); } if(m_paramClassifier.m_bEvaluation&&lTotalNum>0&&lCorrect>0) str.Format("测试文档总数:%d,准确率:%f",m_lstTestCatalogList.GetDocNum(),(float)(lCorrect)/(float)(lTotalNum)); else str.Format("测试文档总数:%d",m_lstTestCatalogList.GetDocNum()); CMessage::PrintInfo(str); return true; } //对文档进行分类,计算文档和每个类别的相似度,返回值为类别无法识别的文档总数 long CClassifier::SVMClassify(CCatalogList &cataList) { long lUnknown=0; FILE *stream; m_theSVM.com_param.classifyfile=m_paramClassifier.m_strResultDir+"\\test.dat"; if((stream=fopen(m_theSVM.com_param.classifyfile,"w"))==NULL) { CMessage::PrintError("无法创建测试文档向量文件"+m_theSVM.com_param.classifyfile+"!"); return 0; } CTime startTime; CTimeSpan totalTime; CString str; int nCount=0; long lWordNum=m_lstTrainWordList.GetCount(); startTime=CTime::GetCurrentTime(); CMessage::PrintInfo(_T("正在生成测试文档的向量形式,请稍候...")); POSITION pos_cata=cataList.GetFirstPosition(); while(pos_cata!=NULL) { CCatalogNode& catalognode=cataList.GetNext(pos_cata); char *path=catalognode.m_strDirName.GetBuffer(0); POSITION pos_doc=catalognode.GetFirstPosition(); while(pos_doc!=NULL) { CDocNode& docnode=catalognode.GetNext(pos_doc); if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese) nCount=docnode.ScanChineseWithDict(path,m_lstTrainWordList); else nCount=docnode.ScanEnglishWithDict(path,m_lstTrainWordList, m_paramClassifier.m_bStem); fprintf(stream,"%d",catalognode.m_idxCata+1); for(int i=0;i 0)&&(nCount>0)) { ComputeSimRatio(docNode,nCmpType); return true; } else return false; } //得到与文档docNode的相似度大于阈值dThreshold的所有类别 //如果没有大于值阈值的类别,则返回相似度最大的类别 bool CClassifier::MultiCategory(CDocNode &docNode, CArray & aryResult, double dThreshold) { double *pSimRatio=docNode.m_pResults; if(pSimRatio==NULL) return false; double dMax=pSimRatio[0]; short nMax=0; bool bFound=false; aryResult.RemoveAll(); for(short i=1;i dMax) { dMax=pSimRatio[i]; nMax=i; } if(pSimRatio[i]>dThreshold) { aryResult.Add(i); bFound=true; } } if(!bFound) aryResult.Add(nMax); return true; } //计算文档docNode和每一个类别的相似度 //nCmpType代表相似度的不同计算方式 void CClassifier::ComputeSimRatio(CDocNode &docNode,int nCmpType) { //计算文档与训练集中每一篇文档的相似度 int i; long k; for(i=0;i ComputeSimilarityRatio(); m_pSimilarityRatio[i].lDocID=i; } //将测试文档与训练文档集中文档的相似度进行降序排序 Sort(m_pSimilarityRatio,m_lDocNum-1); docNode.AllocResultsBuffer(m_nClassNum); double *pSimRatio=docNode.m_pResults; for(i=0;i 0)&&(nCount>0)) { DOC doc; CString str; docNode.GenDocVector(doc); docNode.AllocResultsBuffer(m_nClassNum); for(int i=0;i m_lDocNum) m_paramClassifier.m_nKNN=m_lDocNum; m_pSimilarityRatio=new DocWeight[m_lDocNum]; m_pProbability=new DocWeight[m_nClassNum]; m_pDocs=(DocCatalog*)malloc(sizeof(DocCatalog)*m_lDocNum); long num=0; POSITION pos_cata = m_lstTrainCatalogList.GetFirstPosition(); while(pos_cata != NULL) // for each catalog { CCatalogNode& catanode = m_lstTrainCatalogList.GetNext(pos_cata); short idxCata=catanode.m_idxCata; POSITION pos_doc = catanode.GetFirstPosition(); while(pos_doc!=NULL) { CDocNode& docnode=catanode.GetNext(pos_doc); m_pDocs[num].pDocNode=&docnode; m_pDocs[num].nCataID=idxCata; num++; } } CDocNode::AllocTempBuffer(m_lstTrainWordList.GetCount()); } void CClassifier::Sort(DocWeight *pData,int nSize) { QuickSort(pData,0,nSize); } void CClassifier::QuickSort(DocWeight *psData, int iLo,int iHi) { int Lo, Hi; double Mid; DocWeight t; Lo = iLo; Hi = iHi; Mid = psData[(Lo + Hi)/2].dWeight; do { while(psData[Lo].dWeight > Mid) Lo++; while(psData[Hi].dWeight < Mid) Hi--; if(Lo <= Hi) { t = psData[Lo]; psData[Lo]=psData[Hi]; psData[Hi]=t; Lo++; Hi--; } }while(Hi>Lo); if(Hi > iLo) QuickSort(psData, iLo, Hi); if(Lo < iHi) QuickSort(psData, Lo, iHi); } //将分类结果保存到文件strFileName中,返回正确分类的文档总数 //如果分类参数中要求拷贝文件到结果类别目录,则执行拷贝操作 //参数typeArray只有在多类分类,且需要进行评价的时候才会用到 long CClassifier::SaveResults(CCatalogList &cataList, CString strFileName, CStringArray *aryType) { FILE *stream; if( (stream = fopen(strFileName, "w+" )) == NULL ) { CMessage::PrintError("无法创建分类结果文件"+strFileName+"!"); return -1; } CString str1,str2; long lCorrect=0; long docID=0; int i; char path[MAX_PATH]; CArray aryResult; CArray aryAnswer; double dThreshold=(double)m_paramClassifier.m_dThreshold/100.0; POSITION pos_cata=cataList.GetFirstPosition(); while(pos_cata!=NULL) { CCatalogNode& cataNode=cataList.GetNext(pos_cata); short id=cataNode.m_idxCata; strcpy(path,cataNode.m_strDirName.GetBuffer(0)); POSITION pos_doc=cataNode.GetFirstPosition(); while(pos_doc!=NULL) { CDocNode& docNode=cataNode.GetNext(pos_doc); if(docNode.m_nCataID<0) continue; str1.Empty(); str2.Empty(); //如果是多类分类 if(m_paramClassifier.m_nClassifyType==CClassifierParam::nFT_Multi) { MultiCategory(docNode,aryResult,dThreshold); //如果需要将分类结果拷贝到分类结果目录 if(m_paramClassifier.m_bCopyFiles) { for(i=0;i GetAt(docID).GetBuffer(0),aryAnswer); //得到答案字符串 for(i=0;i dMaxNum) { dMaxNum=pSimRatio[i]; nCataID=i; } } docNode.m_nCataID=nCataID; return nCataID; } short CClassifier::SVMCategory(char *pPath, CDocNode &docNode, bool bFile) { short nCataID=-1; if(SVMClassify(pPath,docNode,bFile)) nCataID=SingleCategory(docNode); return nCataID; } short CClassifier::SVMCategory(char *file, bool bFile) { CDocNode docNode; short id=-1; if(bFile) { char *fname=strrchr(file,'\\'); if(fname==NULL) return -1; docNode.m_strDocName=(fname+1); char path[MAX_PATH]; strncpy(path,file,fname-file); path[fname-file]=0; id=SVMCategory(path,docNode,bFile); } else id=SVMCategory(file,docNode,bFile); return id; } void CClassifier::SVMClassifyVectorFile(CString strFileName) { //为了计算分类结果,用来保存每个分类器分类结果的数组 CTime startTime; CTimeSpan totalTime; CString str; long num=m_lstTestCatalogList.GetDocNum(),lDocNum=0; double *fpWeight=new double[num]; POSITION pos_doc, pos_cata; m_theSVM.com_param.classifyfile=strFileName; for(int i=1;i<=m_nClassNum;i++) { memset(fpWeight,0,sizeof(double)*num); startTime=CTime::GetCurrentTime(); str.Format("正在使用第%d个SVM分类器对文档进行分类,请稍候...",i); CMessage::PrintInfo(str); str.Format("%s\\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i); m_theSVM.com_param.modelfile=str; m_theSVM.svm_classify(i,fpWeight); //将文档和当前类别的相似度赋给m_pResults[i-1] lDocNum=0; pos_cata=m_lstTestCatalogList.GetFirstPosition(); while(pos_cata!=NULL) { CCatalogNode& catalognode=m_lstTestCatalogList.GetNext(pos_cata); pos_doc=catalognode.GetFirstPosition(); while(pos_doc!=NULL) { CDocNode& docnode=catalognode.GetNext(pos_doc); docnode.AllocResultsBuffer(m_nClassNum); docnode.m_pResults[i-1]=fpWeight[lDocNum]; lDocNum++; } } totalTime=CTime::GetCurrentTime()-startTime; str.Format("第%d个SVM分类器分类结束,耗时",i); CMessage::PrintInfo(str+totalTime.Format("%H:%M:%S")); } delete[] fpWeight; //计算和文档的相似度最大的类别 pos_cata=m_lstTestCatalogList.GetFirstPosition(); while(pos_cata!=NULL) { CCatalogNode& catalognode=m_lstTestCatalogList.GetNext(pos_cata); pos_doc=catalognode.GetFirstPosition(); while(pos_doc!=NULL) { CDocNode& docnode=catalognode.GetNext(pos_doc); docnode.m_nCataID=SingleCategory(docnode); } } } //计算文档docNode属于每一个类别的概率 void CClassifier::ComputePro(CDocNode &docNode) { //计算文档与训练集中每一类文档的概率 int i,j,l,FeaNum=0; long k; int N; //总的文档数; int N_c; //C类的文档数 int N_Cata; //总类数 N = m_lstTrainCatalogList.GetDocNum(); POSITION pos_cata; CString strWord; // calculate the weight of each word to all catalog N = m_lstTrainCatalogList.GetDocNum(); N_Cata=m_lstTrainCatalogList.GetCataNum(); docNode.AllocResultsBuffer(m_nClassNum); int nCataCount=0; double ClassPro=0.0; pos_cata = m_lstTrainCatalogList.GetFirstPosition(); i=m_lstTrainWordList.GetCount(); POSITION pos=m_lstTrainWordList.GetFirstPosition(); for(l=0;l 0)&&(nCount>0)) { ComputePro(docNode); return true; } else return false; } //计算文档和每个类别的相似度,返回与文档相似度最大的类别ID short CClassifier::BAYESCategory(char *pPath, CDocNode &docNode, bool bFile) { BAYESClassify(pPath,docNode,bFile); return docNode.m_nCataID; } long CClassifier::BAYESClassify(CCatalogList &cataList) { long docID=0,lUnknown=0; CString str; POSITION pos_cata=cataList.GetFirstPosition(); while(pos_cata!=NULL) { CCatalogNode& cataNode=cataList.GetNext(pos_cata); POSITION pos_doc=cataNode.GetFirstPosition(); char *path=cataNode.m_strDirName.GetBuffer(0); while(pos_doc!=NULL) { CDocNode& docNode=cataNode.GetNext(pos_doc); short id=BAYESCategory(path, docNode, true); if(id==-1) { str="无法识别文档"; str+=cataNode.m_strDirName; str+="\\"+docNode.m_strDocName+"的类别!"; CMessage::PrintError(str); lUnknown++; } CMessage::PrintStatusInfo(_T("扫描文档")+docNode.m_strDocName); } } return lUnknown; } short CClassifier::GetCategory(char *file, bool bFile) { short result=-1; if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN) result=KNNCategory(file,bFile); else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM) result=SVMCategory(file,bFile); return result; } short CClassifier::GetCategory(char *path, CDocNode &docNode, bool bFile) { short result=-1; if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN) result=KNNCategory(path,docNode,bFile); else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM) result=SVMCategory(path,docNode,bFile); return result; } bool CClassifier::Classify(char *path, CDocNode &docNode, bool bFile) { bool result=false; if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN) result=KNNClassify(path,docNode,bFile); else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM) result=SVMClassify(path,docNode,bFile); else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_BAYES) result=BAYESClassify(path,docNode,bFile); return result; } long CClassifier::Classify(CCatalogList &cataList) { long lUnknown=0; if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN) lUnknown=KNNClassify(m_lstTestCatalogList); else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM) lUnknown=SVMClassify(m_lstTestCatalogList); else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_BAYES) lUnknown=BAYESClassify(m_lstTestCatalogList); else CMessage::PrintError("无法确定分类器的类型!"); return lUnknown; }