In a previous post, I explored how one might apply decision trees to solve a complex problem. This post will explore the code necessary to implement that decision tree. If you would like a full copy of the source code, it is available here in zip format.

Entropy.java – In Entropy.java, we are concerned with calculating the amount of entropy, or the amount of uncertainty or randomness with a particular variable. For example, consider a classifier with two classes, YES and NO. If a particular variable or attribute, say x has three training examples of class YES and three training examples of class NO (for a total of six), the entropy would be 1. This is because there is an equal number of both classes for this variable and is the most mixed up you can get. Likewise, if x had all six training examples of a particular class, say YES, then entropy would be 0 because this particular variable would be pure, thus making it a leaf node in our decision tree.

Entropy may be calculated in the following way:


import java.util.ArrayList;

public class Entropy {	
	public static double calculateEntropy(ArrayList<Record> data) {
		double entropy = 0;
		
		if(data.size() == 0) {
			// nothing to do
			return 0;
		}
		
		for(int i = 0; i < Hw1.setSize("PlayTennis"); i++) {
			int count = 0;
			for(int j = 0; j < data.size(); j++) {
				Record record = data.get(j);
				
				if(record.getAttributes().get(4).getValue() == i) {
					count++;
				}
			}
				
			double probability = count / (double)data.size();
			if(count > 0) {
				entropy += -probability * (Math.log(probability) / Math.log(2));
			}
		}
		
		return entropy;
	}
	
	public static double calculateGain(double rootEntropy, ArrayList<Double> subEntropies, ArrayList<Integer> setSizes, int data) {
		double gain = rootEntropy; 
		
		for(int i = 0; i < subEntropies.size(); i++) {
			gain += -((setSizes.get(i) / (double)data) * subEntropies.get(i));
		}
		
		return gain;
	}
}

Tree.java – This tree class contains all our code for building our decision tree. Note that each level, we choose the attribute that presents the best gain for that node. The gain is simply the expected reduction in the entropy of Xachieved by learning the state of the random variable A. Gain is also known as Kullback-Leibler divergence. Gain can be calculated in the following way:

Notice that gain is calculated as a function of all the values of the attribute.

import java.io.*;
import java.util.*;

public class Tree {
	public Node buildTree(ArrayList<Record> records, Node root, LearningSet learningSet) {
		int bestAttribute = -1;
		double bestGain = 0;
		root.setEntropy(Entropy.calculateEntropy(root.getData()));
		
		if(root.getEntropy() == 0) {
			return root;
		}
		
		for(int i = 0; i < Hw1.NUM_ATTRS - 2; i++) {
			if(!Hw1.isAttributeUsed(i)) {
				double entropy = 0;
				ArrayList<Double> entropies = new ArrayList<Double>();
				ArrayList<Integer> setSizes = new ArrayList<Integer>();
				
				for(int j = 0; j < Hw1.NUM_ATTRS - 2; j++) {
					ArrayList<Record> subset = subset(root, i, j);
					setSizes.add(subset.size());
					
					if(subset.size() != 0) {
						entropy = Entropy.calculateEntropy(subset);
						entropies.add(entropy);
					}
				}
				
				double gain = Entropy.calculateGain(root.getEntropy(), entropies, setSizes, root.getData().size());
				
				if(gain > bestGain) {
					bestAttribute = i;
					bestGain = gain;
				}
			}
		}
		if(bestAttribute != -1) {
			int setSize = Hw1.setSize(Hw1.attrMap.get(bestAttribute));
			root.setTestAttribute(new DiscreteAttribute(Hw1.attrMap.get(bestAttribute), 0));
			root.children = new Node[setSize];
			root.setUsed(true);
			Hw1.usedAttributes.add(bestAttribute);
			
			for (int j = 0; j< setSize; j++) {
				root.children[j] = new Node();
				root.children[j].setParent(root);
				root.children[j].setData(subset(root, bestAttribute, j));
				root.children[j].getTestAttribute().setName(Hw1.getLeafNames(bestAttribute, j));
				root.children[j].getTestAttribute().setValue(j);
			}

			for (int j = 0; j < setSize; j++) {
				buildTree(root.children[j].getData(), root.children[j], learningSet);
			}

			root.setData(null);
		}
		else {
			return root;
		}
		
		return root;
	}
	
