www.pudn.com > weka.rar > EvaluationClient.java, change:2002-09-03,size:17907b


/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    EvaluationClient.java
 *    Copyright (C) 2002 Dave Musicant, Sebastian Celis
 *
 */

package weka.classifiers;

import java.io.*;
import java.net.*;
import java.util.*;
import weka.core.*;
import weka.estimators.*;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

/** Used to classifiy each fold and whether or not it's been completed. */
interface Status
{
    int
        NOT_DONE = 0,
        DONE = 1;
}

/**
 * Used by the client to tell the server if it is being run from
 * the command line or the graphical interface.
 */
interface Connection
{
    int
        CV_CLI = 0,
        CV_GUI = 1;
}

/**
 * Class for running cross-validation over multiple machines. <p>
 *
 * This is the class that will send the necessary data over to
 * other machines and then wait for those machines to send the results
 * back. <p>
 *
 * ------------------------------------------------------------------- <p>
 *
 * Example usage from within an application:
 * <code> <pre>
 * int numFolds = ... the number of folds to do
 * Instances data = ... the data we are examining
 * Classifier classifier = ... the classifier we are running
 * StringBuffer otherComputers = ... will end up holding the names of the
 *                                   computers that the client ends up
 *                                   getting data from
 * Evaluation evaluation = ... the evaluation object that will end up holding
 *                             all of the final data attained from this
 *                             cross-validation
 *
 * EvaluationClient ec = new EvaluationClient(numFolds, data, classifier,
 *                                            otherComputers, evaluation);
 * ec.start();
 * </pre> </code>
 *
 *
 * @author Dave Musicant (dmusican@mathcs.carleton.edu)
 * @author Sebastian Celis (celiss@mathcs.carleton.edu)
 */
public class EvaluationClient
{
    /** The port the client will connect on. */
    private int port;

    /**
     * A linked list of Strings containing the domain name of each server
     * that the client will attempt to connect to.
     */
    private LinkedList computers;

    /** The data to be cross-validated */
    private Instances data;

    /** The status of each fold (whether it is DONE or NOT_DONE). */
    private int status[];

    /** The number of folds to be done in the cross-validation */
    private int numFolds;

    /** The index of the last fold calculated. */
    private int lastIndexSent;

    /** The classifier used in the cross-validation. */
    private Classifier classifier;

    /**
     * The evaluation object that will eventually contain all of the
     * information gathered during cross-validation.
     */
    private Evaluation evaluation;

    /**
     * Holds the name of all of the computers that the client received
     * data from.
     */
    private StringBuffer otherComputers;

    /**
     * Initializes all of the EvaluationClient's main variables.
     *
     * @param numFolds the number of folds that need to be calculated
     * @param data the data to run the cross-validation on
     * @param classifier the classifier to use for the cross-validation
     * @param otherComputers will eventually hold all of the names of the
     * computers that this client received data from
     * @param evaluation the object that will eventually hold the data
     * gathered from the cross-validation
     */
    public EvaluationClient(int numFolds, Instances data,
                            Classifier classifier,
                            StringBuffer otherComputers,
                            Evaluation evaluation)
    {
        // Set class variables
        this.numFolds = numFolds;
        this.classifier = classifier;
        this.evaluation = evaluation;
        this.otherComputers = otherComputers;
        port = -1;
        computers = new LinkedList();
        lastIndexSent = numFolds - 1;
        status = new int[numFolds];
        for(int i = 0; i < numFolds; i++)
            status[i] = Status.NOT_DONE;

        // Stratify the data if it is nominal
        this.data = new Instances(data);
        if(this.data.classAttribute().isNominal())
            this.data.stratify(numFolds);
    }

