www.pudn.com > TextClassify.rar > Classifier.cpp


// Classifier.cpp: implementation of the CClassifier class. 
// 
////////////////////////////////////////////////////////////////////// 
 
#include "stdafx.h" 
#include "TextClassify.h" 
#include "Classifier.h" 
#include "WordSegment.h" 
#include "Message.h" 
#include  
#include  
 
#ifdef _DEBUG 
#undef THIS_FILE 
static char THIS_FILE[]=__FILE__; 
#define new DEBUG_NEW 
#endif 
 
CClassifier theClassifier; 
const DWORD CClassifier::dwModelFileID=0xFFEFFFFF; 
////////////////////////////////////////////////////////////////////// 
// Construction/Destruction 
////////////////////////////////////////////////////////////////////// 
 
CClassifier::CClassifier() 
{ 
	n_Type=-1; 
	m_pDocs=NULL; 
	m_pSimilarityRatio=NULL; 
	m_pProbability=NULL; 
	m_lDocNum=0; 
	m_nClassNum=0; 
} 
 
CClassifier::~CClassifier() 
{ 
 
} 
 
//参数bGenDic=false代表无需重新扫描文档得到训练文档集中所有特征,一般在层次分类时使用 
//参数nType用来决定分类模型的类别,nType=0代表KNN分类器,nType=1代表SVM分类器 
bool CClassifier::Train(int nType, bool bFlag) 
{ 
	this->n_Type=nType; 
	CTime startTime; 
	CTimeSpan totalTime; 
	if(bFlag) 
	{ 
		InitTrain(); 
		//生成所有候选特征项,将其保存在m_lstWordList中 
		GenDic(); 
	} 
	CMessage::PrintStatusInfo(""); 
 
	if(m_lstWordList.GetCount()==0) 
		return false; 
	if(m_lstTrainCatalogList.GetCataNum()==0) 
		return false; 
 
	//清空特征项列表m_lstTrainWordList 
	m_lstTrainWordList.InitWordList(); 
	//为特征项列表m_lstWordList中的每个特征加权 
	CMessage::PrintInfo(_T("开始计算候选特征集中每个特征的类别区分度,请稍候...")); 
	startTime=CTime::GetCurrentTime(); 
	FeatherWeight(m_lstWordList); 
	totalTime=CTime::GetCurrentTime()-startTime; 
	CMessage::PrintInfo(_T("特征区分度计算结束,耗时")+totalTime.Format("%H:%M:%S")); 
	CMessage::PrintStatusInfo(""); 
 
	//从特征项列表m_lstWordList中选出最优特征 
	CMessage::PrintInfo(_T("开始进行特征选择,请稍候...")); 
	startTime=CTime::GetCurrentTime(); 
	FeatherSelection(m_lstTrainWordList); 
    //为最优特征集m_lstTrainWordList中的每个特征建立一个ID 
	m_lstTrainWordList.IndexWord(); 
	totalTime=CTime::GetCurrentTime()-startTime; 
	CMessage::PrintInfo(_T("特征选择结束,耗时")+totalTime.Format("%H:%M:%S")); 
	CMessage::PrintStatusInfo(""); 
 
//	清空m_lstWordList,释放它占用的空间 
	m_lstWordList.InitWordList(); 
 
	CMessage::PrintInfo("开始生成文档向量,请稍候..."); 
	startTime=CTime::GetCurrentTime(); 
	GenModel(); 
	totalTime=CTime::GetCurrentTime()-startTime; 
	CMessage::PrintInfo(_T("文档向量生成结束,耗时")+totalTime.Format("%H:%M:%S")); 
	CMessage::PrintStatusInfo(""); 
 
	CMessage::PrintInfo("开始保存分类模型,请稍候..."); 
	startTime=CTime::GetCurrentTime(); 
	WriteModel(m_paramClassifier.m_txtResultDir+"\\model.prj",nType); 
	totalTime=CTime::GetCurrentTime()-startTime; 
	CMessage::PrintInfo(_T("保存分类模型结束,耗时")+totalTime.Format("%H:%M:%S")); 
 
	//训练SVM分类器必须在保存训练文档的文档向量后进行 
	if(nType == 1) 
	{ 
		CMessage::PrintInfo("开始训练SVM,请稍候..."); 
		m_lstTrainCatalogList.InitCatalogList(2); //删除文档向量所占用的空间 
		startTime=CTime::GetCurrentTime(); 
		TrainSVM(); 
		totalTime=CTime::GetCurrentTime()-startTime; 
		CMessage::PrintInfo(_T("SVM分类器训练结束,耗时")+totalTime.Format("%H:%M:%S")); 
		CMessage::PrintStatusInfo(""); 
	} 
	//为分类做好准备,否则不能进行分类 
	Prepare(); 
	CMessage::PrintStatusInfo(""); 
	return TRUE; 
} 
 