	public ArrayList<Record> subset(Node root, int attr, int value) {
		ArrayList<Record> subset = new ArrayList<Record>();
		
		for(int i = 0; i < root.getData().size(); i++) {
			Record record = root.getData().get(i);
			
			if(record.getAttributes().get(attr).getValue() == value) {
				subset.add(record);
			}
		}
		return subset;
	}
	
	public double calculateSurrogates(ArrayList<Record> records) {
		return 0;
	}
}

DiscreteAttribute.java


public class DiscreteAttribute extends Attribute {
	public static final int Sunny = 0;
	public static final int Overcast = 1;
	public static final int Rain = 2;

	public static final int Hot = 0;
	public static final int Mild = 1;
	public static final int Cool = 2;
	
	public static final int High = 0;
	public static final int Normal = 1;
	
	public static final int Weak = 0;
	public static final int Strong = 1;
	
	public static final int PlayNo = 0;
	public static final int PlayYes = 1;
	
	enum PlayTennis {
		No,
		Yes
	}
	
	enum Wind {
		Weak,
		Strong
	}
	
	enum Humidity {
		High,
		Normal
	}
	
	enum Temp {
		Hot,
		Mild,
		Cool
	}
	
	enum Outlook {
		Sunny,
		Overcast,
		Rain
	}

	public DiscreteAttribute(String name, double value) {
		super(name, value);
	}

	public DiscreteAttribute(String name, String value) {
		super(name, value);
	}
}

Hw1.java – This class is our main driver class

import java.util.*;

public class Hw1 {
	public static int NUM_ATTRS = 6;
	public static ArrayList<String> attrMap;
	public static ArrayList<Integer> usedAttributes = new ArrayList<Integer>();

	public static void main(String[] args) {
		populateAttrMap();

		Tree t = new Tree();
		ArrayList<Record> records;
		LearningSet learningSet = new LearningSet();
		
		// read in all our data
		records = FileReader.buildRecords();
		
		Node root = new Node();
		
		for(Record record : records) {
			root.getData().add(record);
		}
		
		t.buildTree(records, root, learningSet);
		traverseTree(records.get(12), root);
		return;
	}
	
	public static void traverseTree(Record r, Node root) {
		while(root.children != null) {
			double nodeValue = 0;
			for(int i = 0; i < r.getAttributes().size(); i++) {
				if(r.getAttributes().get(i).getName().equalsIgnoreCase(root.getTestAttribute().getName())) {
					nodeValue = r.getAttributes().get(i).getValue();
					break;
				}
			}
			for(int i = 0; i < root.getChildren().length; i++) {
				if(nodeValue == root.children[i].getTestAttribute().getValue()) {
					traverseTree(r, root.children[i]);
				}
			}
		}
		
		System.out.print("Prediction for Play Tennis: ");
		if(root.getTestAttribute().getValue() == 0) {
			System.out.println("No");
		}
		else if(root.getTestAttribute().getValue() == 0) {
			System.out.println("Yes");
		}

		return;
	}
	
	public static boolean isAttributeUsed(int attribute) {
		if(usedAttributes.contains(attribute)) {
			return true;
		}
		else {
			return false;
		}
	}
	
	public static int setSize(String set) {
		if(set.equalsIgnoreCase("Outlook")) {
			return 3;
		}
		else if(set.equalsIgnoreCase("Wind")) {
			return 2;
		}
		else if(set.equalsIgnoreCase("Temperature")) {
			return 3;
		}
		else if(set.equalsIgnoreCase("Humidity")) {
			return 2;
		}
		else if(set.equalsIgnoreCase("PlayTennis")) {
			return 2;
		}
		return 0;
	}
	
