www.pudn.com > snippets(1).rar > crossvalidation.m, change:2009-10-13,size:10487b


function [confusion_matrix, confusion_details, rate, confusion_matrix_seq, confusion_details_seq, rate_seq] = crossvalidation(experiment_name)
%
% test performance with cross-validation

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%     define global parameters
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

  % basic information
  Globals.DISPLAY = false; % detailed screen output or summary?
  Globals.class_names = {'bend', 'jack', 'jump', 'pjump', 'side', 'run', 'walk', 'wave1', 'wave2'}; % Weizmann
  % Globals.class_names = {'box', 'clp', 'jog', 'run', 'wlk', 'wav'}; % KTH
  Globals.classes = length(Globals.class_names);

  % cross-validation settings
  Globals.crossval_folds = 9; % Weizmann
  % Globals.crossval_folds = 5; % KTH
  Globals.crossval_iterations = 1; % number of repeats
  Globals.crossval_scramble = false; % don't scramble if working with precomputed basis!!!

  % dimensionality reduction parameters
  Globals.pca_preserve_flow = -500; % 0.5; % negative number;: # basis vectors, positive number: fraction of variance
  Globals.pca_preserve_form = -500; % 0.7;
  Globals.pca_normalize = true; % normalize length of image vectors?

  % read features from file, or compute?
  Globals.features_tofile = true;
  Globals.features_fromfile = ~Globals.features_tofile;

  if Globals.crossval_scramble && (Globals.features_tofile || Globals.features_fromfile)
    fprintf('WARNING: Features from disk make no sense with scrambled subject order - turning disk I/O off\n');
    Globals.features_tofile = false;
    Globals.features_fromfile = false;
  end

  % support vector machine parameters
  Globals.svm_ova = true;
  Globals.svm_kernel = 'linear';
  Globals.svm_kernelpar = [];
  Globals.svm_maxweight = repmat([Globals.classes-1 1],Globals.classes,1);

  Globals.relative_weight = 0.7;

  % number of frames per snippet
  Globals.concat_frames = 1;

  %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  %   data loading and preprocessing
  %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

  % path to svm files (using code by Schwaighofer)
  addpath('./svm/');

  % path to low-level filter files (code by Kovesi and Konrad)
  addpath('./filters/');

  % read data
  fprintf('\n');
  if Globals.DISPLAY, fprintf('Reading filenames and ground truth from file...'); end
  [data, class] = read_exp(experiment_name);
  if Globals.DISPLAY, fprintf('DONE\n'); end

  % allocate confusion matrix etc.
  confusion_matrix = zeros(Globals.classes);
  confusion_details = zeros(Globals.crossval_iterations*Globals.crossval_folds, Globals.classes);
  confusion_matrix_seq = zeros(Globals.classes);
  confusion_details_seq = zeros(Globals.crossval_iterations*Globals.crossval_folds, Globals.classes);
  conf_det_ct = 1;

  %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  %    run multiple evaluations with different data order
  %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

  for q = 1:Globals.crossval_iterations

    if Globals.DISPLAY, fprintf('\n'); end
    fprintf('ITERATION %d / %d - running %d-fold cross validation\n', q, Globals.crossval_iterations, Globals.crossval_folds);

    % randomly permute subjects to avoid biases
    [data, class] = scramble_subjects(data, class, Globals.crossval_scramble);

    % make cross-validation data sets
    [f_start, f_end] = crossval_boundaries(length(data), Globals.crossval_folds);

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %  repeat training and testing with different parts of data
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

    for u = 1:Globals.crossval_folds

      %--------------------------------------
      %         SPLIT INPUT DATA
      %--------------------------------------

      if Globals.DISPLAY, fprintf('\n'); end
      fprintf('  CROSS-VALIDATION, fold %d / %d: test data = %d--%d\n', ...
	      u, Globals.crossval_folds, f_start(u), f_end(u));
      if Globals.DISPLAY, fprintf('\n'); end

      % split data into training and test
      if Globals.DISPLAY, fprintf('    Removing test data from data set...'); end
      [train_data, train_class, test_data, test_class] = split_data(data, class, f_start(u), f_end(u));
      if Globals.DISPLAY, fprintf('DONE\n'); end

      %--------------------------------------
      %             TRAINING
      %--------------------------------------

      % learn feature extraction from training data
      if Globals.DISPLAY, fprintf('    Learning image representation...'); end

      if Globals.features_fromfile
	if Globals.DISPLAY, fprintf('read features from file...'); end
	[train_form,train_flow,framesperseq] = read_features(data,u,'train',Globals.concat_frames);
      else
	[train_form, pca_form, framesperseq] = learn_pca_representation(train_data, Globals, 'form');
	[train_flow, pca_flow, framesperseq] = learn_pca_representation(train_data, Globals, 'flow');
	if Globals.features_tofile
	  write_features(data,u,train_form,train_flow,framesperseq,'train',Globals.concat_frames);
	end
      end
      if Globals.DISPLAY, fprintf('DONE\n'); end

      % concatenate form and flow features
      if Globals.relative_weight<=0
	train_data_red = train_form;
      elseif Globals.relative_weight>=1
	train_data_red = train_flow;
      else
	train_data_red = [(1-Globals.relative_weight)*train_form ; Globals.relative_weight*train_flow];
      end

      % clean up class index
      frpseq = size(train_data_red,2)/length(train_class);
      train_class_frame = reshape(repmat(train_class,frpseq,1),frpseq*length(train_class),1)';

      % train classifier on training features
      if Globals.DISPLAY, fprintf('    Training classifier'); end
      if Globals.svm_ova
	classifier = learn_classification_ova(train_data_red, train_class_frame, Globals);
      else
	classifier = learn_classification_npairs(train_data_red, train_class_frame, Globals);
      end
      if Globals.DISPLAY, fprintf('DONE\n'); end

      %--------------------------------------
      %             TESTING
      %--------------------------------------
      if Globals.DISPLAY, fprintf('    Testing classifier...'); end

      % extract features from test data
      if Globals.features_fromfile
	if Globals.DISPLAY, fprintf('read features from file...'); end
	[test_form,test_flow,framesperseq] = read_features(data,u,'test',Globals.concat_frames);
      else
	if Globals.DISPLAY, fprintf('extracting descriptors...'); end
	test_form = extract_pca_descriptors(test_data, pca_form, Globals, 'form');
	test_flow = extract_pca_descriptors(test_data, pca_flow, Globals, 'flow');
	if Globals.features_tofile
	  write_features(data,u,test_form,test_flow,framesperseq,'test',Globals.concat_frames);
	end
      end

      % concatenate form and flow features
      if Globals.relative_weight<=0
	test_data_red = test_form;
      elseif Globals.relative_weight>=1
	test_data_red = test_flow;
      else
	test_data_red = [(1-Globals.relative_weight)*test_form ; Globals.relative_weight*test_flow];
      end

      % clean up class index
      test_class_frame = reshape(repmat(test_class,frpseq,1),frpseq*length(test_class),1)';

      % test classifier on test data
      if Globals.DISPLAY, fprintf('classifying test data...'); end
      if Globals.svm_ova
	[result,curr_strength] = classify_ova(classifier, test_data_red');
      else
      	result = classify_npairs(classifier, test_data_red');
      end
      if Globals.DISPLAY, fprintf('DONE\n'); end

      % consensus for sequence-level tests
      result_seq = majority_vote(result,frpseq);

      %--------------------------------------
      %         TEST STATISTICS
      %--------------------------------------

      % compare to ground truth
      good(u) = nnz(result==test_class_frame');
      bad(u) = nnz(result~=test_class_frame');
      rate = 100*good(u)/length(test_class_frame);
      fprintf('    %d correct / %d incorrect classifications...FRAME    classification rate = %4.1f...DONE\n',...
	      good(u), bad(u), rate)

      % confusion matrix
      confusion_matrix_curr = zeros(Globals.classes);
      ct_err = 1;
      for v = 1:length(test_class_frame)
	correct = test_class_frame(v);
	estimated = result(v);
	confusion_matrix_curr(estimated, correct) = confusion_matrix_curr(estimated, correct)+1;
      end

      confusion_matrix = confusion_matrix+confusion_matrix_curr;
      confusion_details(conf_det_ct,:) = 100*diag(double(confusion_matrix_curr))'./sum(double(confusion_matrix_curr),1);

      %%% K % if Globals.DISPLAY, CONFUSION_MATRIX = confusion_matrix_curr, end

      % compare to ground truth on sequence level
      good_seq(u) = nnz(result_seq==test_class');
      bad_seq(u) = nnz(result_seq~=test_class');
      rate_seq = 100*good_seq(u)/length(test_class);
      fprintf('    %d correct / %d incorrect classifications...SEQUENCE classification rate = %4.1f...DONE\n',...
	      good_seq(u), bad_seq(u), rate_seq)

      % confusion matrix
      confusion_matrix_curr_seq = zeros(Globals.classes);
      ct_err = 1;
      for v = 1:length(test_class)
	correct = test_class(v);
	estimated = result_seq(v);
	confusion_matrix_curr_seq(estimated, correct) = confusion_matrix_curr_seq(estimated, correct)+1;
      end

      confusion_matrix_seq = confusion_matrix_seq+confusion_matrix_curr_seq;
      confusion_details_seq(conf_det_ct,:) = 100*diag(double(confusion_matrix_curr_seq))'./sum(double(confusion_matrix_curr_seq),1);
      conf_det_ct = conf_det_ct+1;

    end % for u = 1:Globals.crossval_folds

  end % for q = 1:Globals.crossval_iterations

  %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  %          display results
  %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

  % compute averages
  rate = [mean(confusion_details(:)) std(confusion_details(:))];
  rate_seq = [mean(confusion_details_seq(:)) std(confusion_details_seq(:))];

  % output results
  fprintf('\n');
  fprintf('CROSS-VALIDATION result: overall FRAME    classification rate = %4.1f +/- %4.1f\n', rate(1), rate(2));
  fprintf('CROSS-VALIDATION result: overall SEQUENCE classification rate = %4.1f +/- %4.1f\n', rate_seq(1), rate_seq(2));
  fprintf('\n');

  % display confusion matrix
  CONFUSION_MATRIX = confusion_matrix,

  % display per-class results
  CONFUSION_DETAILS = [ mean(confusion_details) ; std(confusion_details) ],

  % display confusion matrix
  CONFUSION_MATRIX_SEQ = confusion_matrix_seq,

  % display per-class results
  CONFUSION_DETAILS_SEQ = [ mean(confusion_details_seq) ; std(confusion_details_seq) ],