www.pudn.com > smo348.zip > smo.h
#ifndef _SMO_H #define _SMO_H #include "math.h" #include "stdlib.h" #include#include #include #include #include using namespace std; #define INFO struct sparse_binary_vector { vector id; }; struct sparse_vector { vector id; vector val; }; typedef vector dense_vector; /*全局变量*/ extern int N; extern int d; extern double C; extern double tolerance; extern double eps; extern double two_sigma_squared; extern bool is_sparse_data; extern bool is_binary; extern bool is_test_only; extern bool is_linear_kernel; extern double delta_b; extern int end_support_i; extern int first_test_i; extern double (*dot_product_func)(int,int); extern double (*learned_func)(int); extern double (*kernel_func)(int,int); extern vector target; extern vector sparse_binary_points; extern vector sparse_points; extern vector dense_points; extern vector alph; extern double b; extern vector w; extern vector error_cache; extern vector precomputed_self_dot_product; extern double (*dot_product_func)(int,int); extern double (*learned_func)(int); extern double (*kernel_func)(int,int); /************************************************************* * 该函数检查第一个变量是否违背KKT条件,如果不满足KKT条件,则 * 寻找第二个权向量,调用takeStep函数,更新两个权值 * 返回值:如果有权向量被更新,则返回1否则返回0 ************************************************************/ int examineExample(int i1); /************************************************************* * 优化两个lagrange因子 * 成功时返回1,否则返回0 ************************************************************/ int takeStep(int i1,int i2); /************************************************************* * 几个学习函数用来计算f(x)=w*x-b * 用来计算第k个样本的输出 ************************************************************/ double learned_func_linear_sparse_binary(int k);//对应于稀疏二进制数据,线性分类器 double learned_func_linear_sparse_nobinary(int k);//对应于稀疏非二进制数据,线性分类器 double learned_func_linear_dense(int k);//对应于非稀疏数据,线性分类器 double learned_func_nonlinear(int k);//对应于非线性分类器 /*************************************************************** * 计算点积的核函数 * 用来计算两个样本之间的点积 **************************************************************/ double dot_product_sparse_binary(int i1,int i2); double dot_product_sparse_nonbinary(int i1,int i2); double dot_product_dense(int i1,int i2); double rbf_kernel(int i1,int i2); /************************************************************* * 读入数据文件 * 数据文件的格式为 * 每行对应一个样本,格式为: * 对于非稀疏数据:d个属性值(空格分开) 类标签 * 对于稀疏数据: id1 val1 id2 val2 ...idm valm target_value * 对于二进制稀疏数据:id1,id2...idm target_value * 返回读入样本的个数 *************************************************************/ int read_data(istream& is); /************************************************************** * 输出模型参数 * 输出顺序为: * 维数d * 稀疏数据的标志 * 二进制数据的标志 * 线性核函数的标志 * 临街值b * 如果使用的是线性的核函数,则输出权向量 * 如果是非线性的核函数 * 输出核函数的参数 * 支持向量的个数 * 支持向量的权重 * 支持向量,每行一个 *********************************************************************/ void write_svm(ostream& os); /****************************************************** * 按照输出格式,输入模型,返回支持向量的个数 ******************************************************/ int read_svm(istream& is); /******************************************************* * 预测的错误率 * 返回模型的经验风险 *****************************************************/ double error_rate(); /**************************************************************** * smo算法主程序 * 给定训练文件data_file_name(格式请参见read_data函数) * 使用smo算法训练模型,并输出到svm_file_name中 * 返回值: * 0:成功 * 1:无法打开输入文件 * 2:读入训练样本时出错 ***************************************************************/ int smo(string data_file_name,string svm_file_name); /**************************************************************** * 加载svm模型svm_file_name * 返回值: * 0:成功 * 1:无法打开输入文件 * 2:读入模型时出错 ****************************************************************/ int load_svm(string svm_file_name); /************************************************************* * 学习函数用来计算f(x)=w*x-b * 用来计算向量x的输出 ************************************************************/ double predict_func(const vector & vx); /************************************************************** * 内积函数 * 计算两个向量之间的内积 **************************************************************/ double kernel(const vector & vx1,const vector & vx2); #endif