	public static String getLeafNames(int attributeNum, int valueNum) {
		if(attributeNum == 0) {
			if(valueNum == 0) {
				return "Sunny";
			}
			else if(valueNum == 1) {
				return "Overcast";
			}
			else if(valueNum == 2) {
				return "Rain";
			}
		}
		else if(attributeNum == 1) {
			if(valueNum == 0) {
				return "Hot";
			}
			else if(valueNum == 1) {
				return "Mild";
			}
			else if(valueNum == 2) {
				return "Cool";
			}
		}
		else if(attributeNum == 2) {
			if(valueNum == 0) {
				return "High";
			}
			else if(valueNum == 1) {
				return "Normal";
			}
		}
		else if(attributeNum == 3) {
			if(valueNum == 0) {
				return "Weak";
			}
			else if(valueNum == 1) {
				return "Strong";
			}
		}
		
		return null;
	}
	
	public static void populateAttrMap() {
		attrMap = new ArrayList<String>();
		attrMap.add("Outlook");
		attrMap.add("Temperature");
		attrMap.add("Humidity");
		attrMap.add("Wind");
		attrMap.add("PlayTennis");
	}
}

Node.java – Node.java holds our information in the tree.


import java.util.*;

public class Node {
	private Node parent;
	public Node[] children;
	private ArrayList<Record> data;
	private double entropy;
	private boolean isUsed;
	private DiscreteAttribute testAttribute;

	public Node() {
		this.data = new ArrayList<Record>();
		setEntropy(0.0);
		setParent(null);
		setChildren(null);
		setUsed(false);
		setTestAttribute(new DiscreteAttribute("", 0));
	}

	public void setParent(Node parent) {
		this.parent = parent;
	}

	public Node getParent() {
		return parent;
	}

	public void setData(ArrayList<Record> data) {
		this.data = data;
	}

	public ArrayList<Record> getData() {
		return data;
	}

	public void setEntropy(double entropy) {
		this.entropy = entropy;
	}

	public double getEntropy() {
		return entropy;
	}

	public void setChildren(Node[] children) {
		this.children = children;
	}

	public Node[] getChildren() {
		return children;
	}

	public void setUsed(boolean isUsed) {
		this.isUsed = isUsed;
	}

	public boolean isUsed() {
		return isUsed;
	}

	public void setTestAttribute(DiscreteAttribute testAttribute) {
		this.testAttribute = testAttribute;
	}

	public DiscreteAttribute getTestAttribute() {
		return testAttribute;
	}
}

FileReader.java – The least interesting class in the code

import java.io.*;
import java.util.ArrayList;
import java.util.StringTokenizer;

public class FileReader {
	public static final String PATH_TO_DATA_FILE = "playtennis.data";

