www.pudn.com > zuixindeID3.zip > DecisionTree.java


package id3; 
import java.io.*; 
import shared.*; 
import shared.Error; 
 
/** DecisonTrees are RootedCatGraphs where each node other than the root has 
 * exactly one parent.  The root has no parents. 
 * @author James Louis 5/29/2001   Ported to Java. 
 * @author Eric Eros 4/18/96 Added delete_subtree 
 * @author Ronny Kohavi  4/16/96 Added treeviz display 
 * @author Richard Long 9/02/93 Initial revision (.c,.h) 
 */ 
 
public class DecisionTree extends RootedCatGraph { 
    /** Indicates if this DecisionTree is sparsely populated. 
     */ 
    boolean isGraphSparse = false; 
     
    /** Constructor. 
     */ 
    public DecisionTree() { 
        super(false); 
    } 
     
    /** Constructor. 
     * @param grph The CGraph object to be used to maintain the DecisionTree. 
     */ 
    public DecisionTree(CGraph grph) { 
        super(grph, false); 
    } 
     
    /** Distribute instances to a subtree. This function is used whenever we 
     * replace a node with its child.  The distributions of the child include 
     * only the instances there while if we replace, we must update all the 
     * counts. This function is also the backfitting function for decision trees. 
     * @param subtree The subtree over which Instances will be distributed. 
     * @param il InstanceList to be distributed over the DecisionTree. 
     * @param pruningFactor The amount of pruning to be done on this tree. 
     * @param pessimisticErrors Number of errors estimated for the new distribution. 
     * @param ldType Leaf Distribution Type. 
     * @param leafDistParameter The distribution of instances that reach this leaf node. 
     * @param parentWeightDist The weight distribution of the parent node. 
     */ 
    public void distribute_instances(Node subtree, 
    InstanceList il, 
    double pruningFactor, 
    DoubleRef pessimisticErrors, 
    int ldType, 			//TDDTInducer.LeafDistType 
    double leafDistParameter, 
    double[] parentWeightDist) { 
        distribute_instances(subtree,il,pruningFactor,pessimisticErrors,ldType, 
        leafDistParameter,parentWeightDist,false); 
    } 
     
    /** Distribute instances to a subtree. This function is used whenever we 
     * replace a node with its child.  The distributions of the child include 
     * only the instances there while if we replace, we must update all the 
     * counts. This function is also the backfitting function for decision trees. 
     * @param subtree The subtree over which Instances will be distributed. 
     * @param il InstanceList to be distributed over the DecisionTree. 
     * @param pruningFactor The amount of pruning to be done on this tree. 
     * @param pessimisticErrors Number of errors estimated for the new distribution. 
     * @param ldType Leaf Distribution Type. 
     * @param leafDistParameter The distribution of instances that reach this leaf node. 
     * @param parentWeightDist The weight distribution of the parent node. 
     * @param saveOriginalDistr TRUE if the original instance distribution should be preserved, FALSE otherwise. 
     */ 
    public void distribute_instances(Node subtree, 
    InstanceList il, 
    double pruningFactor, 
    DoubleRef pessimisticErrors, 
    int ldType, 			//TDDTInducer.LeafDistType 
    double leafDistParameter, 
    double[] parentWeightDist, 
    boolean saveOriginalDistr) { 
        //   DBGSLOW(check_node_in_graph(subtree, TRUE)); 
        NodeCategorizer splitCat = ((NodeInfo)cGraph.inf(subtree)).get_categorizer(); 
         
        logOptions.LOG(3, "Distributing instances: " + il + '\n' + "categorizer is " 
        +splitCat.description()+'\n'); 
         
        splitCat.distribute_instances(il, pruningFactor, pessimisticErrors, ldType, 
        leafDistParameter, parentWeightDist, 
        saveOriginalDistr); 
    } 
     