void CClassifier::TrainSVM() 
{ 
	CString str; 
	CTime tmStart; 
	CTimeSpan tmSpan; 
 
	m_paramClassifier.m_strModelFile="model"; 
	for(int i=1;i<=m_lstTrainCatalogList.GetCataNum();i++) 
	{ 
		tmStart=CTime::GetCurrentTime(); 
		str.Format("正在训练第%d个SVM分类器,请稍侯...",i); 
		CMessage::PrintInfo(str); 
		m_theSVM.com_param.trainfile=m_paramClassifier.m_txtResultDir+"\\train.txt"; 
		m_theSVM.com_param.modelfile.Format("%s\\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i); 
		m_theSVM.svm_learn_main(i); 
		tmSpan=CTime::GetCurrentTime()-tmStart; 
		str.Format("第%d个SVM分类器训练完成,耗时%s!",i,tmSpan.Format("%H:%M:%S")); 
		CMessage::PrintInfo(str); 
	}	 
} 
 
void CClassifier::TrainBAYES() 
{ 
	/* 
	CString str; 
	CTime tmStart; 
	CTimeSpan tmSpan; 
 
	m_paramClassifier.m_strModelFile="model"; 
	for(int i=1;i<=m_lstTrainCatalogList.GetCataNum();i++) 
	{ 
		tmStart=CTime::GetCurrentTime(); 
		str.Format("正在训练第%d个SVM分类器,请稍侯...",i); 
		CMessage::PrintInfo(str); 
		m_theSVM.com_param.trainfile=m_paramClassifier.m_txtResultDir+"\\train.txt"; 
		m_theSVM.com_param.modelfile.Format("%s\\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i); 
		m_theSVM.svm_learn_main(i); 
		tmSpan=CTime::GetCurrentTime()-tmStart; 
		str.Format("第%d个SVM分类器训练完成,耗时%s!",i,tmSpan.Format("%H:%M:%S")); 
		CMessage::PrintInfo(str); 
	} 
	*/ 
} 
 
// fill an array of CTrain::sSortType (train word length) 
// nCatalog mean the value of element of the array is the weight 
// of nCatalog(as an index of catalog) for each individual word 
// if nCatalog==-1 then sum weight for all catalog 
void CClassifier::GenSortBuf(CWordList& wordList,sSortType *psSortBuf,int nCatalog) 
{ 
	int nTotalCata=m_lstTrainCatalogList.GetCataNum(); 
	int i; 
	ASSERT(nCatalogn_Type==2) 
			for(i=0;im_pCataWeightPro[i]=wordnode.m_pCataWeightPro[i]; 
	//			CString strtemp; 
	//			strtemp.Format("123 %f",psSortBuf[lWordCount].pclsWordNode->m_pCataWeightPro[i]); 
	//			CMessage::PrintInfo(strtemp); 
			} 
			lWordCount++; 
	}	 
} 
 
 
//从m_lstWordList选出最优特征子集,存到dstWordList中 
void CClassifier::FeatherSelection(CWordList& dstWordList) 
{ 
	if(m_lstWordList.GetCount()<=0) return; 
	dstWordList.InitWordList(); 
	m_lstWordList.IndexWord(); 
 
	sSortType	*psSortBuf; 
	int			nDistinctWordNum = m_lstWordList.GetCount(); 
	psSortBuf = new sSortType[nDistinctWordNum ];  // the distinct number of the word  
	ASSERT(psSortBuf!=NULL); 
	long lDocNum=m_lstTrainCatalogList.GetDocNum(); 
	for(int i=0;iComputeWeight(lDocNum); 
				else if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF_DIFF) 
					psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum,true); 
				wordNode.m_dWeight=psSortBuf[j].pclsWordNode->m_dWeight; 
				wordNode.m_lDocFreq=psSortBuf[j].pclsWordNode->m_lDocFreq; 
				wordNode.m_lWordFreq=psSortBuf[j].pclsWordNode->m_lWordFreq; 
 
				dstWordList.SetAt(psSortBuf[j].word,wordNode); 
				nSelectWordNum++; 
			} 
		} 
	} 
	// total selecting 
	else //if(m_paramClassifier.m_nSelMode==CClassifierParam::nFSM_GolbalMode) 
	{ 
		int iWord=0; 
		GenSortBuf(m_lstWordList,psSortBuf,-1);//-1 mean sum all catalog 
		Sort(psSortBuf,nDistinctWordNum-1); 
		 
		int nSelectWordNum=m_paramClassifier.m_nWordSize; 
		if (nSelectWordNum>nDistinctWordNum) 
			nSelectWordNum=nDistinctWordNum; 
 
		for(i=0;iComputeWeight(lDocNum); 
			else if(m_paramClassifier.m_nWeightMode==CClassifierParam::nWM_TF_IDF_DIFF) 
				psSortBuf[i].pclsWordNode->ComputeWeight(lDocNum,true); 
			wordNode.m_dWeight=psSortBuf[i].pclsWordNode->m_dWeight; 
			wordNode.m_lDocFreq=psSortBuf[i].pclsWordNode->m_lDocFreq; 
			wordNode.m_lWordFreq=psSortBuf[i].pclsWordNode->m_lWordFreq; 
 
			if(this->n_Type==2) 
	//			拷贝词在不同类中的概率 
				for(int k=0;km_pCataWeightPro[k]; 
					wordNode.m_pCataWeightPro[k] = psSortBuf[i].pclsWordNode->m_pCataWeightPro[k]; 
	//				CString strtemp; 
	//				strtemp.Format("123 %f",wordNode.m_pCataWeightPro[k]); 
	//				CMessage::PrintInfo(strtemp); 
				} 
 
			dstWordList.SetAt(psSortBuf[i].word,wordNode); 
		}	 
	} 
	delete [] psSortBuf; 
} 
 