    public static ArrayList<Record> buildRecords() {
		BufferedReader reader = null;
		DataInputStream dis = null;
		ArrayList<Record> records = new ArrayList<Record>();

        try { 
           File f = new File(PATH_TO_DATA_FILE);
           FileInputStream fis = new FileInputStream(f); 
           reader = new BufferedReader(new InputStreamReader(fis));;
           
           // read the first record of the file
           String line;
           Record r = null;
           ArrayList<DiscreteAttribute> attributes;
           while ((line = reader.readLine()) != null) {
              StringTokenizer st = new StringTokenizer(line, ",");
              attributes = new ArrayList<DiscreteAttribute>();
              r = new Record();
              
              if(Hw1.NUM_ATTRS != st.countTokens()) {
            	  throw new Exception("Unknown number of attributes!");
              }
              	
			  @SuppressWarnings("unused")
			  String day = st.nextToken();
			  String outlook = st.nextToken();
			  String temperature = st.nextToken();
			  String humidity = st.nextToken();
			  String wind = st.nextToken();
			  String playTennis = st.nextToken();
			  
			  if(outlook.equalsIgnoreCase("overcast")) {
				  attributes.add(new DiscreteAttribute("Outlook", DiscreteAttribute.Overcast));
			  }
			  else if(outlook.equalsIgnoreCase("sunny")) {
				  attributes.add(new DiscreteAttribute("Outlook", DiscreteAttribute.Sunny));
			  }
			  else if(outlook.equalsIgnoreCase("rain")) {
				  attributes.add(new DiscreteAttribute("Outlook", DiscreteAttribute.Rain));
			  }
			  
			  if(temperature.equalsIgnoreCase("hot")) {
				  attributes.add(new DiscreteAttribute("Temperature", DiscreteAttribute.Hot));
			  }
			  else if(temperature.equalsIgnoreCase("mild")) {
				  attributes.add(new DiscreteAttribute("Temperature", DiscreteAttribute.Mild));
			  }
			  else if(temperature.equalsIgnoreCase("cool")) {
				  attributes.add(new DiscreteAttribute("Temperature", DiscreteAttribute.Cool));
			  }
			  
			  if(humidity.equalsIgnoreCase("high")) {
				  attributes.add(new DiscreteAttribute("Humidity", DiscreteAttribute.High));
			  }
			  else if(humidity.equalsIgnoreCase("normal")) {
				  attributes.add(new DiscreteAttribute("Humidity", DiscreteAttribute.Normal));
			  }
			  
			  if(wind.equalsIgnoreCase("weak")) {
				  attributes.add(new DiscreteAttribute("Wind", DiscreteAttribute.Weak));

			  }
			  else if(wind.equalsIgnoreCase("strong")) {
				  attributes.add(new DiscreteAttribute("Wind", DiscreteAttribute.Strong));

			  }
			  
			  if(playTennis.equalsIgnoreCase("no")) {
				  attributes.add(new DiscreteAttribute("PlayTennis", DiscreteAttribute.PlayNo));
			  }
			  else if(playTennis.equalsIgnoreCase("yes")) {
				  attributes.add(new DiscreteAttribute("PlayTennis", DiscreteAttribute.PlayYes));
			  }
			    		    
			  r.setAttributes(attributes);
			  records.add(r);
           }

        } 
        catch (IOException e) { 
           System.out.println("Uh oh, got an IOException error: " + e.getMessage()); 
        } 
        catch (Exception e) {
            System.out.println("Uh oh, got an Exception error: " + e.getMessage()); 
        }
        finally { 
           if (dis != null) {
              try {
                 dis.close();
              } catch (IOException ioe) {
                 System.out.println("IOException error trying to close the file: " + ioe.getMessage()); 
              }
           }
        }
		return records;
	}
}

**EDIT:**

playtennis.data is only a simple text file that describes the learned attribute “play tennis” for a particular learning episode. I modeled my playtennis.data file off of Tom Mitchell’s play tennis example in his book “Machine Learning.” Essentially what it contains is attributes describing each learning episode: outlook (sunny, overcast, rain), wind (strong, weak) and humidity (high, normal) and play tennis (yes, no). Based off this information, one can trace the decision tree to derive your learned attribute. A sample decision tree is below. All you have to do is create a text file (tab delimited) that describes each particular situation (make sure you match the columns in the text file to the parsing that occurs in the Java).

**EDIT 2:**

After so many requests, I have worked up a small playtennis.data file based on the Tom Mitchell “Machine Learning” book. This follows his example data exactly and can be found here (rename the file to “playtennis.data” before running as WordPress wouldn’t let me upload a .data file due to “security restrictions.” A small note of caution: the above code was put together very quickly for purposes of my own learning. I am not claiming that it is by any means complete or fault tolerant, however, I do believe that the entropy and gain calculation is sound and correct.