    /** Removes a subtree recursively. This is used to                           
* a) Remove a node and all nodes below it if the second parameter is NULL,
* b) Remove just the nodes under a particular node, if both parameters are * the same (the named node remains in the graph).
* c) Replace the subtree rooted at the first parameter with the subtree * rooted at the second parameter, if the two parameters are not equal, and * are non-null.

* We allow replacing node X with a child of node X (or a node related * through comman ancestors) or, in general, replacing a subtree with * another subtree. In both cases, we disconnect the parents of the new node * node from the new node.

* We do not allow replacing node X with an ancester (parent, etc.) of * node X, as this would make no sense.

* The method is as follows:

* 1) If 'node' is to be deleted, delete the edges connecting it to its * parents.
* 2) If 'node' is to be replaced by 'newNode', delete the edges connecting * 'newNode' to its parents.
* 3) Delete the edges from 'node' to all its children.
* 4) If 'node' is to be deleted, since it's now completely disconnected, * delete it.
* 5) If 'node' is to be replaced by 'newNode',
* 5a) Connect all of 'newNode's children to 'node' (adding edges),
* 5b) Delete all the edges from 'newNode' to its children.
* 5c) Since 'newNode' is now completely disconnected, delete it.
* 6) For all the children discovered in step 3, recurse to delete them. * @param node Node to be replaced. * @param newNode New Node to be used for replacement. */ public void delete_subtree(Node node, Node newNode) { if (node == null) Error.fatalErr("DecisionTree::delete_subtree: node is NULL"); // Delete a subtree, given the starting node. The second parameter // NULL means the top-most node is deleted--it is set NULL for all // recursive calls from here, so that all children are deleted. If the // second parameter is non-NULL, it needs to point to a node in the // same cGraph as the first. boolean deleteNode = (newNode == null); boolean replaceWithSelf = (node == newNode); boolean replaceWithOther = !deleteNode && !replaceWithSelf; // We can extend this routine to support the new node being the root, // but it seems very strange to do so, since we usually delete the // newNode. // One would need to set the new node to be the root. For safely // it's better to abort in this case that can't happen right now. // Note that replacing the root with a child is OK and the root // will be the new node because it's the categorizer that's replace, // and the root reference remains valid if (!replaceWithSelf && newNode == get_root()) Error.fatalErr("DecisionTree::delete_subtree: new node cannot be root"); if (deleteNode) logOptions.LOG(5, " 1. Deleting the node " + node + '\n'); else logOptions.LOG(5, " 2. Removing the subtree from node " + node + '\n'); if (!deleteNode && !replaceWithSelf) logOptions.LOG(5, " 3. Replacing it with the node " + newNode + '\n'); // Ensure specified node(s) in graph (check_node_in_graph(node, TRUE) // aborts when node isn't in graph.) // DBGSLOW(check_node_in_graph(node, TRUE)); if (replaceWithOther) { check_node_in_graph(newNode, true); // 'node' is to be replaced with 'newNode'. This is only legal when // 'newNode' is NOT an ancester of 'node'. // The following function is only called once, as newNode is NULL in // all recursive calls. // DBG(if (check_node_reachable(newNode, node)) // err << "DecisionTree::delete_subtree: attempt to replace a " // "node with its own ancestor" << fatal_error); } Edge iterEdge; Edge oldEdge; if (deleteNode) { // If 'node' is to be deleted, remove the edges from its parent(s). iterEdge = node.first_in_edge(); while (iterEdge != null) { oldEdge = iterEdge; iterEdge = oldEdge.in_succ(oldEdge); // oldEdge.entry() = null; cGraph.del_edge(oldEdge); } MLJ.ASSERT(node.indeg() == 0,"DecisionTree.delete_subtree: node.indeg() != 0"); } // 'node' is to be replaced with 'newNode'. That means that the // current incoming edges to 'newNode' are extraneous, and need // to be removed. if (replaceWithOther) { iterEdge = newNode.first_in_edge(); while (iterEdge != null) { oldEdge = iterEdge; iterEdge = oldEdge.in_succ(oldEdge); Node parentNode = oldEdge.source(); logOptions.LOG(5, " 4. Removing parent " + parentNode + " from " + newNode + " (deleting edge " + oldEdge + ")" + '\n'); // cGraph[oldEdge] = null; cGraph.del_edge(oldEdge); } MLJ.ASSERT(newNode.indeg() == 0,"DecisionTree.delete_subtree: newNode.indeg() != 0"); } // Disconnect 'node' (the old node) from all outgoing edges. Save references // to the targets of these edges so we can (effectively) follow them, int numChildren = node.outdeg(); int currentChild = 0; MLJ.ASSERT(numChildren >= 0,"DecisionTree.delete_subtree: numChildren < 0"); // Declared before the loop because we use it after the if // for replaceWithSelf. Node[] children = new Node[numChildren]; if (numChildren > 0) { // We're not a leaf, we've got children to delete. // a) Copy the (references to the) children nodes. // b) Delete the edges. iterEdge = node.first_adj_edge(); while (iterEdge != null) { logOptions.LOG(5, " 7. Disconnecting edge " + iterEdge + " from node " + node + " to its child " + '\n' + iterEdge.target() + '\n'); oldEdge = iterEdge; iterEdge = oldEdge.adj_succ(); // Save the other node attached to this edge. Node childNode = oldEdge.target(); children[currentChild++] = childNode; // Delete the connection. // cGraph[oldEdge] = null; cGraph.del_edge(oldEdge); } } MLJ.ASSERT(currentChild == numChildren,"DecisionTree.delete_subtree: currentChild != numChildren"); MLJ.ASSERT(node.outdeg() == 0,"DecisionTree.delete_subtree: node.outdeg() != 0"); // Delete the node. if (deleteNode) { logOptions.LOG(5, " 8. Deleting the node " + node + '\n'); MLJ.ASSERT(node.indeg() == 0,"DecisionTree.delete_subtree: node.indeg() != 0"); MLJ.ASSERT(node.outdeg() == 0,"DecisionTree.delete_subtree: node.outdeg() != 0"); // cGraph[node] = null; cGraph.del_node(node); } else if (replaceWithOther) { // Delete 'newNode' after moving all its children over to 'node', // and assigning its categorizer to 'node'. cGraph.assign_categorizer(node, newNode); iterEdge = newNode.first_adj_edge(); while (iterEdge != null) { oldEdge = iterEdge; iterEdge = oldEdge.adj_succ(); Node childNode = oldEdge.target(); AugCategory aug = new AugCategory(cGraph.edge_info(oldEdge).num(), cGraph.edge_info(oldEdge).description()); cGraph.new_edge(node, childNode, aug); // cGraph[oldEdge] = null; cGraph.del_edge(oldEdge); } MLJ.ASSERT(newNode.indeg() == 0,"DecisionTree.delete_subtree: newNode.indeg() != 0"); MLJ.ASSERT(newNode.outdeg() == 0,"DecisionTree.delete_subtree: newNode.outdeg() != 0"); // cGraph.entry(newNode) = null; cGraph.del_node(newNode); // Re-assign the levels of each node in the subtree we just moved. if (get_graph().node_info(node).level() != CGraph.DEFAULT_LEVEL) assign_subtree_levels(node, get_graph().node_info(node).level()); } // Recurse--all children must delete themselves. for (currentChild = 0; currentChild < numChildren; currentChild++) { logOptions.LOG(5, " 9. Now to delete child " + currentChild + " of " + numChildren + " children" + '\n'); delete_subtree(children[currentChild], null); } } /** Creates NodeInfo objects for every Node in the branch starting at the * given Node and assigns each NodeInfo its appropriate level in the tree. * @param node The base node where level assignment will start. * @param baseLevel The initial level for the base Node. */ public void assign_subtree_levels(Node node, int baseLevel) { // MLJ.ASSERT(baseLevel != DEFAULT_LEVEL); NodeInfo rootInfo = cGraph.node_info(node); logOptions.LOG(5, "Replacing level "+rootInfo.level()+" with "+baseLevel+'\n'); cGraph.node_info(node).set_level(baseLevel); Edge iterEdge; Edge oldEdge; iterEdge = node.first_adj_edge(); int nextLevel; // if (get_categorizer(node).class_id() == CLASS_MULTI_SPLIT_CATEGORIZER) // nextLevel = baseLevel; // else nextLevel = baseLevel + 1; while (iterEdge != null) { oldEdge = iterEdge; iterEdge = oldEdge.adj_succ(); Node childNode = oldEdge.target(); assign_subtree_levels(childNode, nextLevel); } } /*************************************************************************** Display this DecisionTree. @param @param @param @param *************************************************************************** public void display(boolean hasNodeLosses, boolean hasLossMatrix, Writer stream, DisplayPref dp) { stream.write(display(hasNodeLosses, hasLossMatrix, dp)); } /*************************************************************************** Display this DecisionTree. @param @param @param *************************************************************************** public String display(boolean hasNodeLosses, boolean hasLossMatrix, DisplayPref dp) { String return_value = new String(); // Note that if the display is XStream, our virtual function gets it // if (stream.output_type() == XStream || // dp.preference_type() != DisplayPref::TreeVizDisplay) // RootedCatGraph.display(hasNodeLosses, hasLossMatrix, stream, dp); else { String dataName = stream.description() + ".data"; MLCOStream data(dataName); convertToTreeVizFormat(stream, data, dp, hasNodeLosses, hasLossMatrix); } } */ /* /*************************************************************************** Displays the DecisionTree in TreeVizFormat. @param @param @param @param @param *************************************************************************** public void convertToTreeVizFormat(Writer conf, Writer data, DisplayPref displayPref, boolean hasNodeLosses, boolean hasLossMatrix) throws IOException { Node rootNode = get_root(true); NodeCategorizer cat = get_categorizer(rootNode); Schema schema = cat.get_schema(); NominalAttrInfo nai = schema.nominal_label_info(); int numLabelValues = nai.num_values(); MLJ.ASSERT(numLabelValues >= 1,"DecisionTree::convertToTreeVizFormat:numLabelValues < 1"); // Avoid log of 1, which is a scale of zero, and causes division // by zero. double scale = MLJ.log_bin(Math.max(numLabelValues, 2)); int[] permLabels = schema.sort_labels(); // permuted labels boolean dispBackfitDisks = displayPref.typecast_to_treeViz().get_display_backfit_disks(); write_subtree(get_log_options(), scale, data, permLabels, Globals.EMPTY_STRING, Globals.EMPTY_STRING, this, rootNode, dispBackfitDisks, hasNodeLosses, hasLossMatrix); String protectedLabelName = new String(Globals.SINGLE_QUOTE + MLJ.protect(nai.name(),"`\\") + Globals.SINGLE_QUOTE); conf.write(minesetVersionStr + "\n"); conf.write("# MLC++ generated file for MineSet Tree Visualizer.\n" + "input {\n" + "\t file \"" + data.description() + "\";\n" + "\t options backslash on;\n" + "\t key string " + protectedLabelName + " {\n"); for (int i = 0; i < numLabelValues; i++) { conf.write("\t\t " + nai.get_value(permLabels[i]).quote()); if (i != numLabelValues - 1) conf.write(","); conf.write("\n"); } permLabels = null; conf.write("\t };\n" + "\t expression `Node label`[] separator ':';\n" + "\t string `Test attribute`;\n" + "\t string `Test value`;\n" + "\t float `Subtree weight` [" + protectedLabelName + "] separator ',' ;\n" + "\t float Percent [" + protectedLabelName + "] separator ',' ;\n"); if (dispBackfitDisks) conf.write("\t float OriginalDist [" + protectedLabelName + "] separator ',' ;\n"); conf.write("\t float Purity;\n"); if (hasNodeLosses) { conf.write("\t float `Test-set subtree weight`;\n"); if (hasLossMatrix) conf.write("\t float `Test-set loss`;\n" + "\t float `Mean loss std-dev`;\n"); else conf.write("\t float `Test-set error`;\n" + "\t float `Mean err std-dev`;\n"); } conf.write("}\n\n"); conf.write("hierarchy {\n" + "\t levels `Node label`;\n" + "\t key `Subtree weight`;\n" + "\t aggregate base {\n" + "\t\t sum `Subtree weight`;\n"); if (dispBackfitDisks) conf.write("\t\t sum `OriginalDist`;\n"); conf.write("\t\t any Purity;\n" + "\t\t any `Test attribute`;\n" + "\t\t any `Test value`;\n"); if (hasNodeLosses) { conf.write("\t\t any `Test-set subtree weight`;\n"); if (hasLossMatrix) conf.write("\t\t any `Test-set loss`;\n" + "\t\t any `Mean loss std-dev`;\n"); else conf.write("\t\t any `Test-set error`;\n" + "\t\t any `Mean err std-dev`;\n"); } conf.write("\t }\n" + "\t options organization same;\n" + "}\n"); // Pick the midpoint entropy color to be 3/4 versus 1/4 for two class probs. // This just makes the color scale much better then 50, which requires // 89% versus 11% to be the middle color. double[] typicalMix = new double[2]; typicalMix[0] = 3; typicalMix[1] = 1; DoubleRef midPointEnt = new DoubleRef(100 - Entropy.entropy(typicalMix)*100 / scale); MLJ.clamp_to_range(midPointEnt, 0, 100, "DecisionTree::convertToTreeVizFormat: mid-point does " + "not clamp to range [0-100]"); MLJ.ASSERT(schema.num_label_values() > 0,"DecisionTree::" + "convertToTreeVizFormat:schema.num_label_values() <= 0"); // Even though nulls are never used, we want to distinguish // them in case somebody changes anything. They're therefore hidden. conf.write("view hierarchy landscape {\n" + "\t height `Subtree weight`, normalize, max 5.0;\n"); if (dispBackfitDisks) conf.write("\t disk height `OriginalDist`;\n"); conf.write("\t base height max 2.0;\n" + "\t base label `Test attribute`;\n" + "\t line label `Test value`;\n" + "\t color key;\n"); // "\t base color legend label \"Purity\";\n" // "\t base color Purity, " // "colors \"red\" \"yellow\" \"green\"" // ", scale 0 " << midPointEnt << " 100, legend on;\n" // "\t base color legend \"impure\" \"mixed\" \"pure\";\n" if (hasNodeLosses) { double min = 0; double max = 0; loss_min_max(this, min, max); if (max - min < 0.01) max += 0.01; // Avoid cases where both are zero and we rely // on a treeviz tiebreaker (happens in mushroom). NodeLoss rootLoss = get_categorizer(rootNode).get_loss(); double medColor = suggest_mid(min, max, rootLoss.totalWeight, rootLoss.totalLoss); if (!hasLossMatrix) { min *= 100; max *= 100; medColor *=100; } if (hasLossMatrix) conf.write("\t base color legend label \"Test-set loss\";\n" + "\t base color `Test-set loss`, "); else conf.write("\t base color legend label \"Test-set error\";\n" + "\t base color `Test-set error`, "); conf.write("colors \"green\" \"yellow\" \"red\"" + ", scale " + min + " " + medColor + " " + max + ", legend on;\n" + "\t base color legend \"low (" + MLJ.numberToString(min,2) + ")\" \"medium (" + MLJ.numberToString(medColor,2) + ")\" \"high (" + MLJ.numberToString(max,2) + ")\";\n"); } conf.write("\t options rows 1;\n" + "\t options root label \"\";\n" + "\t options initial depth 4;\n" // Don't show bar labels, so the level of details is far + "\t options lod bar label 10000;\n" + "\t options zero outline;\n" + "\t options null hidden;\n"); conf.write("\t base message \"Subtree weight:%.2f, "); String lossMetric = hasLossMatrix ? "loss" : "error"; String shortLossMetric = hasLossMatrix ? "loss" : "err"; if (hasNodeLosses) conf.write("test-set " + lossMetric + ":%.2f+-%.2f, " + " test-set weight:%.2f, "); if (dispBackfitDisks) conf.write("training-set weight: %.2f, "); conf.write("purity:%.2f\", `Subtree weight`, "); if (hasNodeLosses) conf.write("`Test-set " + lossMetric + "`, " + "`Mean " + shortLossMetric + " std-dev`, " + "`Test-set subtree weight`, "); if (dispBackfitDisks) conf.write("`OriginalDist`, "); conf.write("Purity;\n"); conf.write("\t message \"Subtree weight for label value:%.2f, percent:%.2f"); if (dispBackfitDisks) conf.write(", training-set weight:%.2f"); conf.write("\", `Subtree weight`, Percent"); if (dispBackfitDisks) conf.write(", `OriginalDist`"); conf.write(";\n}\n"); } */ }