void CClassifier::FeatherWeight(CWordList& wordList) 
{ 
// ------------------------------------------------------------------------------ 
//  based on document number model 
	int			N;		//总的文档数; 
	int			N_c;	//C类的文档数 
	int			N_ft;	//含有ft的文档数 
	int			N_c_ft;	//C类中含有ft的文档数 
// ------------------------------------------------------------------------------ 
//  based on word number model 
	long		N_W;    //总的词数					m_lWordNum; 
	long		N_W_C;  //C类词数					CCatalogNode.m_lTotalWordNum; 
	long		N_W_f_t; //f_t出现的总次数	 
	long		N_W_C_f_t;//C类中f_t出现的次数 
// ------------------------------------------------------------------------------ 
	double		P_c_ft,P_c_n_ft,P_n_c_ft,P_n_c_n_ft; 
	double		P_c,P_n_c; 
	double		P_ft,P_n_ft; 
 
// ------------------------------------------------------------------------------ 
	POSITION	pos_cata,pos_word; 
	CString     strWord; 
 
	// calculate the weight of each word to all catalog 
	N = m_lstTrainCatalogList.GetDocNum(); 
	N_W = wordList.GetWordNum(); 
 
	int nTotalCata=m_lstTrainCatalogList.GetCataNum(); 
	pos_word = wordList.GetFirstPosition(); 
	while(pos_word!= NULL)  // for each word 
	{ 
		CWordNode& wordnode = wordList.GetNext(pos_word,strWord); 
		 
		wordnode.m_dWeight=0; 
 
		ASSERT(wordnode.m_pCataWeight); 
		ASSERT(wordnode.m_pCataWeightPro); 
 
		CMessage::PrintStatusInfo("特征:"+strWord+"...");  
 
		N_ft = wordnode.GetDocNum();   
		N_W_f_t = wordnode.GetWordNum(); 
		int nCataCount=0; 
		pos_cata = m_lstTrainCatalogList.GetFirstPosition(); 
		 
		while(pos_cata!=NULL)  // for each catalog  
		{ 
			CCatalogNode& catanode = m_lstTrainCatalogList.GetNext(pos_cata); 
			N_c  = catanode.GetDocNum(); 
			N_W_C  = catanode.m_lTotalWordNum;			 
			N_c_ft = wordnode.GetCataDocNum(catanode.m_idxCata); 
			N_W_C_f_t =wordnode.GetCataWordNum(catanode.m_idxCata); 
			// calculation model  
			if(m_paramClassifier.m_nOpMode==CClassifierParam::nOpWordMode)    
			{ 
				P_c	    = 1.0 * N_W_C /N_W; 
				P_ft	= 1.0 * N_W_f_t/N_W; 
				P_c_ft  = 1.0 * N_W_C_f_t/N_W; 
			} 
			else //if(m_paramClassifier.m_nOpMode==CClassifierParam::nOpDocMode) 
			{ 
				P_c			= 1.0 * N_c /N;		//C类出现的概率 
				P_ft		= 1.0 * N_ft/N;		//含有ft的文档出现的概率 
				P_c_ft		= 1.0 * N_c_ft/N;	//C类中含有ft的文档的概率 
			} 
			 
			P_n_c		= 1 - P_c; 
			P_n_ft		= 1 - P_ft; 
			P_n_c_ft	= P_ft - P_c_ft; 
			P_c_n_ft	= P_c - P_c_ft; 
			P_n_c_n_ft  = P_n_ft - P_c_n_ft; 
 
			wordnode.m_pCataWeight[nCataCount]=0; 
			// feature selection model 
			if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_XXMode) 
			{ 
				// Right half of IG 
				if ( (fabs(P_c * P_n_ft) > dZero) && ( fabs(P_c_n_ft) > dZero) )  
				{ 
					wordnode.m_pCataWeight[nCataCount]+=P_c_n_ft * log( P_c_n_ft/(P_c * P_n_ft) ); 
				} 
			} 
			else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_MIMode) 
			{ 
				// Mutual Informaiton feature selection 
				if ( (fabs(P_c * P_ft) > dZero) && (fabs(P_c_ft) > dZero) )  
				{ 
					wordnode.m_pCataWeight[nCataCount]+= P_c * log( P_c_ft/(P_c * P_ft) ); 
				} 
			} 
			else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_CEMode) 
			{ 
				// Cross Entropy for text feature selection 
				if ( (fabs(P_c * P_ft) > dZero) && (fabs(P_c_ft) > dZero) )  
				{ 
					wordnode.m_pCataWeight[nCataCount]+= P_c_ft * log( P_c_ft/(P_c * P_ft) ); 
				} 
			} 
			else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_X2Mode) 
			{ 
				// X^2 Statistics feature selection 
				if ( (fabs(P_n_c * P_ft * P_n_ft) > dZero) )  
				{ 
					wordnode.m_pCataWeight[nCataCount]+= (P_c_ft * P_n_c_n_ft - P_n_c_ft * P_c_n_ft) * (P_c_ft * P_n_c_n_ft - P_n_c_ft * P_c_n_ft) / ( P_ft * P_n_c * P_n_ft); 
				} 
			} 
			else if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_WEMode) 
			{ 
				// Weight of Evielence for text feature selection 
				double		odds_c_ft; 
				double		odds_c; 
				double		P_c_inv_ft=P_c_ft/P_ft; 
 
				if( fabs(P_c_inv_ft) < dZero ) 
					odds_c_ft = 1.0 / ( N * N -1); 
				else if ( fabs(P_c_inv_ft-1) < dZero ) 
					odds_c_ft = N * N -1; 
				else 
					odds_c_ft = P_c_inv_ft / (1.0 - P_c_inv_ft); 
 
				if( fabs(P_c) < dZero ) 
					odds_c = 1.0 / ( N * N -1); 
				else if ( fabs(P_c-1) < dZero ) 
					odds_c = N * N -1; 
				else 
					odds_c = P_c / (1.0 - P_c); 
				if( fabs(odds_c) > dZero && fabs(odds_c_ft) > dZero ) 
				{ 
					wordnode.m_pCataWeight[nCataCount]+= P_c * P_ft * fabs( log(odds_c_ft / odds_c) ); 
				} 
			} 
			else //if(m_paramClassifier.m_nFSMode==CClassifierParam::nFS_IGMode)  
			{ 
				// Information gain feature selection 
				if ( (fabs(P_c * P_n_ft) > dZero) && ( fabs(P_c_n_ft) > dZero) )  
				{ 
					wordnode.m_pCataWeight[nCataCount]+=P_c_n_ft * log( P_c_n_ft/(P_c * P_n_ft) ); 
				} 
				if ( (fabs(P_c * P_ft) > dZero) && (fabs(P_c_ft) > dZero) )  
				{ 
					wordnode.m_pCataWeight[nCataCount]+= P_c_ft * log( P_c_ft/(P_c * P_ft) ); 
				} 
			} 
			wordnode.m_dWeight+=wordnode.m_pCataWeight[nCataCount]; 
			wordnode.m_pCataWeightPro[nCataCount] = 1.0 * (N_c_ft+1)/(2+N);//词str属于类别nCataCount的概率 
			/* 
			CString strtemp; 
			strtemp.Format("第%d个类中,词的权重是%lf",nCataCount,wordnode.m_pCataWeight[nCataCount]); 
			CMessage::PrintInfo(strtemp); 
			*/ 
			nCataCount++;				 
		} 
		ASSERT(nCataCount==nTotalCata); 
	} 
	CMessage::PrintStatusInfo(""); 
} 
 