    /**
     * Configures the client by looking at the file ~/.weka-parallel.
     * This file tells the client the addresses of the computers to connect
     * to and what port to connect on.
     *
     * @exception Exception if the configuration file has a syntax error or
     * if the file does not exist
     */
    public void configure() throws Exception
    {
        String tempString;
        Integer tempInteger;
        File inFile;
        String parallelConfigLocation = new String();
        BufferedReader fileInStream;

        try
        {
            parallelConfigLocation = System.getProperty("user.home");

            // If the user is running windows...
            if(System.getProperty("os.name").charAt(0) == 'W')
            {
                parallelConfigLocation
                    = parallelConfigLocation.concat("\\.weka-parallel");
            }
            // If the user is running anything else
            else
            {
                parallelConfigLocation
                    = parallelConfigLocation.concat("/.weka-parallel");
            }
            inFile = new File(parallelConfigLocation);
        }
        catch (Exception e)
        {
            e.printStackTrace();
            return;
        }

        // Check to make sure the file actually exists
        if(!inFile.exists())
        {
            throw new Exception("Config file does not exist.");
        }

        try
        {
            // Places a wrapper around the input stream from the file
            fileInStream = new BufferedReader(new FileReader(inFile));

            // Finds the port number from the config file
            tempString = fileInStream.readLine();
        }
        catch(Exception e)
        {
            e.printStackTrace();
            return;
        }

        // Check the file for any syntax errors
        try
        {
            tempString = tempString.trim();
            if(!tempString.startsWith("PORT="))
                throw new Exception();
            tempString = tempString.substring(5, tempString.length()).trim();
            tempInteger = new Integer(tempString);
            port = tempInteger.intValue();
        }
        catch(Exception e)
        {
            throw new Exception("Syntax error in config file for "+
                                "parallelization.");
        }

        try
        {
            // Finds the list of server names
            tempString = fileInStream.readLine();
            while(tempString != null)
            {
                tempString = tempString.trim();
                if(tempString != "")
                    computers.add(tempString);
                tempString = fileInStream.readLine();
            }
        }
        catch(Exception e)
        {
            e.printStackTrace();
        }
    }

    /**
     * Determines which fold should start to be calculated next.
     * A Round-Robin algorithm is used to try and maximize efficiency.
     *
     * @return the index of the fold that should be started next or
     * -1 if all of the folds have already been calculated
     */
    private int determineIndex()
    {
        synchronized(status)
        {
            for(int i = 0; i < numFolds; i++)
            {
                lastIndexSent = (lastIndexSent + 1) % numFolds;
                if(status[lastIndexSent] != Status.DONE)
                {
                    return lastIndexSent;
                }
            }
        }

        return -1;
    }

    /**
     * Runs the distributed client.
     *
     * @exception Exception any problems during configuration
     */
    public void start() throws Exception
    {
        ConnectionService [] connect;
        ClientSideComputations cs = new ClientSideComputations();

        // Configure the client
        configure();

	connect = new ConnectionService[computers.size()];

        // Start a thread for each server
        for(int i = 0; i < computers.size(); i++)
        {
	    try 
	    {
		connect[i] = new ConnectionService((String)computers.get(i));
	    } catch (Exception e)
	    {
		e.printStackTrace();
	    }
            connect[i].start();
        }

        // Let the computer running the client also do some work to speed
        // things up
        cs.start();

        try
        {
            synchronized(evaluation)
            {
                // The main thread waits on the evaluation object.  It will
                // receive a signal when the cross-validation is complete.
                evaluation.wait();
            }
        }
        catch (InterruptedException e)
        {
            e.printStackTrace();
        }
    }

    /**
     * Performs part of the cross-fold validation on the computer running
     * Weka.  This yields a speed increase since this computer will
     * be calculating data at the same time the other computers are doing
     * so.
     *
     * @author Sebastian Celis (celiss@mathcs.carleton.edu)
     * @author Dave Musicant (dmusican@mathcs.carleton.edu)
     */
    private class ClientSideComputations extends Thread
    {

        /**
         * This is the actual thread which performs the necessary
         * computations.
         */
        public void run()
        {
            int index;

            try
            {
                // Determine which section we should train on
                index = determineIndex();

                while(index != -1)
                {
                    // Do the cross validation
                    Evaluation anEvaluation = new Evaluation(data, null);
                    Instances train = data.trainCV(numFolds, index);
                    anEvaluation.setPriors(train);
                    classifier.buildClassifier(train);
                    Instances test = data.testCV(numFolds, index);
                    anEvaluation.evaluateModel(classifier, test);

                    synchronized(status)
                    {
                        synchronized(evaluation)
                        {
                            if(status[index] != Status.DONE)
                                evaluation.aggregate(anEvaluation);
                            status[index] = Status.DONE;
                        }
                    }

                    // Determine which section we should train on
                    index = determineIndex();
                }
                synchronized(evaluation)
                {
                    evaluation.notifyAll();
                }
            }
            catch (Exception e)
            {
                e.printStackTrace();
                return;
            }
        }
    }

