www.pudn.com > weka.rar > CVParameterSelection.java, change:2001-03-14,size:20492b


/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    CVParameterSelection.java
 *    Copyright (C) 1999 Len Trigg
 *
 */

package weka.classifiers;

import java.io.*;
import java.util.*;
import weka.core.*;

/**
 * Class for performing parameter selection by cross-validation for any
 * classifier. For more information, see<p>
 *
 * R. Kohavi (1995). <i>Wrappers for Performance
 * Enhancement and Oblivious Decision Graphs</i>. PhD
 * Thesis. Department of Computer Science, Stanford University. <p>
 *
 * Valid options are:<p>
 *
 * -D <br>
 * Turn on debugging output.<p>
 *
 * -W classname <br>
 * Specify the full class name of classifier to perform cross-validation
 * selection on.<p>
 *
 * -X num <br>
 * Number of folds used for cross validation (default 10). <p>
 *
 * -S seed <br>
 * Random number seed (default 1).<p>
 *
 * -P "N 1 5 10" <br>
 * Sets an optimisation parameter for the classifier with name -N,
 * lower bound 1, upper bound 5, and 10 optimisation steps.
 * The upper bound may be the character 'A' or 'I' to substitute 
 * the number of attributes or instances in the training data,
 * respectively.
 * This parameter may be supplied more than once to optimise over
 * several classifier options simultaneously. <p>
 *
 * Options after -- are passed to the designated sub-classifier. <p>
 *
 * @author Len Trigg (trigg@cs.waikato.ac.nz)
 * @version $Revision: 1.12 $ 
*/
public class CVParameterSelection extends Classifier 
  implements OptionHandler, Summarizable {

  /*
   * A data structure to hold values associated with a single
   * cross-validation search parameter
   */
  protected class CVParameter {

    /**  Char used to identify the option of interest */
    private char m_ParamChar;    

    /**  Lower bound for the CV search */
    private double m_Lower;      

    /**  Upper bound for the CV search */
    private double m_Upper;      

    /**  Increment during the search */
    private double m_Steps;      

    /**  The parameter value with the best performance */
    private double m_ParamValue; 

    /**  True if the parameter should be added at the end of the argument list */
    private boolean m_AddAtEnd;  

    /**  True if the parameter should be rounded to an integer */
    private boolean m_RoundParam;

    /**
     * Constructs a CVParameter.
     */
    public CVParameter(String param) throws Exception {
     
      // Tokenize the string into it's parts
      StreamTokenizer st = new StreamTokenizer(new StringReader(param));
      if (st.nextToken() != StreamTokenizer.TT_WORD) {
	throw new Exception("CVParameter " + param 
			    + ": Character parameter identifier expected");
      }
      m_ParamChar = st.sval.charAt(0);
      if (st.nextToken() != StreamTokenizer.TT_NUMBER) {
	throw new Exception("CVParameter " + param 
			    + ": Numeric lower bound expected");
      }
      m_Lower = st.nval;
      if (st.nextToken() == StreamTokenizer.TT_NUMBER) {
	m_Upper = st.nval;
	if (m_Upper < m_Lower) {
	  throw new Exception("CVParameter " + param
			      + ": Upper bound is less than lower bound");
	}
      } else if (st.ttype == StreamTokenizer.TT_WORD) {
	if (st.sval.toUpperCase().charAt(0) == 'A') {
	  m_Upper = m_Lower - 1;
	} else if (st.sval.toUpperCase().charAt(0) == 'I') {
	  m_Upper = m_Lower - 2;
	} else {
	  throw new Exception("CVParameter " + param 
	      + ": Upper bound must be numeric, or 'A' or 'N'");
	}
      } else {
	throw new Exception("CVParameter " + param 
	      + ": Upper bound must be numeric, or 'A' or 'N'");
      }
      if (st.nextToken() != StreamTokenizer.TT_NUMBER) {
	throw new Exception("CVParameter " + param 
			    + ": Numeric number of steps expected");
      }
      m_Steps = st.nval;
      if (st.nextToken() == StreamTokenizer.TT_WORD) {
	if (st.sval.toUpperCase().charAt(0) == 'R') {
	  m_RoundParam = true;
	}
      }
    }

    /**
     * Returns a CVParameter as a string.
     */
    public String toString() {

      String result = m_ParamChar + " " + m_Lower + " ";
      switch ((int)(m_Lower - m_Upper + 0.5)) {
      case 1:
	result += "A";
	break;
      case 2:
	result += "I";
	break;
      default:
	result += m_Upper;
	break;
      }
      result += " " + m_Steps;
      if (m_RoundParam) {
	result += " R";
      }
      return result;
    }
  }

  /** The generated base classifier */
  protected Classifier m_Classifier = new weka.classifiers.ZeroR();

  /**
   * The base classifier options (not including those being set
   * by cross-validation)
   */
  protected String [] m_ClassifierOptions;

  /** The set of all classifier options as determined by cross-validation */
  protected String [] m_BestClassifierOptions;

  /** The cross-validated performance of the best options */
  protected double m_BestPerformance;

  /** The set of parameters to cross-validate over */
  protected FastVector m_CVParams;

  /** The number of attributes in the data */
  protected int m_NumAttributes;

  /** The number of instances in a training fold */
  protected int m_TrainFoldSize;
  
  /** The number of folds used in cross-validation */
  protected int m_NumFolds = 10;

  /** Random number seed */
  protected int m_Seed = 1;

  /** Debugging mode, gives extra output if true */
  protected boolean m_Debug;

  /**
   * Create the options array to pass to the classifier. The parameter
   * values and positions are taken from m_ClassifierOptions and
   * m_CVParams.
   *
   * @return the options array
   */
  protected String [] createOptions() {
    
    String [] options = new String [m_ClassifierOptions.length 
				   + 2 * m_CVParams.size()];
    int start = 0, end = options.length;

    // Add the cross-validation parameters and their values
    for (int i = 0; i < m_CVParams.size(); i++) {
      CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i);
      double paramValue = cvParam.m_ParamValue;
      if (cvParam.m_RoundParam) {
	paramValue = (double)((int) (paramValue + 0.5));
      }
      if (cvParam.m_AddAtEnd) {
	options[--end] = "" + 
	Utils.doubleToString(paramValue,4);
	options[--end] = "-" + cvParam.m_ParamChar;
      } else {
	options[start++] = "-" + cvParam.m_ParamChar;
	options[start++] = "" 
	+ Utils.doubleToString(paramValue,4);
      }
    }
    // Add the static parameters
    System.arraycopy(m_ClassifierOptions, 0,
		     options, start,
		     m_ClassifierOptions.length);

    return options;
  }

  /**
   * Finds the best parameter combination. (recursive for each parameter
   * being optimised).
   *
   * @param depth the index of the parameter to be optimised at this level
   * @exception Exception if an error occurs
   */
  protected void findParamsByCrossValidation(int depth, Instances trainData)
    throws Exception {

    if (depth < m_CVParams.size()) {
      CVParameter cvParam = (CVParameter)m_CVParams.elementAt(depth);

      double upper;
      switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) {
      case 1:
	upper = m_NumAttributes;
	break;
      case 2:
	upper = m_TrainFoldSize;
	break;
      default:
	upper = cvParam.m_Upper;
	break;
      }
      double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1);
      for(cvParam.m_ParamValue = cvParam.m_Lower; 
	  cvParam.m_ParamValue <= upper; 
	  cvParam.m_ParamValue += increment) {
	findParamsByCrossValidation(depth + 1, trainData);
      }
    } else {
      
      Evaluation evaluation = new Evaluation(trainData);

      // Set the classifier options
      String [] options = createOptions();
      if (m_Debug) {
	System.err.print("Setting options for " 
			 + m_Classifier.getClass().getName() + ":");
	for (int i = 0; i < options.length; i++) {
	  System.err.print(" " + options[i]);
	}
	System.err.println("");
      }
      ((OptionHandler)m_Classifier).setOptions(options);
      for (int j = 0; j < m_NumFolds; j++) {
	Instances train = trainData.trainCV(m_NumFolds, j);
	Instances test = trainData.testCV(m_NumFolds, j);
	m_Classifier.buildClassifier(train);
	evaluation.setPriors(train);
	evaluation.evaluateModel(m_Classifier, test);
      }
      double error = evaluation.errorRate();
      if (m_Debug) {
	System.err.println("Cross-validated error rate: " 
			   + Utils.doubleToString(error, 6, 4));
      }
      if ((m_BestPerformance == -99) || (error < m_BestPerformance)) {
	
	m_BestPerformance = error;
	m_BestClassifierOptions = createOptions();
      }
    }
  }

  /**
   * Returns an enumeration describing the available options
   *
   * @return an enumeration of all the available options
   */
  public Enumeration listOptions() {

    Vector newVector = new Vector(5);

    newVector.addElement(new Option(
	      "\tTurn on debugging output.",
	      "D", 0, "-D"));
    newVector.addElement(new Option(
	      "\tFull name of classifier to perform parameter selection on.\n"
	      + "\teg: weka.classifiers.NaiveBayes",
	      "W", 1, "-W <classifier class name>"));
    newVector.addElement(new Option(
	      "\tNumber of folds used for cross validation (default 10).",
	      "X", 1, "-X <number of folds>"));
    newVector.addElement(new Option(
	      "\tClassifier parameter options.\n"
	      + "\teg: \"N 1 5 10\" Sets an optimisation parameter for the\n"
	      + "\tclassifier with name -N, with lower bound 1, upper bound\n"
	      + "\t5, and 10 optimisation steps. The upper bound may be the\n"
	      + "\tcharacter 'A' or 'I' to substitute the number of\n"
	      + "\tattributes or instances in the training data,\n"
	      + "\trespectively. This parameter may be supplied more than\n"
	      + "\tonce to optimise over several classifier options\n"
	      + "\tsimultaneously.",
	      "P", 1, "-P <classifier parameter>"));
    newVector.addElement(new Option(
	      "\tSets the random number seed (default 1).",
	      "S", 1, "-S <random number seed>"));

    if ((m_Classifier != null) &&
	(m_Classifier instanceof OptionHandler)) {
      newVector.addElement(new Option("",
	        "", 0,
		"\nOptions specific to sub-classifier "
	        + m_Classifier.getClass().getName()
		+ ":\n(use -- to signal start of sub-classifier options)"));
      Enumeration enum = ((OptionHandler)m_Classifier).listOptions();
      while (enum.hasMoreElements()) {
	newVector.addElement(enum.nextElement());
      }
    }
    return newVector.elements();
  }


  /**
   * Parses a given list of options. Valid options are:<p>
   *
   * -D <br>
   * Turn on debugging output.<p>
   *
   * -W classname <br>
   * Specify the full class name of classifier to perform cross-validation
   * selection on.<p>
   *
   * -X num <br>
   * Number of folds used for cross validation (default 10). <p>
   *
   * -S seed <br>
   * Random number seed (default 1).<p>
   *
   * -P "N 1 5 10" <br>
   * Sets an optimisation parameter for the classifier with name -N,
   * lower bound 1, upper bound 5, and 10 optimisation steps.
   * The upper bound may be the character 'A' or 'I' to substitute 
   * the number of attributes or instances in the training data,
   * respectively.
   * This parameter may be supplied more than once to optimise over
   * several classifier options simultaneously. <p>
   *
   * Options after -- are passed to the designated sub-classifier. <p>
   *
   * @param options the list of options as an array of strings
   * @exception Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {
    
    setDebug(Utils.getFlag('D', options));

    String foldsString = Utils.getOption('X', options);
    if (foldsString.length() != 0) {
      setNumFolds(Integer.parseInt(foldsString));
    } else {
      setNumFolds(10);
    }

    String randomString = Utils.getOption('S', options);
    if (randomString.length() != 0) {
      setSeed(Integer.parseInt(randomString));
    } else {
      setSeed(1);
    }

    String cvParam;
    m_CVParams = new FastVector();
    do {
      cvParam = Utils.getOption('P', options);
      if (cvParam.length() != 0) {
	addCVParameter(cvParam);
      }
    } while (cvParam.length() != 0);
    if (m_CVParams.size() == 0) {
      throw new Exception("A parameter specifier must be given with"
			  + " the -P option.");
    }

    String classifierName = Utils.getOption('W', options);
    if (classifierName.length() == 0) {
      throw new Exception("A classifier must be specified with"
			  + " the -W option.");
    }
    setClassifier(Classifier.forName(classifierName,
				     Utils.partitionOptions(options)));
    if (!(m_Classifier instanceof OptionHandler)) {
      throw new Exception("Base classifier must accept options");
    }
  }

  /**
   * Gets the current settings of the Classifier.
   *
   * @return an array of strings suitable for passing to setOptions
   */
  public String [] getOptions() {

    String [] classifierOptions = new String [0];
    if ((m_Classifier != null) && 
	(m_Classifier instanceof OptionHandler)) {
      classifierOptions = ((OptionHandler)m_Classifier).getOptions();
    }

    int current = 0;
    String [] options = new String [classifierOptions.length + 8];
    if (m_CVParams != null) {
      options = new String [m_CVParams.size() * 2 + options.length];
      for (int i = 0; i < m_CVParams.size(); i++) {
	options[current++] = "-P"; options[current++] = "" + getCVParameter(i);
      }
    }

    if (getDebug()) {
      options[current++] = "-D";
    }
    options[current++] = "-X"; options[current++] = "" + getNumFolds();
    options[current++] = "-S"; options[current++] = "" + getSeed();

    if (getClassifier() != null) {
      options[current++] = "-W";
      options[current++] = getClassifier().getClass().getName();
    }
    options[current++] = "--";

    System.arraycopy(classifierOptions, 0, options, current, 
		     classifierOptions.length);
    current += classifierOptions.length;
    while (current < options.length) {
      options[current++] = "";
    }
    return options;
  }

  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data 
   * @exception Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    if (instances.checkForStringAttributes()) {
      throw new Exception("Can't handle string attributes!");
    }
    Instances trainData = new Instances(instances);
    trainData.deleteWithMissingClass();
    if (trainData.numInstances() == 0) {
      throw new Exception("No training instances without missing class.");
    }
    if (trainData.numInstances() < m_NumFolds) {
      throw new Exception("Number of training instances smaller than number of folds.");
    }
    trainData.randomize(new Random(m_Seed));
    if (trainData.classAttribute().isNominal()) {
      trainData.stratify(m_NumFolds);
    }
    m_BestPerformance = -99;
    m_BestClassifierOptions = null;
    m_NumAttributes = trainData.numAttributes();
    m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances();
    
    // Set up m_ClassifierOptions -- take getOptions() and remove
    // those being optimised.
    m_ClassifierOptions = ((OptionHandler)m_Classifier).getOptions();
    for (int i = 0; i < m_CVParams.size(); i++) {
      Utils.getOption(((CVParameter)m_CVParams.elementAt(i)).m_ParamChar,
		      m_ClassifierOptions);
    }
    findParamsByCrossValidation(0, trainData);

    String [] options = (String [])m_BestClassifierOptions.clone();
    ((OptionHandler)m_Classifier).setOptions(options);
    m_Classifier.buildClassifier(trainData);
  }


  /**
   * Predicts the class value for the given test instance.
   *
   * @param instance the instance to be classified
   * @return the predicted class value
   * @exception Exception if an error occurred during the prediction
   */
  public double classifyInstance(Instance instance) throws Exception {
    
    return m_Classifier.classifyInstance(instance);
  }

  /**
   * Sets the seed for random number generation.
   *
   * @param seed the random number seed
   */
  public void setSeed(int seed) {
    
    m_Seed = seed;;
  }

  /**
   * Gets the random number seed.
   * 
   * @return the random number seed
   */
  public int getSeed() {

    return m_Seed;
  }

  /**
   * Adds a scheme parameter to the list of parameters to be set
   * by cross-validation
   *
   * @param cvParam the string representation of a scheme parameter. The
   * format is: <br>
   * param_char lower_bound upper_bound increment <br>
   * eg to search a parameter -P from 1 to 10 by increments of 2: <br>
   * P 1 10 2 <br>
   * @exception Exception if the parameter specifier is of the wrong format
   */
  public void addCVParameter(String cvParam) throws Exception {

    CVParameter newCV = new CVParameter(cvParam);
    
    m_CVParams.addElement(newCV);
  }

  /**
   * Gets the scheme paramter with the given index.
   */
  public String getCVParameter(int index) {

    if (m_CVParams.size() <= index) {
      return "";
    }
    return ((CVParameter)m_CVParams.elementAt(index)).toString();
  }

  /**
   * Sets debugging mode
   *
   * @param debug true if debug output should be printed
   */
  public void setDebug(boolean debug) {

    m_Debug = debug;
  }

  /**
   * Gets whether debugging is turned on
   *
   * @return true if debugging output is on
   */
  public boolean getDebug() {

    return m_Debug;
  }

  /**
   * Get the number of folds used for cross-validation.
   *
   * @return the number of folds used for cross-validation.
   */
  public int getNumFolds() {
    
    return m_NumFolds;
  }
  
  /**
   * Set the number of folds used for cross-validation.
   *
   * @param newNumFolds the number of folds used for cross-validation.
   */
  public void setNumFolds(int newNumFolds) {
    
    m_NumFolds = newNumFolds;
  }

  /**
   * Set the classifier for boosting. 
   *
   * @param newClassifier the Classifier to use.
   */
  public void setClassifier(Classifier newClassifier) {

    m_Classifier = newClassifier;
  }

  /**
   * Get the classifier used as the classifier
   *
   * @return the classifier used as the classifier
   */
  public Classifier getClassifier() {

    return m_Classifier;
  }

 
  /**
   * Returns description of the cross-validated classifier.
   *
   * @return description of the cross-validated classifier as a string
   */
  public String toString() {

    if (m_BestClassifierOptions == null)
      return "CVParameterSelection: No model built yet.";

    String result = "Cross-validated Parameter selection.\n"
    + "Classifier: " + m_Classifier.getClass().getName() + "\n";
    try {
      for (int i = 0; i < m_CVParams.size(); i++) {
	CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i);
	result += "Cross-validation Parameter: '-" 
	  + cvParam.m_ParamChar + "'"
	  + " ranged from " + cvParam.m_Lower 
	  + " to ";
	switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) {
	case 1:
	  result += m_NumAttributes;
	  break;
	case 2:
	  result += m_TrainFoldSize;
	  break;
	default:
	  result += cvParam.m_Upper;
	  break;
	}
	result += " with " + cvParam.m_Steps + " steps\n";
      }
    } catch (Exception ex) {
      result += ex.getMessage();
    }
    result += "Classifier Options: "
      + Utils.joinOptions(m_BestClassifierOptions)
      + "\n\n" + m_Classifier.toString();
    return result;
  }

  public String toSummaryString() {

    String result = "Selected values: "
      + Utils.joinOptions(m_BestClassifierOptions);
    return result + '\n';
  }
  
  /**
   * Main method for testing this class.
   *
   * @param argv the options
   */
  public static void main(String [] argv) {

    try {
      System.out.println(Evaluation.evaluateModel(new CVParameterSelection(), 
						  argv));
    } catch (Exception e) {
      System.err.println(e.getMessage());
    }
  }
}