//计算每一篇训练文档向量的每一维的权重 
void CClassifier::ComputeWeight(bool bMult) 
{ 
	long lWordNum=m_lstTrainWordList.GetCount(); 
	if(m_lstTrainWordList.GetCount()<=0) return; 
	long lDocNum=m_lstTrainCatalogList.GetDocNum(); 
	if(lDocNum<=0) return; 
	m_lstTrainWordList.ComputeWeight(lDocNum,bMult); 
 
	double sum=0.0; 
	int i=0; 
	POSITION pos_cata = m_lstTrainCatalogList.GetFirstPosition(); 
	while(pos_cata != NULL)  // for each catalog  
	{ 
		//取类列表中的每一个类 
		CCatalogNode& cataNode = m_lstTrainCatalogList.GetNext(pos_cata); 
		POSITION pos_doc  = cataNode.GetFirstPosition(); 
		while(pos_doc!=NULL) 
		{ 
			CDocNode& docNode=cataNode.GetNext(pos_doc); 
 
			sum=0.0; 
			for(i=0;i Mid) Lo++; 
		while(psData[Hi].dWeight < Mid) Hi--; 
		if(Lo <= Hi) 
		{ 
			t = psData[Lo]; 
			psData[Lo]=psData[Hi]; 
			psData[Hi]=t; 
			Lo++; 
			Hi--; 
		} 
	}while(Hi>Lo); 
    if(Hi > iLo) QuickSort(psData, iLo, Hi); 
    if(Lo < iHi) QuickSort(psData, Lo, iHi); 
} 
 
void CClassifier::Sort(sSortType *psData,int nSize) 
{ 
	QuickSort(psData,0,nSize); 
} 
 
 
// Give m_lstWordList & m_lstTrainCatalogList 
// Output the present vector of each document; 
// bFlag=false 层次分类的时候使用 
void CClassifier::GenModel() 
{ 
	CDocNode::AllocTempBuffer(m_lstTrainWordList.GetCount()); 
	POSITION pos_cata = m_lstTrainCatalogList.GetFirstPosition(); 
	while(pos_cata != NULL)  // for each catalog  
	{ 
		//取类列表中的每一个类 
		CCatalogNode& cataNode = m_lstTrainCatalogList.GetNext(pos_cata); 
		POSITION pos_doc  = cataNode.GetFirstPosition(); 
		while(pos_doc!=NULL) 
		{ 
			CDocNode& docNode=cataNode.GetNext(pos_doc); 
			if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese) 
				docNode.ScanChineseWithDict(cataNode.m_strDirName.GetBuffer(0),m_lstTrainWordList); 
			else 
				docNode.ScanEnglishWithDict(cataNode.m_strDirName.GetBuffer(0),m_lstTrainWordList,m_paramClassifier.m_bStem); 
			docNode.GenDocVector(); 
			CMessage::PrintStatusInfo("生成文档"+docNode.m_strDocName+"的文档向量"); 
		} 
	} 
	CDocNode::DeallocTempBuffer(); 
} 
 
 
// generate original dictionary (the largest one) 
// form train files 
bool CClassifier::GenDic() 
{ 
	m_lstWordList.InitWordList(); 
	CTime startTime; 
	CTimeSpan totalTime; 
 
	startTime=CTime::GetCurrentTime(); 
	CMessage::PrintInfo(_T("分词程序初始化,请稍候..."));	 
	if(!g_wordSeg.InitWorgSegment(theApp.m_strPath.GetBuffer(0),m_paramClassifier.m_nLanguageType)) 
	{ 
		CMessage::PrintError(_T("分词程序初始化失败!")); 
		return false; 
	} 
	if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese) 
		g_wordSeg.SetSegSetting(CWordSegment::uPlace); 
	totalTime=CTime::GetCurrentTime()-startTime; 
	CMessage::PrintInfo(_T("分词程序初始化结束,耗时")+totalTime.Format("%H:%M:%S")); 
 
 
	startTime=CTime::GetCurrentTime(); 
	CMessage::PrintInfo(_T("开始扫描训练文档,请稍候...")); 
	if(m_lstTrainCatalogList.BuildLib(m_paramClassifier.m_txtTrainDir)<=0) 
	{ 
		CMessage::PrintError("训练文档的总数为0!"); 
		return false; 
	} 
 
	CString strFileName; 
	POSITION pos_cata=m_lstTrainCatalogList.GetFirstPosition(); 
	int nCount,nCataNum; 
	nCataNum=m_lstTrainCatalogList.GetCataNum(); 
	while(pos_cata!=NULL) 
	{ 
		CCatalogNode& catalognode=m_lstTrainCatalogList.GetNext(pos_cata); 
		POSITION pos_doc=catalognode.GetFirstPosition(); 
		while(pos_doc!=NULL) 
		{ 
			CDocNode& docnode=catalognode.GetNext(pos_doc); 
			CMessage::PrintStatusInfo(_T("扫描文档")+docnode.m_strDocName); 
 
			if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese) 
				nCount=docnode.ScanChinese(catalognode.m_strDirName.GetBuffer(0), 
							m_lstWordList,nCataNum,catalognode.m_idxCata); 
			else 
				nCount=docnode.ScanEnglish(catalognode.m_strDirName.GetBuffer(0), 
							m_lstWordList,nCataNum,catalognode.m_idxCata, 
							m_paramClassifier.m_bStem); 
			if(nCount==0) 
			{ 
				CMessage::PrintError("文件"+catalognode.m_strDirName+"\\"+docnode.m_strDocName+"无内容!"); 
				continue; 
			} 
			else if(nCount<0) 
			{ 
				CMessage::PrintError("文件"+catalognode.m_strDirName+"\\"+docnode.m_strDocName+"无法打开!"); 
				continue; 
			} 
			catalognode.m_lTotalWordNum+=nCount;// information collection point 
		} 
	} 
	g_wordSeg.FreeWordSegment(); 
	totalTime=CTime::GetCurrentTime()-startTime; 
	CMessage::PrintInfo(_T("扫描训练文档结束,耗时")+totalTime.Format("%H:%M:%S")); 
	return true; 
} 
 
void CClassifier::InitTrain() 
{ 
	m_lstTrainWordList.InitWordList(); 
	m_lstTrainCatalogList.InitCatalogList(); 
	m_lstWordList.InitWordList(); 
} 
 
