www.pudn.com > snippets(1).rar > svmtrain.m, change:2009-10-12,size:21759b


function net = svmtrain(net, X, Y, alpha0, dodisplay)
% SVMTRAIN - Train a Support Vector Machine classifier
%
%   NET = SVMTRAIN(NET, X, Y)
%   Train the SVM given by NET using the training data X with target values
%   Y. X is a matrix of size (N,NET.nin) with N training examples (one per
%   row). Y is a column vector containing the target values (classes) for
%   each example in X. Each element of Y that is >=0 is treated as class
%   +1, each element <0 is treated as class -1.
%   SVMTRAIN normally uses L1-norm of all training set errors in the
%   objective function. If NET.use2norm==1, L2-norm is used.
%
%   All training parameters are given in the structure NET. Relevant
%   parameters are mainly NET.c, for fine-tuning also NET.qpsize,
%   NET.alphatol and NET.kkttol. See function SVM for a description of
%   these fields.
%
%   NET.c is a weight for misclassifying a particular example. NET.c may
%   either be a scalar (where all errors have the same weights), or it may
%   be a column vector (size [N 1]) where entry NET.c(i) corresponds to the
%   error weight for example X(i,:). If NET.c is e vector of length 2,
%   NET.c(1) specifies the error weight for all positive examples, NET.c(2)
%   is the error weight for all negative examples. Specifying a different
%   weight for each example may be used for imbalanced data sets.
%
%   NET = SVMTRAIN(NET, X, Y, ALHPA0) uses the column vector ALPHA0 as
%   the initial values for the coefficients NET.alpha. ALPHA0 may result
%   from a previous training with different parameters.
%   NET = SVMTRAIN(NET, X, Y, ALPHA0, 1) displays information on the
%   training progress (number of errors in the current iteration, etc)
%   SVMTRAIN uses either the function LOQO (Matlab-Interface to Smola's
%   LOQO code) or the routines QP/QUADPROG from the Matlab Optimization
%   Toolbox to solve the quadratic programming problem.
%
%   See also:
%   SVM, SVMKERNEL. SVMFWD
%

% 
% Copyright (c) Anton Schwaighofer (2001)
% $Revision: 1.19 $ $Date: 2002/01/09 12:11:41 $
% mailto:anton.schwaighofer@gmx.net
% 
% This program is released unter the GNU General Public License.
% 