    /**
     * One instance of this class will connect to a single computer
     * and send it the information necessary for computing cross-validations.
     * This class instructs the other computer concerning which folds it
     * should work on.  When the other computer finishes a fold, it sends
     * the results back to this class, which then compiles that information.
     *
     * @author Dave Musicant (dmusican@mathcs.carleton.edu)
     * @author Sebastian Celis (celiss@mathcs.carleton.edu)
     */
    private class ConnectionService extends Thread
    {
        /** The address of the computer to connect to */
        private InetAddress serverAddress;

        /** The current fold to work on */
        private int index;

        /**
         * Gets the internet address of the computer to connect to by
         * looking up the server's name.
         *
         * @param serverName the name of computer to connect to
         */
        ConnectionService(String serverName)
        {
            try
            {
                this.serverAddress = InetAddress.getByName(serverName);
            }
            catch(UnknownHostException e)
            {
                // The server does not exist, so exit cleanly.
                return;
            }
        }

        /**
         * The actual thread that will connect to the other computer,
         * send data to it, and receive data from it.
         */
        public void run()
        {
            /** Used to send bytes to the server */
            BufferedOutputStream bos;

            /** Used to send objects to the server */
            ObjectOutputStream oos;

            /** Used to send ints to the server */
            DataOutputStream dos;

            /** Used to receive objects from the server */
            ObjectInputStream ois;

            /** The connection to the server */
            Socket sock;

            try
            {
                // Determines the server's InetAddress and then connects
                // to this server
                sock = new Socket(serverAddress, port);

                try
                {
                    try
                    {
                        index = determineIndex();
                        if(index != -1)
                        {
                            // Create the output streams
                            bos = new BufferedOutputStream(
                                      sock.getOutputStream());
                            oos = new ObjectOutputStream(bos);
                            dos = new DataOutputStream(bos);

                            // Write all necessary info to the server
                            dos.writeInt(Connection.CV_CLI);
                            oos.writeObject(data);
                            oos.writeObject(classifier);
                            oos.flush();
                            dos.writeInt(numFolds);
                            dos.writeInt(index);
                            dos.flush();

                            // Create the input stream and get the results
                            // back from the server
                            ois = new ObjectInputStream(
                                      new BufferedInputStream(
                                          sock.getInputStream()));
                            Evaluation newEvaluation
                                = (Evaluation)ois.readObject();

                            synchronized(status)
                            {
                                synchronized(evaluation)
                                {
                                    if(status[index] != Status.DONE)
                                        evaluation.aggregate(newEvaluation);
                                    status[index] = Status.DONE;
                                }
                            }

                            synchronized(otherComputers)
                            {
                                otherComputers.append(serverAddress+"\n");
                            }

                            // Determine which section we should train on
                            index = determineIndex();

                            while(index != -1)
                            {
                                // Write this number to the server
                                dos.writeInt(index);
                                dos.flush();

                                // Get the results from the server
                                newEvaluation
                                    = (Evaluation)ois.readObject();

                                synchronized(status)
                                {
                                    synchronized(evaluation)
                                    {
                                        if(status[index] != Status.DONE)
                                           evaluation.aggregate(newEvaluation);
                                        status[index] = Status.DONE;
                                    }
                                }

                                // Determine which section we should train on
                                index = determineIndex();
                            }

                            synchronized(evaluation)
                            {
                                evaluation.notifyAll();
                            }
                        }
                    }
                    catch (Exception e)
                    {
                        // There was a problem with this connection.
                        // One possibility is that the server ran out of
                        // memory.  Thus, we simply close the connection.
                        return;
                    }
                }
                finally
                {
                    // Close the connection no matter what exceptions were
                    // thrown
                    sock.close();
                }
            }
            catch(Exception e)
            {
                // We could not connect, so exit cleanly
                return;
            }
        }
    }
}