//参数nType用来决定分类模型的类别,nType=0代表KNN分类器,nType=1代表SVM分类器 
bool CClassifier::WriteModel(CString strFileName, int nType) 
{ 
	CFile fOut; 
	if( !fOut.Open(strFileName,CFile::modeCreate | CFile::modeWrite) ) 
	{ 
		CMessage::PrintError("无法创建文件"+strFileName+"!"); 
		return false; 
	} 
 
	CArchive ar(&fOut,CArchive::store);	 
	if(nType==0) 
	{ 
		m_lstTrainWordList.DumpToFile(m_paramClassifier.m_txtResultDir+"\\features.dat"); 
		m_lstTrainWordList.DumpWordList(m_paramClassifier.m_txtResultDir+"\\features.txt"); 
		m_lstTrainCatalogList.DumpToFile(m_paramClassifier.m_txtResultDir+"\\train.dat"); 
		m_lstTrainCatalogList.DumpDocList(m_paramClassifier.m_txtResultDir+"\\train.txt"); 
		m_paramClassifier.DumpToFile(m_paramClassifier.m_txtResultDir+"\\params.dat"); 
 
		ar<>dwFileID; 
	if(dwFileID!=dwModelFileID) 
	{ 
		ar.Close(); 
		fIn.Close(); 
		CMessage::PrintError("分类模型文件的格式不正确!"); 
		return false; 
	} 
 
	ar>>str; 
 
	if(!m_paramClassifier.GetFromFile(strPath+"\\"+str)) 
	{ 
		CMessage::PrintError(_T("无法打开训练参数文件"+str+"!")); 
		return false; 
	} 
	m_paramClassifier.m_txtResultDir=strPath; 
 
	if(m_paramClassifier.m_nClassifierType==0) 
	{ 
		ar>>str; 
		m_lstTrainWordList.InitWordList(); 
		if(!m_lstTrainWordList.GetFromFile(strPath+"\\"+str)) 
		{ 
			CMessage::PrintError(_T("无法打开特征类表文件"+str+"!")); 
			return false; 
		} 
		ar>>str; 
		m_lstTrainCatalogList.InitCatalogList(); 
		if(!m_lstTrainCatalogList.GetFromFile(strPath+"\\"+str)) 
		{ 
			CMessage::PrintError(_T("无法打开训练文档列表文件"+str+"!")); 
			return false; 
		} 
	} 
	else if(m_paramClassifier.m_nClassifierType==1) 
	{ 
		ar>>str; 
		m_lstTrainWordList.InitWordList(); 
		if(!m_lstTrainWordList.GetFromFile(strPath+"\\"+str)) 
		{ 
			CMessage::PrintError(_T("无法打开特征类表文件"+str+"!")); 
			return false; 
		} 
		//对于SVM分类起来说m_lstTrainCatalogList其实没用 
		//此处读入它只是为了在CLeftViw中显示某些统计信息时使用 
		ar>>str; 
		m_lstTrainCatalogList.InitCatalogList(); 
		if(!m_lstTrainCatalogList.GetFromFile(strPath+"\\"+str)) 
		{ 
			CMessage::PrintError(_T("无法打开训练文档列表文件"+str+"!")); 
			return false; 
		} 
		ar>>str; 
		if(!m_theSVM.com_param.GetFromFile(strPath+"\\"+str)) 
		{ 
			CMessage::PrintError(_T("无法打开SVM训练参数文件"+str+"!")); 
			return false; 
		} 
		m_theSVM.com_param.trainfile=strPath+"\\train.txt"; 
		m_theSVM.com_param.resultpath=strPath; 
	} 
	else if(m_paramClassifier.m_nClassifierType==2) 
	{ 
		ar>>str; 
		m_lstTrainWordList.InitWordList(); 
		if(!m_lstTrainWordList.GetFromFile(strPath+"\\"+str)) 
		{ 
			CMessage::PrintError(_T("无法打开特征类表文件"+str+"!")); 
			return false; 
		} 
		ar>>str; 
		m_lstTrainCatalogList.InitCatalogList(); 
		if(!m_lstTrainCatalogList.GetFromFile(strPath+"\\"+str)) 
		{ 
			CMessage::PrintError(_T("无法打开训练文档列表文件"+str+"!")); 
			return false; 
		} 
		ar>>str; 
		if(!m_lstTrainWordList.GetProFromFile(strPath+"\\"+str)) 
		{ 
			CMessage::PrintError(_T("无法打开特征词类属概率文件"+str+"!")); 
			return false; 
		} 
	} 
	ar.Close(); 
	fIn.Close(); 
	Prepare(); 
	CTimeSpan totalTime=CTime::GetCurrentTime()-startTime; 
	CMessage::PrintInfo(_T("分类模型文件已经打开,耗时")+totalTime.Format("%H:%M:%S")+"\r\n"); 
	 
	str.Empty(); 
	m_paramClassifier.GetParamString(str); 
	CMessage::PrintInfo(str); 
	return true;	 
} 
 
bool CClassifier::Classify() 
{ 
	m_lstTrainCatalogList.DumpCataList(m_paramClassifier.m_strResultDir+"\\classes.txt"); 
	CTime startTime; 
	CTimeSpan totalTime; 
	startTime=CTime::GetCurrentTime(); 
	CMessage::PrintInfo(_T("正在扫描测试文档,请稍候...")); 
	if(m_paramClassifier.m_bEvaluation) 
	{ 
		//vBuildLib方法中已经清空了g_lstTestCatalogList,所以此处无需再对其初始化 
		m_lstTestCatalogList.BuildLib(m_paramClassifier.m_strTestDir); 
		if(!m_lstTestCatalogList.BuildCatalogID(m_lstTrainCatalogList)) 
		{ 
			CMessage::PrintError("测试文件中包含有无法识别的类别!"); 
			return false; 
		} 
	} 
	else 
	{ 
		m_lstTestCatalogList.InitCatalogList(); 
		CCatalogNode catalognode; 
		catalognode.m_strDirName=m_paramClassifier.m_strTestDir; 
		catalognode.m_strCatalogName="测试文档"; 
		catalognode.m_idxCata=-1; 
		POSITION posTemp=m_lstTestCatalogList.AddCata(catalognode); 
		CCatalogNode& cataTemp=m_lstTestCatalogList.GetAt(posTemp); 
		cataTemp.SetStartDocID(0); 
		cataTemp.ScanDirectory(m_paramClassifier.m_strTestDir); 
	} 
	if(m_lstTestCatalogList.GetDocNum()<=0) 
	{ 
		CMessage::PrintError("测试文件总数为0!\r\n如果不需要对分类结果进行评价时,分类文档必须在\"分类文档目录\"下,而不是它的子目录下!"); 
		return false; 
	} 
	totalTime=CTime::GetCurrentTime()-startTime; 
	CMessage::PrintInfo(_T("扫描测试文档结束,耗时")+totalTime.Format("%H:%M:%S")); 
 
	startTime=CTime::GetCurrentTime(); 
	CMessage::PrintInfo(_T("正在对测试文档进行分类,请稍候...")); 
	long lCorrect=0,lUnknown=0; 
	lUnknown=Classify(m_lstTestCatalogList); 
 
 
	lCorrect=SaveResults(m_lstTestCatalogList,m_paramClassifier.m_strResultDir+"\\results.txt"); 
	long lTotalNum=m_lstTestCatalogList.GetDocNum()-lUnknown; 
	CString str; 
	totalTime=CTime::GetCurrentTime()-startTime; 
	CMessage::PrintInfo(_T("测试文档分类结束,耗时")+totalTime.Format("%H:%M:%S")); 
	if (lUnknown>0)  
	{ 
		str.Format("无法分类的文档数%d:",lUnknown); 
		CMessage::PrintInfo(str); 
	} 
	if(m_paramClassifier.m_bEvaluation&&lTotalNum>0&&lCorrect>0) 
		str.Format("测试文档总数:%d,准确率:%f",m_lstTestCatalogList.GetDocNum(),(float)(lCorrect)/(float)(lTotalNum)); 
	else 
		str.Format("测试文档总数:%d",m_lstTestCatalogList.GetDocNum()); 
	CMessage::PrintInfo(str); 
	return true; 
} 
 
//对Smart格式的文档进行分类 
bool CClassifier::ClassifySmart() 
{ 
	m_lstTrainCatalogList.DumpCataList(m_paramClassifier.m_strResultDir+"\\classes.txt"); 
	m_lstTestCatalogList.InitCatalogList(); 
	CCatalogNode catalognode; 
	catalognode.m_strDirName=m_paramClassifier.m_strTestDir; 
	catalognode.m_strCatalogName="测试文档"; 
	catalognode.m_idxCata=-1; 
	POSITION posTemp=m_lstTestCatalogList.AddCata(catalognode); 
	CCatalogNode& cataTemp=m_lstTestCatalogList.GetAt(posTemp); 
 
	FILE *stream1,*stream2; 
	if( (stream1 = fopen( m_paramClassifier.m_strTestDir, "r" )) == NULL ) 
	{ 
		CMessage::PrintError("无法打开文件"+m_paramClassifier.m_strTestDir+"!"); 
		return false; 
	} 
 
	//如果是SVM分类器,则需要先将所有测试文档转换为向量,保存到文件test.dat 
	if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM) 
	{ 
		m_theSVM.com_param.classifyfile=m_paramClassifier.m_strResultDir+"\\test.dat"; 
		if((stream2=fopen(m_theSVM.com_param.classifyfile,"w"))==NULL) 
		{ 
			CMessage::PrintError("无法创建测试文档向量文件"+m_theSVM.com_param.classifyfile+"!"); 
			return false; 
		} 
	} 
 
	CTime startTime; 
	CTimeSpan totalTime; 
	startTime=CTime::GetCurrentTime(); 
	CMessage::PrintInfo(_T("正在对测试文档进行分类,请稍候...")); 
 
	char fname[10],type[1024],line[4096],content[100*1024]; 
	//falg=1 下一行的内容是文档的类别 
	//flag=2 下一行的内容是文档的标题 
	//flag=3 下一行的内容是文档的内容 
	int flag=0,nCount,len,i; 
	long lUnknown=0,lDocNum=0; 
	CStringArray typeArray; 
	CString strFileName,strCopyFile; 
	bool bTitle=false; //是否已经读出标题 
	double dThreshold=(double)m_paramClassifier.m_dThreshold/100.0; 
	int nWordNum=m_lstTrainWordList.GetCount(); 
	while(!feof(stream1)) 
	{ 
		if(fgets(line,4096,stream1)==NULL) continue; 
		if(line[0]=='.') 
		{ 
			if(flag==3) 
			{ 
				CDocNode doc; 
				posTemp=cataTemp.AddDoc(doc); 
				CDocNode& docnode=cataTemp.GetAt(posTemp); 
				docnode.m_strDocName=fname; 
				if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN) 
				{ 
					nCount=KNNCategory(content,docnode,false); 
				} 
				else 
				{ 
					if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese) 
						nCount=docnode.ScanChineseStringWithDict(content,m_lstTrainWordList)-1; 
					else 
						nCount=docnode.ScanEnglishStringWithDict(content,m_lstTrainWordList, 
											m_paramClassifier.m_bStem)-1; 
 
					fprintf(stream2,"%d",1); 
					for(i=0;i0)  
	{ 
		str.Format("无法分类的文档数%d:",lUnknown); 
		CMessage::PrintInfo(str); 
	} 
	if(m_paramClassifier.m_bEvaluation&&lTotalNum>0&&lCorrect>0) 
		str.Format("测试文档总数:%d,准确率:%f",m_lstTestCatalogList.GetDocNum(),(float)(lCorrect)/(float)(lTotalNum)); 
	else 
		str.Format("测试文档总数:%d",m_lstTestCatalogList.GetDocNum()); 
	CMessage::PrintInfo(str); 
	return true; 
} 
 
//对文档进行分类,计算文档和每个类别的相似度,返回值为类别无法识别的文档总数 
long CClassifier::SVMClassify(CCatalogList &cataList) 
{ 
	long lUnknown=0; 
	FILE *stream; 
	m_theSVM.com_param.classifyfile=m_paramClassifier.m_strResultDir+"\\test.dat"; 
	if((stream=fopen(m_theSVM.com_param.classifyfile,"w"))==NULL) 
	{ 
		CMessage::PrintError("无法创建测试文档向量文件"+m_theSVM.com_param.classifyfile+"!"); 
		return 0; 
	} 
	 
	CTime startTime; 
	CTimeSpan totalTime; 
	CString str; 
	int nCount=0; 
	long lWordNum=m_lstTrainWordList.GetCount(); 
	startTime=CTime::GetCurrentTime(); 
	CMessage::PrintInfo(_T("正在生成测试文档的向量形式,请稍候...")); 
	POSITION pos_cata=cataList.GetFirstPosition(); 
	while(pos_cata!=NULL) 
	{ 
		CCatalogNode& catalognode=cataList.GetNext(pos_cata); 
		char *path=catalognode.m_strDirName.GetBuffer(0); 
		POSITION pos_doc=catalognode.GetFirstPosition(); 
		while(pos_doc!=NULL) 
		{ 
			CDocNode& docnode=catalognode.GetNext(pos_doc); 
			if(m_paramClassifier.m_nLanguageType==CClassifierParam::nLT_Chinese) 
				nCount=docnode.ScanChineseWithDict(path,m_lstTrainWordList); 
			else 
				nCount=docnode.ScanEnglishWithDict(path,m_lstTrainWordList, 
											m_paramClassifier.m_bStem); 
 
			fprintf(stream,"%d",catalognode.m_idxCata+1); 
			for(int i=0;i0)&&(nCount>0)) 
	{ 
		ComputeSimRatio(docNode,nCmpType); 
		return true; 
	} 
	else 
		return false; 
} 
 
//得到与文档docNode的相似度大于阈值dThreshold的所有类别 
//如果没有大于值阈值的类别,则返回相似度最大的类别 
bool CClassifier::MultiCategory(CDocNode &docNode, CArray& aryResult, double dThreshold) 
{ 
	double *pSimRatio=docNode.m_pResults; 
	if(pSimRatio==NULL) return false; 
	 
	double dMax=pSimRatio[0]; 
	short nMax=0; 
	bool bFound=false; 
	aryResult.RemoveAll(); 
	for(short i=1;idMax) 
		{ 
			dMax=pSimRatio[i]; 
			nMax=i; 
		} 
		if(pSimRatio[i]>dThreshold)  
		{ 
			aryResult.Add(i); 
			bFound=true; 
		} 
	} 
	if(!bFound) aryResult.Add(nMax); 
	return true; 
} 
 
//计算文档docNode和每一个类别的相似度 
//nCmpType代表相似度的不同计算方式 
void CClassifier::ComputeSimRatio(CDocNode &docNode,int nCmpType) 
{ 
	//计算文档与训练集中每一篇文档的相似度 
	int i; 
	long k; 
	for(i=0;iComputeSimilarityRatio(); 
		m_pSimilarityRatio[i].lDocID=i; 
	} 
	//将测试文档与训练文档集中文档的相似度进行降序排序 
	Sort(m_pSimilarityRatio,m_lDocNum-1); 
	docNode.AllocResultsBuffer(m_nClassNum); 
	double *pSimRatio=docNode.m_pResults; 
	for(i=0;i0)&&(nCount>0)) 
	{ 
		DOC doc; 
		CString str; 
		docNode.GenDocVector(doc); 
		docNode.AllocResultsBuffer(m_nClassNum); 
		for(int i=0;im_lDocNum) m_paramClassifier.m_nKNN=m_lDocNum; 
	m_pSimilarityRatio=new DocWeight[m_lDocNum]; 
	m_pProbability=new DocWeight[m_nClassNum]; 
	m_pDocs=(DocCatalog*)malloc(sizeof(DocCatalog)*m_lDocNum); 
	long num=0; 
	POSITION pos_cata = m_lstTrainCatalogList.GetFirstPosition(); 
	while(pos_cata != NULL)  // for each catalog  
	{ 
		CCatalogNode& catanode = m_lstTrainCatalogList.GetNext(pos_cata); 
		short idxCata=catanode.m_idxCata; 
		POSITION pos_doc  = catanode.GetFirstPosition(); 
		while(pos_doc!=NULL) 
		{ 
			CDocNode& docnode=catanode.GetNext(pos_doc); 
			m_pDocs[num].pDocNode=&docnode; 
			m_pDocs[num].nCataID=idxCata; 
			num++; 
		} 
	} 
	CDocNode::AllocTempBuffer(m_lstTrainWordList.GetCount()); 
} 
 
void CClassifier::Sort(DocWeight *pData,int nSize) 
{ 
	QuickSort(pData,0,nSize); 
} 
 
void CClassifier::QuickSort(DocWeight *psData, int iLo,int iHi) 
{ 
    int Lo, Hi; 
	double Mid; 
	DocWeight	t; 
    Lo = iLo; 
    Hi = iHi; 
    Mid = psData[(Lo + Hi)/2].dWeight; 
    do 
	{ 
		while(psData[Lo].dWeight > Mid) Lo++; 
		while(psData[Hi].dWeight < Mid) Hi--; 
		if(Lo <= Hi) 
		{ 
			t = psData[Lo]; 
			psData[Lo]=psData[Hi]; 
			psData[Hi]=t; 
			Lo++; 
			Hi--; 
		} 
	}while(Hi>Lo); 
    if(Hi > iLo) QuickSort(psData, iLo, Hi); 
    if(Lo < iHi) QuickSort(psData, Lo, iHi); 
} 
 
//将分类结果保存到文件strFileName中,返回正确分类的文档总数 
//如果分类参数中要求拷贝文件到结果类别目录,则执行拷贝操作 
//参数typeArray只有在多类分类,且需要进行评价的时候才会用到 
long CClassifier::SaveResults(CCatalogList &cataList, CString strFileName, CStringArray *aryType) 
{ 
	FILE *stream; 
	if( (stream = fopen(strFileName, "w+" )) == NULL ) 
	{ 
		CMessage::PrintError("无法创建分类结果文件"+strFileName+"!"); 
		return -1; 
	} 
 
	CString str1,str2; 
	long lCorrect=0; 
	long docID=0; 
	int i; 
	char path[MAX_PATH]; 
	CArray aryResult; 
	CArray aryAnswer; 
	double dThreshold=(double)m_paramClassifier.m_dThreshold/100.0; 
 
	POSITION pos_cata=cataList.GetFirstPosition(); 
	while(pos_cata!=NULL) 
	{ 
		CCatalogNode& cataNode=cataList.GetNext(pos_cata); 
		short id=cataNode.m_idxCata; 
		strcpy(path,cataNode.m_strDirName.GetBuffer(0)); 
		POSITION pos_doc=cataNode.GetFirstPosition(); 
		while(pos_doc!=NULL) 
		{ 
			CDocNode& docNode=cataNode.GetNext(pos_doc); 
			if(docNode.m_nCataID<0) continue; 
			str1.Empty(); 
			str2.Empty(); 
			//如果是多类分类 
			if(m_paramClassifier.m_nClassifyType==CClassifierParam::nFT_Multi) 
			{ 
				MultiCategory(docNode,aryResult,dThreshold); 
				//如果需要将分类结果拷贝到分类结果目录 
				if(m_paramClassifier.m_bCopyFiles) 
				{ 
					for(i=0;iGetAt(docID).GetBuffer(0),aryAnswer); 
					//得到答案字符串 
					for(i=0;idMaxNum) 
		{ 
			dMaxNum=pSimRatio[i]; 
			nCataID=i; 
		} 
	} 
	docNode.m_nCataID=nCataID; 
	return nCataID; 
} 
 
short CClassifier::SVMCategory(char *pPath, CDocNode &docNode, bool bFile) 
{ 
	short nCataID=-1; 
	if(SVMClassify(pPath,docNode,bFile)) nCataID=SingleCategory(docNode); 
	return nCataID; 
} 
 
short CClassifier::SVMCategory(char *file, bool bFile) 
{ 
	CDocNode docNode; 
	short id=-1; 
	if(bFile) 
	{ 
		char *fname=strrchr(file,'\\'); 
		if(fname==NULL) return -1; 
		docNode.m_strDocName=(fname+1); 
 
		char path[MAX_PATH]; 
		strncpy(path,file,fname-file); 
		path[fname-file]=0; 
		id=SVMCategory(path,docNode,bFile); 
	} 
	else 
		id=SVMCategory(file,docNode,bFile); 
	return id; 
} 
 
void CClassifier::SVMClassifyVectorFile(CString strFileName) 
{ 
	//为了计算分类结果,用来保存每个分类器分类结果的数组 
	CTime startTime; 
	CTimeSpan totalTime; 
	CString str; 
	long num=m_lstTestCatalogList.GetDocNum(),lDocNum=0; 
	double *fpWeight=new double[num]; 
	POSITION pos_doc, pos_cata; 
 
	m_theSVM.com_param.classifyfile=strFileName; 
	for(int i=1;i<=m_nClassNum;i++) 
	{ 
		memset(fpWeight,0,sizeof(double)*num); 
		startTime=CTime::GetCurrentTime(); 
		str.Format("正在使用第%d个SVM分类器对文档进行分类,请稍候...",i); 
		CMessage::PrintInfo(str); 
		str.Format("%s\\%s%d.mdl",m_paramClassifier.m_txtResultDir,m_paramClassifier.m_strModelFile,i); 
		m_theSVM.com_param.modelfile=str; 
		m_theSVM.svm_classify(i,fpWeight); 
		//将文档和当前类别的相似度赋给m_pResults[i-1] 
		lDocNum=0; 
		pos_cata=m_lstTestCatalogList.GetFirstPosition(); 
		while(pos_cata!=NULL) 
		{ 
			CCatalogNode& catalognode=m_lstTestCatalogList.GetNext(pos_cata); 
			pos_doc=catalognode.GetFirstPosition(); 
			while(pos_doc!=NULL) 
			{ 
				CDocNode& docnode=catalognode.GetNext(pos_doc); 
				docnode.AllocResultsBuffer(m_nClassNum); 
				docnode.m_pResults[i-1]=fpWeight[lDocNum]; 
				lDocNum++; 
			} 
		} 
		totalTime=CTime::GetCurrentTime()-startTime; 
		str.Format("第%d个SVM分类器分类结束,耗时",i); 
		CMessage::PrintInfo(str+totalTime.Format("%H:%M:%S")); 
	} 
	delete[] fpWeight; 
 
	//计算和文档的相似度最大的类别 
	pos_cata=m_lstTestCatalogList.GetFirstPosition(); 
	while(pos_cata!=NULL) 
	{ 
		CCatalogNode& catalognode=m_lstTestCatalogList.GetNext(pos_cata); 
		pos_doc=catalognode.GetFirstPosition(); 
		while(pos_doc!=NULL) 
		{ 
			CDocNode& docnode=catalognode.GetNext(pos_doc); 
			docnode.m_nCataID=SingleCategory(docnode); 
		} 
	} 
} 
 
//计算文档docNode属于每一个类别的概率 
void CClassifier::ComputePro(CDocNode &docNode) 
{ 
	//计算文档与训练集中每一类文档的概率 
	int i,j,l,FeaNum=0; 
	long k; 
	 
	int			N;				//总的文档数; 
	int			N_c;			//C类的文档数 
	int			N_Cata;			//总类数 
	N = m_lstTrainCatalogList.GetDocNum(); 
	 
	POSITION	pos_cata; 
	CString     strWord; 
 
	// calculate the weight of each word to all catalog 
	N = m_lstTrainCatalogList.GetDocNum(); 
	N_Cata=m_lstTrainCatalogList.GetCataNum(); 
	 
	docNode.AllocResultsBuffer(m_nClassNum); 
 
	int nCataCount=0; 
	double ClassPro=0.0; 
	pos_cata = m_lstTrainCatalogList.GetFirstPosition(); 
 
	i=m_lstTrainWordList.GetCount(); 
	POSITION pos=m_lstTrainWordList.GetFirstPosition(); 
 
 
	for(l=0;l0)&&(nCount>0)) 
	{ 
		ComputePro(docNode); 
		return true; 
	} 
	else 
		return false; 
} 
 
 
//计算文档和每个类别的相似度,返回与文档相似度最大的类别ID 
short CClassifier::BAYESCategory(char *pPath, CDocNode &docNode, bool bFile) 
{ 
	BAYESClassify(pPath,docNode,bFile); 
	return docNode.m_nCataID; 
} 
 
long CClassifier::BAYESClassify(CCatalogList &cataList) 
{ 
	long docID=0,lUnknown=0; 
	CString str; 
	POSITION pos_cata=cataList.GetFirstPosition(); 
	while(pos_cata!=NULL) 
	{ 
		CCatalogNode& cataNode=cataList.GetNext(pos_cata); 
		POSITION pos_doc=cataNode.GetFirstPosition(); 
		char *path=cataNode.m_strDirName.GetBuffer(0); 
		while(pos_doc!=NULL) 
		{ 
			CDocNode& docNode=cataNode.GetNext(pos_doc); 
			short id=BAYESCategory(path, docNode, true); 
			if(id==-1)  
			{ 
				str="无法识别文档"; 
				str+=cataNode.m_strDirName; 
				str+="\\"+docNode.m_strDocName+"的类别!"; 
				CMessage::PrintError(str); 
				lUnknown++; 
			} 
			CMessage::PrintStatusInfo(_T("扫描文档")+docNode.m_strDocName); 
		} 
	} 
	return lUnknown; 
} 
 
 
 
 
 
short CClassifier::GetCategory(char *file, bool bFile) 
{ 
	short result=-1; 
	if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN) 
		result=KNNCategory(file,bFile); 
	else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM) 
		result=SVMCategory(file,bFile); 
	return result; 
} 
 
short CClassifier::GetCategory(char *path, CDocNode &docNode, bool bFile) 
{ 
	short result=-1; 
	if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN) 
		result=KNNCategory(path,docNode,bFile); 
	else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM) 
		result=SVMCategory(path,docNode,bFile); 
	return result; 
} 
 
bool CClassifier::Classify(char *path, CDocNode &docNode, bool bFile) 
{ 
	bool result=false; 
	if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN) 
		result=KNNClassify(path,docNode,bFile); 
	else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM) 
		result=SVMClassify(path,docNode,bFile); 
	else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_BAYES) 
		result=BAYESClassify(path,docNode,bFile); 
	return result; 
} 
 
long CClassifier::Classify(CCatalogList &cataList) 
{ 
	long lUnknown=0; 
	if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_KNN) 
		lUnknown=KNNClassify(m_lstTestCatalogList); 
	else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_SVM) 
		lUnknown=SVMClassify(m_lstTestCatalogList); 
	else if(m_paramClassifier.m_nClassifierType==CClassifierParam::nCT_BAYES) 
		lUnknown=BAYESClassify(m_lstTestCatalogList); 
	else  
		CMessage::PrintError("无法确定分类器的类型!"); 
	return lUnknown; 
}