% Training a SVM involves solving a quadratic programming problem that
% scales quadratically with the number of examples. SVMTRAIN uses the
% decomposed training algorithm proposed by Osuna, Freund and Girosi, where
% the maximum size of a quadratic program is constant.
% (ftp://ftp.ai.mit.edu/pub/cbcl/nnsp97-svm.ps)
% For selecting the working set, the approximation proposed by Joachims
% (http://www-ai.cs.uni-dortmund.de/DOKUMENTE/joachims_99a.ps.gz) is used.

Y = double(Y);
warning off;

% Check arguments for consistency
errstring = consist(net, 'svm', X, Y);
if ~isempty(errstring);
  error(errstring);
end
[N, d] = size(X);
if N==0,
  error('No training examples given');
end
net.nbexamples = N;
if nargin<5,
  dodisplay = 0;
end
if nargin<4,
  alpha0 = [];
elseif (~isempty(alpha0)) & (~all(size(alpha0)==[N 1])),
  error(['Initial values ALPHA0 must be a column vector with the same length' ...
	 ' as X']); 
end

% Find the indices of examples from class +1 and -1
class1 = logical(uint8(Y>=0));
class0 = logical(uint8(Y<0));

if length(net.c(:))==1,
  C = repmat(net.c, [N 1]);
  % The same upper bound for all examples
elseif length(net.c(:))==2,
  C = zeros([N 1]);
  C(class1) = net.c(1);
  C(class0) = net.c(2);
  % Different upper bounds C for the positive and negative examples
else
  C = net.c;
  if ~all(size(C)==[N 1]),
    error(['Upper bound C must be a column vector with the same length' ...
	   ' as X']); 
  end
end
if min(C)<net.alphatol,
  error('NET.C must be positive and larger than NET.alphatol');
end

if ~isfield(net, 'use2norm'),
  net.use2norm = 0;
end

if ~isfield(net, 'qpsolver'),
  net.qpsolver = '';
end
qpsolver = net.qpsolver;
if isempty(qpsolver),
  % QUADPROG is the fastest solver for both 1norm and 2norm SVMs, if
  % qpsize is around 10-70 (loqo is best for large 1norm SVMs)
  checkseq = {'quadprog', 'loqo', 'qp'};
  i = 1;
  while (i <= length(checkseq)),
    e = exist(checkseq{i});
    if (e==2) | (e==3),
      qpsolver = checkseq{i};
      break;
    end
    i = i+1;
  end
  if isempty(qpsolver),
    error('No quadratic programming solver (QUADPROG,LOQO,QP) found.');
  end
end
% Mind that there may occur problems with the QUADPROG solver. At least in
% early versions of Matlab 5.3 there are severe numerical problems somewhere
% deep in QUADPROG

% Turn off all messages coming from quadprog, increase the maximum number
% of iterations from 200 to 500 - good for low-dimensional problems
if strcmp(qpsolver, 'quadprog') & (dodisplay==0),
  quadprogopt = optimset('Display', 'off', 'MaxIter', 500);
else
  quadprogopt = [];
end

% Actual size of quadratic program during training may not be larger than
% the number of examples
QPsize = min(N, net.qpsize);
chsize = net.chunksize;

% SVMout contains the output of the SVM decision function for each
% example. This is updated iteratively during training.
SVMout = zeros(N, 1);

% Make sure there are no other values in Y than +1 and -1
Y(class1) = 1;
Y(class0) = -1;
if dodisplay>0,
  fprintf('Training set: %i examples (%i positive, %i negative)\n', ...
	  length(Y), length(find(class1)), length(find(class0)));
end

% Start with a vector of zeros for the coefficients alpha, or the
% parameter ALPHA0, if it is given. Those values will be used to perform
% an initial working set selection, by assuming they are the true weights
% for the training set at hand.
if ~any(alpha0),
  net.alpha = zeros([N 1]);
  % If starting with a zero vector: randomize the first working set search
  randomWS = 1;
else
  randomWS = 0;
  % for 1norm SVM: make the initial values conform to the upper bounds
  if ~net.use2norm,
    net.alpha = min(C, alpha0);
  end
end
alphaOld = net.alpha;

if length(find(Y>0))==N,
  % only positive examples
  net.bias = 1;
  net.svcoeff = [];
  net.sv = [];
  net.svind = [];
  net.alpha = zeros([N 1]);
  return;
elseif length(find(Y<0))==N,
  % only negative examples
  net.bias = 1;
  net.svcoeff = [];
  net.sv = [];
  net.svind = [];
  net.alpha = zeros([N 1]);
  return;
end

iteration = 0;
workset = logical(uint8(zeros(N, 1)));
sameWS = 0;
net.bias = 0;

while 1,

  if dodisplay>0,
    fprintf('\nIteration %i: ', iteration+1);
  end

  % Step 1: Determine the Support Vectors.
  [net, SVthresh, SV, SVbound, SVnonbound] = findSV(net, C);
  if dodisplay>0,
    fprintf(['Working set of size %i: %i Support Vectors, %i of them at' ...
	     ' bound C\n'], length(find(workset)), length(find(workset & SV)), ...
	    length(find(workset & SVbound))); 
    fprintf(['Whole training set: %i Support Vectors, %i of them at upper' ...
	     ' bound C.\n'], length(net.svind), length(find(SVbound)));
    if dodisplay>1,
      fprintf('The Support Vectors (threshold %g) are the examples\n', ...
	      SVthresh);
      fprintf(' %i', net.svind);
      fprintf('\n');
    end
  end

  
  % Step 2: Find the output of the SVM for all training examples
  if (iteration==0) | (mod(iteration, net.recompute)==0),
    % Every NET.recompute iterations the SVM output is built from
    % scratch. Use all Support Vectors for determining the output.
    changedSV = net.svind;
    changedAlpha = net.alpha(changedSV);
    SVMout = zeros(N, 1);
    if strcmp(net.kernel, 'linear'),
      net.normalw = zeros([1 d]);
    end
  else
    % A normal iteration: Find the coefficients that changed and adjust
    % the SVM output only by the difference of old and new alpha
    changedSV = find(net.alpha~=alphaOld);
    changedAlpha = net.alpha(changedSV)-alphaOld(changedSV);
  end
  
  if strcmp(net.kernel, 'linear'),
    chunks = ceil(length(changedSV)/chsize);
    % Linear kernel: Build the normal vector of the separating
    % hyperplane by computing the weighted sum of all Support Vectors
    for ch = 1:chunks,
      ind = (1+(ch-1)*chsize):min(length(changedSV), ch*chsize);
      temp = changedAlpha(ind).*Y(changedSV(ind));
      net.normalw = net.normalw+temp'*X(changedSV(ind), :);
    end
    % Find the output of the SVM by multiplying the examples with the
    % normal vector
    SVMout = zeros(N, 1);
    chunks = ceil(N/chsize);
    for ch = 1:chunks,
      ind = (1+(ch-1)*chsize):min(N, ch*chsize);
      SVMout(ind) = X(ind,:)*(net.normalw');
    end
  else
    % A normal kernel function: Split both the examples and the Support
    % Vectors into small chunks
    chunks1 = ceil(N/chsize);
    chunks2 = ceil(length(changedSV)/chsize);
    for ch1 = 1:chunks1,
      ind1 = (1+(ch1-1)*chsize):min(N, ch1*chsize);
      for ch2 = 1:chunks2,
	% Compute the kernel function for a chunk of Support Vectors and
        % a chunk of examples
	ind2 = (1+(ch2-1)*chsize):min(length(changedSV), ch2*chsize);
	K12 = svmkernel(net, X(ind1, :), X(changedSV(ind2), :));
	% Add the weighted kernel matrix to the SVM output. In update
        % cycles, the kernel matrix is weighted by the difference of
        % alphas, in other cycles it is weighted by the value alpha alone.
	coeff = changedAlpha(ind2).*Y(changedSV(ind2));
	SVMout(ind1) = SVMout(ind1)+K12*coeff;
      end
      if dodisplay>2,
	K1all = svmkernel(net, X(ind1,:), X(net.svind,:));
	coeff2 = net.alpha(net.svind).*Y(net.svind);
	fprintf('Maximum error due to matrix partitioning: %g\n', ...
		max((SVMout(ind1)-K1all*coeff2)'));
      end
    end
  end

  
  % Step 3: Compute the bias of the SVM decision function.
  if net.use2norm,
    % The bias can be found from the SVM output for Support Vectors. For
    % those vectors, the output should be 1-alpha/C resp -1+alpha/C.
    workSV = find(SV & workset);
    if ~isempty(workSV),
      net.bias = mean((1-net.alpha(workSV)./C(workSV)).*Y(workSV)- ...
                      SVMout(workSV));
    end
  else
    % normal 1norm SVM:
    % The bias can be found from Support Vector whose value alpha is not at
    % the upper bound. For those vectors, the SVM output should be +1
    % resp. -1.
    workSV = find(SVnonbound & workset);
    if ~isempty(workSV),
      net.bias = mean(Y(workSV)-SVMout(workSV));
    end
  end
  % The nasty case that no SVs to determine the bias have been found.
  % The only sensible thing do to is to leave the bias unchanged.
  if isempty(workSV) & (dodisplay>0),
    disp('No Support Vectors in the current working set.');
    disp('Leaving the bias unchanged.');
  end

  
  % Step 4: Compute the values of the Karush-Kuhn-Tucker conditions
  % of the quadratic program. If no violations of these conditions are
  % found, the optimal solution has been found, and we are finished.
  % KKT describes how correct each example is classified. KKT must be
  %   positive for all examples that are on the correct side and that are
  %     not Support Vectors
  %   0 for all Support Vectors
  %   negative for all examples on the wrong side of the hyperplane
  if net.use2norm,
    KKT = (SVMout+net.bias).*Y-1+net.alpha./C;
    KKTviolations = logical(uint8((SV & (abs(KKT)>net.kkttol)) | ...
                                  (~SV & (KKT<-net.kkttol))));
  else
    KKT = (SVMout+net.bias).*Y-1;
    KKTviolations = logical(uint8((SVnonbound & (abs(KKT)>net.kkttol)) | ...
                                  (SVbound & (KKT>net.kkttol)) | ...
                                  (~SV & (KKT<-net.kkttol))));
  end
  ind = find(KKTviolations & workset);
  if ~isempty(ind),
    % The coefficients alpha for the current working set have just been
    % optimised, non of those should violate the KKT conditions.
    if dodisplay>0,
      fprintf('KKT conditions not met in working set (value %g)', ...
              max(abs(KKT(ind))));
    end
  end
  if dodisplay>0,
    fprintf('%i violations of the KKT conditions.\n', ... 
	    length(find(KKTviolations)));
    fprintf(['(%i violations from positive examples, %i from negative' ...
	     ' examples)\n'], length(find(KKTviolations & class1)), ...
	    length(find(KKTviolations & class0)));
    if (dodisplay>1) & ~isempty(find(KKTviolations)),
      disp('The following examples violate the KKT conditions:');
      fprintf(' %i', find(KKTviolations));
      fprintf('\n');
    end
  end
  % Check how many violations of the KKT-conditions have been found. If
  % none, we are finished.
  if length(find(KKTviolations)) == 0,
    break;
  end
  
  % Step 5: Determine a new working set. To this aim, a linear
  % approximation of the objective function is made. The new working set
  % constitutes of the QPSIZE largest elements of the gradient of the
  % linear approximation. The gradient of the linear approximation can be
  % expressed using the ouput of the SVM on all training examples.
  if net.use2norm,
    searchDir = SVMout+Y.*(net.alpha./C-1);
    set1 = logical(uint8(SV | class0));
    set2 = logical(uint8(SV | class1));
  else
    searchDir = SVMout-Y;
    set1 = logical(uint8((SV | class0) & (~SVbound | class1)));
    set2 = logical(uint8((SV | class1) & (~SVbound | class0)));
  end
  % During the very first iteration: If no initial values for net.alpha
  % are given, perform a random working set selection
  if randomWS,
    searchDir = rand([N 1]);
    set1 = class1;
    set2 = class0;
    randomWS = 0;
  end

  % Step 6: Select the working set.
  % Goal is to select an equal number of examples from set1 and set2
  % (QPsize/2 examples from set1, QPsize/2 from set2). The examples from
  % set1 are the QPsize/2 highest elements of searchDir for set1,
  % the examples from set2 are the QPsize/2 smallest elements of searchDir
  % for set2.
  worksetOld = workset;
  workset = logical(uint8(zeros(N, 1)));
  if length(find(set1 | set2)) <= QPsize,
    workset(set1 | set2) = 1;
    % Less than QPsize examples to select from: Use them all
  elseif length(find(set1)) <= floor(QPsize/2),
    workset(set1) = 1;
    % set1 has less than half QPsize examples: Use all of set1, fill the
    % rest with examples from set2 starting with the ones that have low
    % values for searchDir
    set2 = find(set2 & ~workset);
    [dummy, ind] = sort(searchDir(set2));
    from2 = min(length(set2), QPsize-length(find(workset)));
    workset(set2(ind(1:from2))) = 1;
  elseif length(find(set2)) <= floor(QPsize/2),
    % set2 has less than half QPsize examples: Use all of set2, fill the
    % rest with examples from set1 starting with the ones that have high
    % values for searchDir
    workset(set2) = 1;
    set1 = find(set1 & ~workset);
    [dummy, ind] = sort(-searchDir(set1));
    from1 = min(length(set1), QPsize-length(find(workset)));
    workset(set1(ind(1:from1))) = 1;
  else
    set1 = find(set1);
    [dummy, ind] = sort(-searchDir(set1));
    from1 = min(length(set1), floor(QPsize/2));
    workset(set1(ind(1:from1))) = 1;
    % Use the QPsize/2 highest values for searchDir from set1
    set2 = find(set2 & ~workset);
    % Make sure that no examples are added twice
    [dummy, ind] = sort(searchDir(set2));
    from2 = min(length(set2), QPsize-length(find(workset)));
    workset(set2(ind(1:from2))) = 1;
    % Use the QPsize/2 lowest values for searchDir from set2
  end
  worksetind = find(workset);
  % Workaround for Matlab bug when indexing sparse arrays with logicals:
  % use index set instead
  
  % Emergency exit: If we end up with the same work set in 2 subsequent
  % iterations, something strange must have happened (for example, the
  % accuracy of the QP solver is insufficient as compared to the required
  % precision given by NET.alphatol and NET.kkttol)
  % Exit immediately if 'loqo' is used, since loqo ignores the start
  % values, so another iteration will not improve the results.
  if all(workset==worksetOld),
    sameWS = sameWS+1;
    if ((sameWS==3) | strcmp(qpsolver, 'loqo')),
      warnstr = 'Working set not changed - check accuracy. Exiting.';
      if dodisplay>0,
        disp(warnstr);
      end
      warning(warnstr);
      break;
    end
  else
    sameWS = 0;
  end
  worksize = length(find(workset));
  nonworkset = ~workset;
  if dodisplay>1,
    disp('Working set consists of examples ');
    fprintf(' %i', find(workset));
    fprintf('\n');
  end

  
  % Step 7: Determine the linear part of the quadratic program. We have
  % determined the working set already. The linear term of the quadratic
  % program is made up of all the kernel evaluations  K(Support Vectors
  % outside of the working set, Support Vectors in the working set)
  nonworkSV = find(nonworkset & SV);
  % All Support Vectors outside of the working set
  qBN = 0;
  if length(nonworkSV)>0,
    % The nonworkSV may be a very large matrix. Split up into smaller
    % chunks.
    chunks = ceil(length(nonworkSV)/chsize);
    for ch = 1:chunks,
      % Indices of the current chunk in NONWORKSV
      ind = (1+(ch-1)*chsize):min(length(nonworkSV), ch*chsize);
      % Evaluate kernel function for working set and the current chunk of
      % non-working set
      Ki = svmkernel(net, X(worksetind, :), X(nonworkSV(ind), :));
      % The linear term qBN for the quadratic program is a column vector
      % given by summing up the kernel evaluations multiplied by the
      % corresponding alpha's and the class labels.
      qBN = qBN+Ki*(net.alpha(nonworkSV(ind)).*Y(nonworkSV(ind)));
    end
    qBN = qBN.*Y(workset);
  end
  % Second linear term is a vector of one's
  f = qBN-ones(worksize, 1);

  
  % Step 8: Solve the quadratic program. The quadratic term of the
  % objective function is made of the examples in the working set, the
  % linear term comes from examples outside of the working set. The so
  % found values WORKALPHA replace the old values NET.alpha for the
  % examples in the working set.
  % Quadratic term H is given by the kernel evaluations for the working
  % set
  H = svmkernel(net, X(worksetind,:), X(worksetind,:));
  if net.use2norm,
    % with 2norm of the slack variables: the quadratic program has values
    % 1/C in the diagonal. Additionally, this makes H better conditioned.
    H = H+diag(1./C(workset));
  else
    % Suggested by Mathworks support for improving the condition
    % number. Condition number should should not be much larger than
    % 1/sqrt(eps) to avoid numerical problems. Condition number of H will
    % now be < eps^(-2/3)
    H = H+diag(ones(worksize, 1)*eps^(2/3));
  end
  H = H.*(Y(workset)*Y(workset)');
  % Matrix A for the equality constraint
  A = Y(workset)';
  % If there are Support Vectors outside of the working set, the equality
  % constraint must give the weighted class labels of all these
  % vectors. Otherwise the equality constraint gives zero.
  if length(nonworkSV)>0,
    eqconstr = -net.alpha(nonworkSV)'*Y(nonworkSV);
  else
    eqconstr = 0;
  end
  % Lower and upper bound for the coefficients alpha in the
  % current working set
  VLB = zeros(worksize, 1);
  if net.use2norm,
    % no upper bound in the 2norm case
    VUB = [];
  else
    % normal 1norm SVM: error weights C are the upper bounds
    VUB = C(workset);
  end
  tic;
  % Solve the quadratic program with 1 equality constraint.
  % Initial guess for the solution of the QP problem.
  startVal = net.alpha(workset);
  switch qpsolver
    case 'quadprog'
      workalpha = quadprog(H, f, [], [], A, eqconstr, VLB, VUB, startVal, ...
                           quadprogopt);
    case 'qp'
      workalpha = qp(H, f, A, eqconstr, VLB, VUB, startVal, 1);
    case 'loqo'
      if isempty(VUB),
        % LOQO crashes if upper bound is missing
        % Use a relatively low value (instead of Inf) for faster
        % convergence
        VUB = repmat(1e7, size(VLB));
      end
      workalpha = loqo(H, f, A, eqconstr, VLB, VUB, startVal, 1);
  end
  t = toc;
  if dodisplay>1,
    fprintf('QP subproblem solved after %i minutes, %2.1f seconds.\n', ...
	  floor(t/60), mod(t, 60));
  end
  % Sometime QUADPROG returns a solution with small imaginary part
  % (usually with ill-posed problems, badly conditioned H)
  if any(imag(workalpha)>0),
    warning(['The QP solver returned a complex solution. '...
             'Check condition number cond(H).']);
    workalpha = real(workalpha);
  end
  % Update the newly found coefficients in NET.alpha
  alphaOld = net.alpha;
  net.alpha(workset) = workalpha;
  
  iteration = iteration+1;
end

% Finished! Store the Support Vectors and the coefficient given by
% NET.alpha and the corresponding label.
net.svcoeff = net.alpha(net.svind).*Y(net.svind);
net.sv = X(net.svind, :);

if dodisplay>0,
  fprintf('\n\n\nTraining finished.\n');
  disp('Information about the trained Support Vector Machine:');
  svmstat(net,1);
  % output statistics over SVs and separating hyperplane
end



function [net, SVthresh, SV, SVbound, SVnonbound] = findSV(net, C)
% FINDSV - Select the Support Vectors from the current coefficients NET.alpha

% Threshold for selecting Support Vectors
maxalpha = max(net.alpha);
if maxalpha > net.alphatol,
  % For most cases, net.alphatol is a reasonable choice (approx 1e-2)
  SVthresh = net.alphatol;
else
  % For complex kernel on small data sets: all alphas will be very small.
  % Use the mean between the minimum and maximum logarithm of values
  % NET.alpha as a threshold.
  SVthresh = exp((log(max(eps,maxalpha))+log(eps))/2);
end
% All examples that have a value of NET.alpha above this threshold are
% assumed to be Support Vectors.
SV = logical(uint8(net.alpha>=SVthresh));
% All Support Vectors that have a value at their upper bound C
if net.use2norm,
  % There is no such thing in the 2norm case!
  SVbound = logical(repmat(uint8(0), size(net.alpha)));
else
  SVbound = logical(uint8(net.alpha>(C-net.alphatol)));
end
% The Support Vectors not at the upper bound
SVnonbound = SV & (~SVbound);
% The actual indices of the Support Vectors in the training set
net.svind = find(SV);