001    package org.maltparser.parser.guide.instance;
002    
003    import java.io.BufferedReader;
004    import java.io.BufferedWriter;
005    import java.io.IOException;
006    import java.util.ArrayList;
007    import java.util.Collections;
008    import java.util.HashMap;
009    import java.util.HashSet;
010    import java.util.LinkedList;
011    import java.util.List;
012    import java.util.Map;
013    import java.util.Set;
014    import java.util.SortedMap;
015    import java.util.TreeMap;
016    import java.util.Map.Entry;
017    import java.util.regex.Pattern;
018    
019    import org.maltparser.core.config.ConfigurationDir;
020    import org.maltparser.core.exception.MaltChainedException;
021    import org.maltparser.core.feature.FeatureException;
022    import org.maltparser.core.feature.FeatureVector;
023    import org.maltparser.core.feature.function.FeatureFunction;
024    import org.maltparser.core.feature.function.Modifiable;
025    import org.maltparser.core.feature.value.SingleFeatureValue;
026    import org.maltparser.core.syntaxgraph.DependencyStructure;
027    import org.maltparser.parser.guide.ClassifierGuide;
028    import org.maltparser.parser.guide.GuideException;
029    import org.maltparser.parser.guide.Model;
030    import org.maltparser.parser.history.action.SingleDecision;
031    
032    /**
033     * This class implements a decision tree model. The class is recursive and an
034     * instance of the class can be a root model or belong to an other decision tree
035     * model. Every node in the decision tree is represented by an instance of the
036     * class. Node can be in one of the three states branch model, leaf model or not
037     * decided. A branch model has several sub decision tree models and a leaf model
038     * owns an atomic model that is used to classify instances. When a decision tree
039     * model is in the not decided state it has both sub decision trees and an
040     * atomic model. It can be in the not decided state during training before it is
041     * tested by cross validation if the sub decision tree models provide better
042     * accuracy than the atomic model.
043     * 
044     * 
045     * @author Kjell Winblad
046     */
047    public class DecisionTreeModel implements InstanceModel {
048    
049            /*
050             * The leaf nodes needs a int index that is unique among all leaf nodes
051             * because they have an AtomicModel which need such an index.
052             */
053            private static int leafModelIndexConter = 0;
054    
055            private final static int OTHER_BRANCH_ID = 1000000;// Integer.MAX_VALUE;
056    
057            // The number of division used when doing cross validation test
058            private int numberOfCrossValidationSplits = 10;
059            /*
060             * Cross validation accuracy is calculated for every node during training
061             * This should be calculated for every node and is set to -1.0 if it isn't
062             * calculated yet
063             */
064            private final static double CROSS_VALIDATION_ACCURACY_NOT_SET_VALUE = -1.0;
065            private double crossValidationAccuracy = CROSS_VALIDATION_ACCURACY_NOT_SET_VALUE;
066            // The parent model
067            private Model parent = null;
068            // An ordered list of features to divide on
069            private LinkedList<FeatureFunction> divideFeatures = null;
070            /*
071             * The branches of the tree Is set to null if this is a leaf node
072             */
073            private SortedMap<Integer, DecisionTreeModel> branches = null;
074    
075            /*
076             * This model is used if this is a leaf node Is set to null if this is a
077             * branch node
078             */
079            private AtomicModel leafModel = null;
080            // Number of training instances added
081            private int frequency = 0;
082            /*
083             * min number of instances for a node to existAll sub nodes with less
084             * instances will be concatenated to one sub node
085             */
086            private int divideThreshold = 0;
087            // The feature vector for this problem
088            private FeatureVector featureVector;
089    
090            private FeatureVector subFeatureVector = null;
091    
092            // Used to indicate that the modelIndex field is not set
093            private static final int MODEL_INDEX_NOT_SET = Integer.MIN_VALUE;
094            /*
095             * Model index is the identifier used to distinguish this model from other
096             * models at the same level. This should not be used in the root model and
097             * has the value MODEL_INDEX_NOT_SET in it.
098             */
099            private int modelIndex = MODEL_INDEX_NOT_SET;
100            // Indexes of the column used to divide on
101            private ArrayList<Integer> divideFeatureIndexVector;
102    
103            private boolean automaticSplit = false;
104            private boolean treeForceDivide = false;
105    
106            /**
107             * Constructs a feature divide model.
108             * 
109             * @param featureVector
110             *            the feature vector used by the decision tree model
111             * @param parent
112             *            the parent guide model.
113             * @throws MaltChainedException
114             */
115            public DecisionTreeModel(FeatureVector featureVector, Model parent)
116                            throws MaltChainedException {
117    
118                    this.featureVector = featureVector;
119                    this.divideFeatures = new LinkedList<FeatureFunction>();
120                    setParent(parent);
121                    setFrequency(0);
122                    initDecisionTreeParam();
123    
124                    if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
125    
126                            // Prepare for training
127    
128                            branches = new TreeMap<Integer, DecisionTreeModel>();
129                            leafModel = new AtomicModel(-1, featureVector, this);
130    
131                    } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
132                            load();
133                    }
134            }
135    
136            /*
137             * This constructor is used from within objects of the class to create sub decision tree models
138             * 
139             *
140             */
141            private DecisionTreeModel(int modelIndex, FeatureVector featureVector,
142                            Model parent, LinkedList<FeatureFunction> divideFeatures,
143                            int divideThreshold) throws MaltChainedException {
144    
145                    this.featureVector = featureVector;
146    
147                    setParent(parent);
148                    setFrequency(0);
149    
150                    this.modelIndex = modelIndex;
151                    this.divideFeatures = divideFeatures;
152                    this.divideThreshold = divideThreshold;
153    
154                    if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
155    
156                            //Create the divide feature index vector
157                            if (divideFeatures.size() > 0) {
158    
159                                    divideFeatureIndexVector = new ArrayList<Integer>();
160                                    for (int i = 0; i < featureVector.size(); i++) {
161                                            if (featureVector.get(i).equals(divideFeatures.get(0))) {
162                                                    divideFeatureIndexVector.add(i);
163                                            }
164                                    }
165    
166                            }
167                            leafModelIndexConter++;
168    
169    
170                            // Prepare for training
171                            branches = new TreeMap<Integer, DecisionTreeModel>();
172                            leafModel = new AtomicModel(-1, featureVector, this);
173    
174                    } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
175                            load();
176                    }
177            }
178    
179            /**
180             * Loads the feature divide model settings .fsm file.
181             * 
182             * @throws MaltChainedException
183             */
184            protected void load() throws MaltChainedException {
185    
186                    ConfigurationDir configDir = getGuide().getConfiguration()
187                                    .getConfigurationDir();
188    
189    
190                    // load the dsm file
191    
192                    try {
193    
194                            final BufferedReader in = new BufferedReader(
195                                            configDir.getInputStreamReaderFromConfigFile(getModelName()
196                                                            + ".dsm"));
197                            final Pattern tabPattern = Pattern.compile("\t");
198    
199                            boolean first = true;
200                            while (true) {
201                                    String line = in.readLine();
202                                    if (line == null)
203                                            break;
204                                    String[] cols = tabPattern.split(line);
205                                    if (cols.length != 2) {
206                                            throw new GuideException("");
207                                    }
208                                    int code = -1;
209                                    int freq = 0;
210                                    try {
211                                            code = Integer.parseInt(cols[0]);
212                                            freq = Integer.parseInt(cols[1]);
213                                    } catch (NumberFormatException e) {
214                                            throw new GuideException(
215                                                            "Could not convert a string value into an integer value when loading the feature divide model settings (.fsm). ",
216                                                            e);
217                                    }
218    
219                                    if (code == MODEL_INDEX_NOT_SET) {
220                                            if (!first)
221                                                    throw new GuideException(
222                                                                    "Error in config file '"
223                                                                                    + getModelName()
224                                                                                    + ".dsm"
225                                                                                    + "'. If the index in the .dsm file is MODEL_INDEX_NOT_SET it should be the first.");
226    
227                                            first = false;
228                                            // It is a leaf node
229                                            // Create atomic model for the leaf node
230                                            leafModel = new AtomicModel(-1, featureVector, this);
231    
232                                            // setIsLeafNode();
233    
234                                    } else {
235                                            if (first) {
236                                                    // Create the branches holder
237    
238                                                    branches = new TreeMap<Integer, DecisionTreeModel>();
239    
240                                                    // setIsBranchNode();
241    
242                                                    first = false;
243                                            }
244    
245                                            if (branches == null)
246                                                    throw new GuideException(
247                                                                    "Error in config file '"
248                                                                                    + getModelName()
249                                                                                    + ".dsm"
250                                                                                    + "'. If MODEL_INDEX_NOT_SET is the first model index in the .dsm file it should be the only.");
251    
252                                            if (code == OTHER_BRANCH_ID)
253                                                    branches.put(code, new DecisionTreeModel(code,
254                                                                    featureVector, this,
255                                                                    new LinkedList<FeatureFunction>(),
256                                                                    divideThreshold));
257                                            else
258                                                    branches.put(code, new DecisionTreeModel(code,
259                                                                    getSubFeatureVector(), this,
260                                                                    createNextLevelDivideFeatures(),
261                                                                    divideThreshold));
262    
263                                            branches.get(code).setFrequency(freq);
264    
265                                            setFrequency(getFrequency() + freq);
266    
267                                    }
268    
269                            }
270                            in.close();
271    
272                    } catch (IOException e) {
273                            throw new GuideException(
274                                            "Could not read from the guide model settings file '"
275                                                            + getModelName() + ".dsm" + "', when "
276                                                            + "loading the guide model settings. ", e);
277                    }
278    
279            }
280    
281            private void initDecisionTreeParam() throws MaltChainedException {
282                    String treeSplitColumns = getGuide().getConfiguration().getOptionValue(
283                                    "guide", "tree_split_columns").toString();
284                    String treeSplitStructures = getGuide().getConfiguration()
285                                    .getOptionValue("guide", "tree_split_structures").toString();
286                    
287                    automaticSplit  = getGuide().getConfiguration()
288                    .getOptionValue("guide", "tree_automatic_split_order").toString().equals("yes");
289                    
290                    treeForceDivide = getGuide().getConfiguration()
291                    .getOptionValue("guide", "tree_force_divide").toString().equals("yes");
292                    
293                    if(automaticSplit){
294                            divideFeatures = new LinkedList<FeatureFunction>();
295                            for(FeatureFunction feature:featureVector){
296                                    if(feature.getFeatureValue() instanceof SingleFeatureValue)
297                                            divideFeatures.add(feature);
298                            }
299                            
300                            
301                    }else{
302                    
303                    if (treeSplitColumns == null || treeSplitColumns.length() == 0) {
304                            throw new GuideException(
305                                            "The option '--guide-tree_split_columns' cannot be found, when initializing the decision tree model. ");
306                    }
307    
308                    if (treeSplitStructures == null || treeSplitStructures.length() == 0) {
309                            throw new GuideException(
310                                            "The option '--guide-tree_split_structures' cannot be found, when initializing the decision tree model. ");
311                    }
312    
313                    String[] treeSplitColumnsArray = treeSplitColumns.split("@");
314                    String[] treeSplitStructuresArray = treeSplitStructures.split("@");
315    
316                    if (treeSplitColumnsArray.length != treeSplitStructuresArray.length)
317                            throw new GuideException(
318                                            "The option '--guide-tree_split_structures' and '--guide-tree_split_columns' must be followed by a ; separated lists of the same length");
319    
320                    try {
321    
322                            for (int n = 0; n < treeSplitColumnsArray.length; n++) {
323    
324                                    final String spec = "InputColumn("
325                                                    + treeSplitColumnsArray[n].trim() + ", "
326                                                    + treeSplitStructuresArray[n].trim() + ")";
327    
328                                    divideFeatures.addLast(featureVector.getFeatureModel()
329                                                    .identifyFeature(spec));
330                            }
331    
332                    } catch (FeatureException e) {
333                            throw new GuideException("The data split feature 'InputColumn("
334                                            + getGuide().getConfiguration().getOptionValue("guide",
335                                                            "data_split_column").toString()
336                                            + ", "
337                                            + getGuide().getConfiguration().getOptionValue("guide",
338                                                            "data_split_structure").toString()
339                                            + ") cannot be initialized. ", e);
340                    }
341    
342                    for (FeatureFunction divideFeature : divideFeatures) {
343                            if (!(divideFeature instanceof Modifiable)) {
344                                    throw new GuideException("The data split feature 'InputColumn("
345                                                    + getGuide().getConfiguration().getOptionValue("guide",
346                                                                    "data_split_column").toString()
347                                                    + ", "
348                                                    + getGuide().getConfiguration().getOptionValue("guide",
349                                                                    "data_split_structure").toString()
350                                                    + ") does not implement Modifiable interface. ");
351                            }
352                    }
353    
354                    divideFeatureIndexVector = new ArrayList<Integer>();
355                    for (int i = 0; i < featureVector.size(); i++) {
356    
357                            if (featureVector.get(i).equals(divideFeatures.get(0))) {
358    
359                                    divideFeatureIndexVector.add(i);
360                            }
361                    }
362    
363                    if (divideFeatureIndexVector.size() == 0) {
364                            throw new GuideException(
365                                            "Could not match the given divide features to any of the available features.");
366                    }
367    
368    
369    
370                    }
371    
372                    try {
373    
374                            String treeSplitTreshold = getGuide().getConfiguration()
375                                            .getOptionValue("guide", "tree_split_threshold").toString();
376    
377                            if (treeSplitTreshold != null && treeSplitTreshold.length() > 0) {
378    
379                                    divideThreshold = Integer.parseInt(treeSplitTreshold);
380    
381                            } else {
382                                    divideThreshold = 0;
383                            }
384                    } catch (NumberFormatException e) {
385                            throw new GuideException(
386                                            "The --guide-tree_split_threshold option is not an integer value. ",
387                                            e);
388                    }
389    
390                    try {
391    
392                            String treeNumberOfCrossValidationDivisions = getGuide()
393                                            .getConfiguration().getOptionValue("guide",
394                                                            "tree_number_of_cross_validation_divisions")
395                                            .toString();
396    
397                            if (treeNumberOfCrossValidationDivisions != null
398                                            && treeNumberOfCrossValidationDivisions.length() > 0) {
399    
400                                    numberOfCrossValidationSplits = Integer
401                                                    .parseInt(treeNumberOfCrossValidationDivisions);
402    
403                            } else {
404                                    divideThreshold = 0;
405                            }
406                    } catch (NumberFormatException e) {
407                            throw new GuideException(
408                                            "The --guide-tree_number_of_cross_validation_divisions option is not an integer value. ",
409                                            e);
410                    }
411            
412            }
413    
414            @Override
415            public void addInstance(SingleDecision decision)
416                            throws MaltChainedException {
417    
418                    if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
419                            throw new GuideException("Can only add instance during learning. ");
420                    } else if (divideFeatures.size() > 0) {
421                            //FeatureFunction divideFeature = divideFeatures.getFirst();
422    
423                            for (FeatureFunction divideFeature : divideFeatures) {
424                                    if (!(divideFeature.getFeatureValue() instanceof SingleFeatureValue)) {
425                                            throw new GuideException(
426                                                            "The divide feature does not have a single value. ");
427                                    }
428                                    // Is this necessary?
429                                    divideFeature.update();
430                            }
431                            leafModel.addInstance(decision);
432    
433                            //Update statistics data
434                            updateStatistics(decision);
435                            
436                            
437                    } else {
438                            // Model has already been decided. It is a leaf node
439                            if (branches != null)
440                                    setIsLeafNode();
441    
442                            leafModel.addInstance(decision);
443                            
444                            //Update statistics data
445                            updateStatistics(decision);
446    
447                    }
448                    
449                    
450                    
451    
452            }
453    
454            /*
455            private class StatisticsItem{
456    
457                    private int columnValue;
458                    
459                    private int classValue;
460                    
461                    public StatisticsItem(int columnValue, int classValue) {
462                            super();
463                            this.columnValue = columnValue;
464                            this.classValue = classValue;
465                    }
466                    
467                    public int getColumnValue() {
468                            return columnValue;
469                    }
470    
471                    public int getClassValue() {
472                            return classValue;
473                    }
474                    
475                    @Override
476                    public int hashCode() {
477                            return new Integer(columnValue/2).hashCode() + new Integer(classValue/2).hashCode();
478                    }
479    
480                    @Override
481                    public boolean equals(Object obj) {
482                            
483                            StatisticsItem compItem = (StatisticsItem)obj;
484                            
485                            return compItem.getClassValue()==this.getClassValue() && compItem.getColumnValue()==this.getColumnValue();
486                    }
487            }
488            */
489    
490            /*
491             * Helper method used for automatic division by gain ratio
492             * @param n
493             * @return
494             */
495            private double log2(double n){
496                    return Math.log(n)/Math.log(2.0);
497            }
498            
499            /*
500             * This map contains one item per element in the divideFeatures. Mappings exist from every Feature function
501             * in divideFeatures to a corresponding Statistics Item list that contains statistics for that divide feature.
502             * In all positions in the list are a list of StatisticsItems one for every unique feature class
503             * combination in the column. The statistics item also contain a count of that combination.
504             */
505            //private HashMap<FeatureFunction, HashMap<StatisticsItem, Integer>> statisticsForDivideFatureMap = null;
506            //The keys are class id's and the value is a count of the number of this
507            private HashMap<Integer,Integer> classIdToCountMap = null;
508            
509            private HashMap<FeatureFunction, HashMap<Integer,Integer>> featureIdToCountMap = null;
510            
511            //private HashMap<FeatureFunction, HashMap<Integer,Integer>> classIdToCountMap = new HashMap<FeatureFunction, HashMap<Integer,Integer>>();
512            
513            private HashMap<FeatureFunction, HashMap<Integer,HashMap<Integer,Integer>>> featureIdToClassIdToCountMap = null;
514            
515            private void updateStatistics(SingleDecision decision)
516                            throws MaltChainedException {
517    
518                    // if(statisticsForDivideFatureMap==null){
519                    // statisticsForDivideFatureMap = new HashMap<FeatureFunction,
520                    // HashMap<StatisticsItem, Integer>>();
521                    //                      
522                    // for(FeatureFunction columnsDivideFeature : divideFeatures)
523                    // statisticsForDivideFatureMap.put(columnsDivideFeature, new
524                    // HashMap<StatisticsItem, Integer>());
525                    // }
526                    //              
527                    //              
528                    // int instanceClass = decision.getDecisionCode();
529                    //              
530                    // Integer classCount = classCountStatistics.get(instanceClass);
531                    //              
532                    // if(classCount==null){
533                    // classCount=0;
534                    // }
535                    //              
536                    // classCountStatistics.put(instanceClass, classCount+1);
537                    //              
538                    // for(FeatureFunction columnsDivideFeature : featureVector){
539                    //                      
540                    // int featureCode =
541                    // ((SingleFeatureValue)columnsDivideFeature.getFeatureValue()).getCode();
542                    // HashMap<StatisticsItem, Integer> statisticsMap =
543                    // statisticsForDivideFatureMap.get(columnsDivideFeature);
544                    // if(statisticsMap!=null){
545                    //                              
546                    // StatisticsItem item = new StatisticsItem(featureCode, instanceClass);
547                    //                              
548                    // Integer count = statisticsMap.get(item);
549                    //                              
550                    // if(count==null){
551                    // //Add the statistic item to the map
552                    // count = 0;
553                    // }
554                    //                              
555                    // statisticsMap.put(item, count + 1);
556                    //                              
557                    // }
558                    //                      
559                    // }
560    
561                    // If it is not done initialize the statistics maps
562                    if (featureIdToCountMap == null) {
563    
564                            featureIdToCountMap = new HashMap<FeatureFunction, HashMap<Integer, Integer>>();
565    
566                            for (FeatureFunction columnsDivideFeature : divideFeatures)
567                                    featureIdToCountMap.put(columnsDivideFeature,
568                                                    new HashMap<Integer, Integer>());
569    
570    
571                            featureIdToClassIdToCountMap = new HashMap<FeatureFunction, HashMap<Integer, HashMap<Integer, Integer>>>();
572    
573                            for (FeatureFunction columnsDivideFeature : divideFeatures)
574                                    featureIdToClassIdToCountMap.put(columnsDivideFeature,
575                                                    new HashMap<Integer, HashMap<Integer, Integer>>());
576    
577                            classIdToCountMap = new HashMap<Integer, Integer>();
578    
579                    }
580    
581                    int instanceClass = decision.getDecisionCode();
582    
583                    // Increase classCountStatistics
584    
585                    Integer classCount = classIdToCountMap.get(instanceClass);
586    
587                    if (classCount == null) {
588                            classCount = 0;
589                    }
590    
591                    classIdToCountMap.put(instanceClass, classCount + 1);
592    
593                    // Increase featureIdToCountMap
594    
595                    for (FeatureFunction columnsDivideFeature : divideFeatures) {
596    
597                            int featureCode = ((SingleFeatureValue) columnsDivideFeature
598                                            .getFeatureValue()).getCode();
599    
600                            HashMap<Integer, Integer> statisticsMap = featureIdToCountMap
601                                            .get(columnsDivideFeature);
602    
603                            Integer count = statisticsMap.get(featureCode);
604    
605                            if (count == null) {
606                                    // Add the statistic item to the map
607                                    count = 0;
608                            }
609    
610                            statisticsMap.put(featureCode, count + 1);
611    
612                    }
613                    
614                    // Increase featureIdToClassIdToCountMap
615                    
616                    for (FeatureFunction columnsDivideFeature : divideFeatures) {
617    
618                            int featureCode = ((SingleFeatureValue) columnsDivideFeature
619                                            .getFeatureValue()).getCode();
620                            
621                            HashMap<Integer, HashMap<Integer, Integer>> featureIdToclassIdToCountMapTmp = featureIdToClassIdToCountMap
622                                            .get(columnsDivideFeature);
623    
624                            HashMap<Integer, Integer> classIdToCountMapTmp = featureIdToclassIdToCountMapTmp.get(featureCode);
625    
626                            if (classIdToCountMapTmp == null) {
627                                    // Add the statistic item to the map
628                                    classIdToCountMapTmp = new HashMap<Integer, Integer>();
629                                    
630                                    featureIdToclassIdToCountMapTmp.put(featureCode, classIdToCountMapTmp);
631                            }
632                            
633                            Integer count = classIdToCountMapTmp.get(instanceClass);
634    
635                            if (count == null) {
636                                    // Add the statistic item to the map
637                                    count = 0;
638                            }
639    
640                            classIdToCountMapTmp.put(instanceClass, count + 1);
641    
642                    }
643    
644            }
645    
646            @SuppressWarnings("unchecked")
647            private LinkedList<FeatureFunction> createNextLevelDivideFeatures() {
648    
649                    LinkedList<FeatureFunction> nextLevelDivideFeatures = (LinkedList<FeatureFunction>) divideFeatures
650                                    .clone();
651    
652                    nextLevelDivideFeatures.removeFirst();
653    
654                    return nextLevelDivideFeatures;
655            }
656    
657            /*
658             * Removes the current divide feature from the feature vector so it is not
659             * present in the sub node
660             */
661            private FeatureVector getSubFeatureVector() {
662    
663                    if (subFeatureVector != null)
664                            return subFeatureVector;
665    
666                    FeatureFunction divideFeature = divideFeatures.getFirst();
667    
668                    ArrayList<Integer> divideFeatureIndexVector = new ArrayList<Integer>();
669                    for (int i = 0; i < featureVector.size(); i++) {
670                            if (featureVector.get(i).equals(divideFeature)) {
671                                    divideFeatureIndexVector.add(i);
672                            }
673                    }
674    
675                    FeatureVector divideFeatureVector = (FeatureVector) featureVector
676                                    .clone();
677    
678                    for (Integer i : divideFeatureIndexVector) {
679                            divideFeatureVector.remove(divideFeatureVector.get(i));
680                    }
681    
682                    subFeatureVector = divideFeatureVector;
683    
684                    return divideFeatureVector;
685            }
686    
687            @Override
688            public FeatureVector extract() throws MaltChainedException {
689    
690                    return getCurrentAtomicModel().extract();
691    
692            }
693    
694            /*
695             * Returns the atomic model that is effected by this parsing step
696             */
697            private AtomicModel getCurrentAtomicModel() throws MaltChainedException {
698    
699                    if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
700                            throw new GuideException("Can only predict during parsing. ");
701                    }
702    
703                    if (branches == null && leafModel != null)
704                            return leafModel;
705    
706                    FeatureFunction divideFeature = divideFeatures.getFirst();
707    
708                    if (!(divideFeature.getFeatureValue() instanceof SingleFeatureValue)) {
709                            throw new GuideException(
710                                            "The divide feature does not have a single value. ");
711                    }
712    
713                    if (branches != null
714                                    && branches.containsKey(((SingleFeatureValue) divideFeature
715                                                    .getFeatureValue()).getCode())) {
716                            return branches.get(
717                                            ((SingleFeatureValue) divideFeature.getFeatureValue())
718                                                            .getCode()).getCurrentAtomicModel();
719                    } else if (branches.containsKey(OTHER_BRANCH_ID)
720                                    && branches.get(OTHER_BRANCH_ID).getFrequency() > 0) {
721                            return branches.get(OTHER_BRANCH_ID).getCurrentAtomicModel();
722                    } else {
723                            getGuide()
724                                            .getConfiguration()
725                                            .getConfigLogger()
726                                            .info(
727                                                            "Could not predict the next parser decision because there is "
728                                                                            + "no divide or master model that covers the divide value '"
729                                                                            + ((SingleFeatureValue) divideFeature
730                                                                                            .getFeatureValue()).getCode()
731                                                                            + "', as default"
732                                                                            + " class code '1' is used. ");
733                    }
734                    return null;
735            }
736    
737            /**
738             * Increase the frequency by 1
739             */
740            public void increaseFrequency() {
741                    frequency++;
742            }
743    
744            public void decreaseFrequency() {
745                    frequency--;
746            }
747    
748            @Override
749            public boolean predict(SingleDecision decision) throws MaltChainedException {
750    
751                    if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
752                            throw new GuideException("Can only predict during parsing. ");
753                    } else if (divideFeatures.size() > 0
754                                    && !(divideFeatures.getFirst().getFeatureValue() instanceof SingleFeatureValue)) {
755                            throw new GuideException(
756                                            "The divide feature does not have a single value. ");
757                    }
758    
759                    
760                    if (branches != null
761                                    && branches.containsKey(((SingleFeatureValue) divideFeatures
762                                                    .getFirst().getFeatureValue()).getCode())) {
763                            
764                            return branches.get(
765                                            ((SingleFeatureValue) divideFeatures.getFirst()
766                                                            .getFeatureValue()).getCode()).predict(decision);
767                    } else if (branches != null && branches.containsKey(OTHER_BRANCH_ID)) {
768                            
769                            return branches.get(OTHER_BRANCH_ID).predict(decision);
770                    } else if (leafModel != null) {
771                            
772                            return leafModel.predict(decision);
773                    } else {
774    
775                            getGuide()
776                                            .getConfiguration()
777                                            .getConfigLogger()
778                                            .info(
779                                                            "Could not predict the next parser decision because there is "
780                                                                            + "no divide or master model that covers the divide value '"
781                                                                            + ((SingleFeatureValue) divideFeatures
782                                                                                            .getFirst().getFeatureValue())
783                                                                                            .getCode() + "', as default"
784                                                                            + " class code '1' is used. ");
785    
786                            decision.addDecision(1); // default prediction
787                            // classCodeTable.getEmptyKBestList().addKBestItem(1);
788                    }
789                    return true;
790            }
791    
792            @Override
793            public FeatureVector predictExtract(SingleDecision decision)
794                            throws MaltChainedException {
795                    return getCurrentAtomicModel().predictExtract(decision);
796            }
797    
798            /*
799             * Decides if this is a branch or leaf node by doing cross validation and
800             * returns the cross validation score for this node
801             */
802            private double decideNodeType() throws MaltChainedException {
803    
804                    // We don't want to do this twice test
805                    if (crossValidationAccuracy != CROSS_VALIDATION_ACCURACY_NOT_SET_VALUE)
806                            return crossValidationAccuracy;
807    
808                    if (modelIndex == MODEL_INDEX_NOT_SET)
809                            if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) {
810                                    getGuide().getConfiguration().getConfigLogger().info(
811                                                    "Starting deph first pruning of the decision tree\n");
812                            }
813    
814                    long start = System.currentTimeMillis();
815    
816                    double leafModelCrossValidationAccuracy = 0.0;
817                    
818                    if(treeForceDivide)
819                            if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) {
820                                    getGuide().getConfiguration().getConfigLogger().info(
821                                                    "Skipping cross validation of the root node since the flag tree_force_divide is set to yes. " +
822                                                    "The cross validation score for the root node is set to zero.\n");
823                            }
824                    
825                    if(!treeForceDivide)
826                            leafModelCrossValidationAccuracy = leafModel.getMethod()
827                                    .crossValidate(featureVector, numberOfCrossValidationSplits);
828    
829                    long stop = System.currentTimeMillis();
830    
831                    if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) {
832                            getGuide().getConfiguration().getConfigLogger().info(
833                                            "Cross Validation Time: " + (stop - start) + " ms"
834                                                            + " for model " + getModelName() + "\n");
835                    }
836    
837                    if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) {
838                            getGuide().getConfiguration().getConfigLogger().info(
839                                            "Cross Validation Accuracy as leaf node = "
840                                                            + leafModelCrossValidationAccuracy + " for model "
841                                                            + getModelName() + "\n");
842                    }
843    
844                    if (branches == null && leafModel != null) {// If it is already decided
845                                                                                                            // that this is a leaf node
846    
847                            crossValidationAccuracy = leafModelCrossValidationAccuracy;
848    
849                            return crossValidationAccuracy;
850    
851                    }
852    
853                    int totalFrequency = 0;
854                    double totalAccuracyCount = 0.0;
855                    // Calculate crossValidationAccuracy for branch nodes
856                    for (DecisionTreeModel b : branches.values()) {
857    
858                            double bAccuracy = b.decideNodeType();
859    
860                            totalFrequency = totalFrequency + b.getFrequency();
861    
862                            totalAccuracyCount = totalAccuracyCount + bAccuracy
863                                            * b.getFrequency();
864    
865                    }
866    
867                    double branchModelCrossValidationAccuracy = totalAccuracyCount
868                                    / totalFrequency;
869    
870                    if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) {
871                            getGuide().getConfiguration().getConfigLogger().info(
872                                            "Total Cross Validation Accuracy for branches = "
873                                                            + branchModelCrossValidationAccuracy
874                                                            + " for model " + getModelName() + "\n");
875                    }
876    
877                    // Finally decide which model to use
878                    if (branchModelCrossValidationAccuracy > leafModelCrossValidationAccuracy) {
879    
880                            setIsBranchNode();
881    
882                            crossValidationAccuracy = branchModelCrossValidationAccuracy;
883    
884                            return crossValidationAccuracy;
885    
886                    } else {
887    
888                            setIsLeafNode();
889    
890                            crossValidationAccuracy = leafModelCrossValidationAccuracy;
891    
892                            return crossValidationAccuracy;
893    
894                    }
895    
896            }
897    
898            @Override
899            public void train() throws MaltChainedException {
900    
901                    // Decide node type
902                    // This operation is more expensive than the training itself
903                    decideNodeType();
904    
905                    // Do the training depending on which type of node this is
906                    if (branches == null && leafModel != null) {
907    
908                            // If it is a leaf node
909    
910                            leafModel.train();
911    
912                            save();
913    
914                            leafModel.terminate();
915    
916                    } else {
917                            // It is a branch node
918    
919                            for (DecisionTreeModel b : branches.values())
920                                    b.train();
921    
922                            save();
923    
924                            for (DecisionTreeModel b : branches.values())
925                                    b.terminate();
926    
927                    }
928                    terminate();
929    
930            }
931    
932            /**
933             * Saves the decision tree model settings .dsm file.
934             * 
935             * @throws MaltChainedException
936             */
937            private void save() throws MaltChainedException {
938                    try {
939    
940                            final BufferedWriter out = new BufferedWriter(getGuide()
941                                            .getConfiguration().getConfigurationDir()
942                                            .getOutputStreamWriter(getModelName() + ".dsm"));
943    
944                            if (branches != null) {
945                                    for (DecisionTreeModel b : branches.values()) {
946                                            out.write(b.getModelIndex() + "\t" + b.getFrequency()
947                                                            + "\n");
948                                    }
949                            } else {
950                                    out.write(MODEL_INDEX_NOT_SET + "\t" + getFrequency() + "\n");
951                            }
952    
953                            out.close();
954    
955                    } catch (IOException e) {
956                            throw new GuideException(
957                                            "Could not write to the guide model settings file '"
958                                                            + getModelName() + ".dsm"
959                                                            + "' or the name mapping file '" + getModelName()
960                                                            + ".nmf" + "', when "
961                                                            + "saving the guide model settings to files. ", e);
962                    }
963            }
964    
965            @Override
966            public void finalizeSentence(DependencyStructure dependencyGraph)
967                            throws MaltChainedException {
968    
969                    if (branches != null) {
970    
971                            for (DecisionTreeModel b : branches.values()) {
972                                    b.finalizeSentence(dependencyGraph);
973                            }
974    
975                    } else if (leafModel != null) {
976    
977                            leafModel.finalizeSentence(dependencyGraph);
978    
979                    } else {
980    
981                            throw new GuideException(
982                                            "The feature divide models cannot be found. ");
983    
984                    }
985    
986            }
987    
988            @Override
989            public ClassifierGuide getGuide() {
990                    return parent.getGuide();
991            }
992    
993            @Override
994            public String getModelName() throws MaltChainedException {
995                    try {
996    
997                            return parent.getModelName()
998                                            + (modelIndex == MODEL_INDEX_NOT_SET ? ""
999                                                            : ("_" + modelIndex));
1000                    } catch (NullPointerException e) {
1001                            throw new GuideException(
1002                                            "The parent guide model cannot be found. ", e);
1003                    }
1004            }
1005    
1006            /*
1007             * This is called to define this node as to be in the leaf state. It sets branches to null.
1008             */
1009            private void setIsLeafNode() throws MaltChainedException {
1010    
1011                    if (branches == null && leafModel != null)
1012                            return;
1013    
1014                    if (branches != null && leafModel != null) {
1015    
1016                            for (DecisionTreeModel t : branches.values())
1017                                    t.terminate();
1018    
1019                            branches = null;
1020    
1021                    } else
1022                            throw new MaltChainedException(
1023                                            "Can't set a node that have aleready been set to a leaf node.");
1024    
1025            }
1026            /*
1027             * This is called to define this node as to be in the branch state. It sets leafModel to null.
1028             */
1029            private void setIsBranchNode() throws MaltChainedException {
1030                    if (branches != null && leafModel != null) {
1031    
1032                            leafModel.terminate();
1033    
1034                            leafModel = null;
1035    
1036                    } else
1037                            throw new MaltChainedException(
1038                                            "Can't set a node that have aleready been set to a branch node.");
1039    
1040            }
1041    
1042    
1043            @Override
1044            public void noMoreInstances() throws MaltChainedException {
1045    
1046                    if (leafModel == null)
1047                            throw new GuideException(
1048                                            "The model in tree node is null in a state where it is not allowed");
1049    
1050                    leafModel.noMoreInstances();
1051    
1052                    if (divideFeatures.size() == 0)
1053                            setIsLeafNode();
1054    
1055                    if (branches != null) {
1056                            
1057                            if(automaticSplit){
1058                                    
1059                                    divideFeatures = createGainRatioSplitList(divideFeatures);
1060                                    
1061                                    divideFeatureIndexVector = new ArrayList<Integer>();
1062                                    for (int i = 0; i < featureVector.size(); i++) {
1063    
1064                                            if (featureVector.get(i).equals(divideFeatures.get(0))) {
1065    
1066                                                    divideFeatureIndexVector.add(i);
1067                                            }
1068                                    }
1069    
1070                                    if (divideFeatureIndexVector.size() == 0) {
1071                                            throw new GuideException(
1072                                                            "Could not match the given divide features to any of the available features.");
1073                                    }
1074                                    
1075                            }
1076    
1077                            FeatureFunction divideFeature = divideFeatures.getFirst();
1078    
1079                            divideFeature.updateCardinality();
1080    
1081                            leafModel.noMoreInstances();
1082    
1083                            Map<Integer, Integer> divideFeatureIdToCountMap = leafModel
1084                                            .getMethod().createFeatureIdToCountMap(
1085                                                            divideFeatureIndexVector);
1086    
1087                            int totalInOther = 0;
1088    
1089                            Set<Integer> featureIdsToCreateSeparateBranchesForSet = new HashSet<Integer>();
1090    
1091                            List<Integer> removeFromDivideFeatureIdToCountMap = new LinkedList<Integer>();
1092    
1093                            for (Entry<Integer, Integer> entry : divideFeatureIdToCountMap
1094                                            .entrySet())
1095                                    if (entry.getValue() >= divideThreshold) {
1096                                            featureIdsToCreateSeparateBranchesForSet
1097                                                            .add(entry.getKey());
1098                                    } else {
1099                                            removeFromDivideFeatureIdToCountMap.add(entry.getKey());
1100                                            totalInOther = totalInOther + entry.getValue();
1101                                    }
1102    
1103                            for (int removeIndex : removeFromDivideFeatureIdToCountMap)
1104                                    divideFeatureIdToCountMap.remove(removeIndex);
1105    
1106                            boolean otherExists = false;
1107    
1108                            if (totalInOther > 0)
1109                                    otherExists = true;
1110    
1111                            if ((totalInOther < divideThreshold && featureIdsToCreateSeparateBranchesForSet
1112                                            .size() <= 1)
1113                                            || featureIdsToCreateSeparateBranchesForSet.size() == 0) {
1114                                    // Node enough instances, make this a leaf node
1115                                    setIsLeafNode();
1116                            } else {
1117    
1118                                    // If total in other is less then divideThreshold then add the
1119                                    // smallest of the other parts to other
1120                                    if (otherExists && totalInOther < divideThreshold) {
1121                                            int smallestSoFar = Integer.MAX_VALUE;
1122                                            int smallestSoFarId = Integer.MAX_VALUE;
1123                                            for (Entry<Integer, Integer> entry : divideFeatureIdToCountMap
1124                                                            .entrySet()) {
1125                                                    if (entry.getValue() < smallestSoFar) {
1126                                                            smallestSoFar = entry.getValue();
1127                                                            smallestSoFarId = entry.getKey();
1128                                                    }
1129                                            }
1130    
1131                                            featureIdsToCreateSeparateBranchesForSet
1132                                                            .remove(smallestSoFarId);
1133                                    }
1134    
1135                                    // Create new files for all feature ids with count value greater
1136                                    // than divideThreshold and one for the
1137                                    // other branch
1138                                    leafModel.getMethod().divideByFeatureSet(
1139                                                    featureIdsToCreateSeparateBranchesForSet,
1140                                                    divideFeatureIndexVector, "" + OTHER_BRANCH_ID);
1141    
1142                                    for (int id : featureIdsToCreateSeparateBranchesForSet) {
1143                                            DecisionTreeModel newBranch = new DecisionTreeModel(id,
1144                                                            getSubFeatureVector(), this,
1145                                                            createNextLevelDivideFeatures(), divideThreshold);
1146                                            branches.put(id, newBranch);
1147    
1148                                    }
1149                                    if (otherExists) {
1150                                            DecisionTreeModel newBranch = new DecisionTreeModel(
1151                                                            OTHER_BRANCH_ID, featureVector, this,
1152                                                            new LinkedList<FeatureFunction>(), divideThreshold);
1153                                            branches.put(OTHER_BRANCH_ID, newBranch);
1154    
1155                                    }
1156    
1157                                    for (DecisionTreeModel b : branches.values())
1158                                            b.noMoreInstances();
1159    
1160                            }
1161    
1162                    }
1163    
1164            }
1165    
1166            @Override
1167            public void terminate() throws MaltChainedException {
1168                    if (branches != null) {
1169                            for (DecisionTreeModel branch : branches.values()) {
1170                                    branch.terminate();
1171                            }
1172                            branches = null;
1173                    }
1174                    if (leafModel != null) {
1175                            leafModel.terminate();
1176                            leafModel = null;
1177                    }
1178    
1179            }
1180    
1181            public void setParent(Model parent) {
1182                    this.parent = parent;
1183            }
1184    
1185            public Model getParent() {
1186                    return parent;
1187            }
1188    
1189            public void setFrequency(int frequency) {
1190                    this.frequency = frequency;
1191            }
1192    
1193            public int getFrequency() {
1194                    return frequency;
1195            }
1196    
1197            public int getModelIndex() {
1198                    return modelIndex;
1199            }
1200    
1201    
1202            private LinkedList<FeatureFunction> createGainRatioSplitList(LinkedList<FeatureFunction> divideFeatures) {
1203                    
1204                    if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) {
1205                            
1206                            getGuide().getConfiguration().getConfigLogger().info(
1207                                            "Start calculating gain ratio for all posible divide features");
1208                    }
1209                    
1210                    //Calculate the root entropy
1211                    
1212                    double total = 0;
1213                    
1214                    for(int count: classIdToCountMap.values()){
1215                            double fraction = ((double)count) / getFrequency();
1216                            total = total + fraction*log2(fraction);
1217                    }
1218                    
1219                    double rootEntropy = -total;
1220                    
1221                    
1222                    class FeatureFunctionInformationGainPair implements Comparable<FeatureFunctionInformationGainPair>{
1223                            double informationGain;
1224                            FeatureFunction featureFunction;
1225                            double splitInfo;
1226                            
1227                            public FeatureFunctionInformationGainPair(
1228                                            FeatureFunction featureFunction) {
1229                                    super();
1230                                    this.featureFunction = featureFunction;
1231                            }
1232                            
1233                            public double getGainRatio(){
1234                                    return informationGain/splitInfo;
1235                            }
1236                            
1237                            @Override
1238                            public int compareTo(FeatureFunctionInformationGainPair o) {
1239                                    
1240                                    int result = 0;
1241                                    
1242                                    if((this.getGainRatio() - o.getGainRatio()) <0)
1243                                            result = -1;
1244                                    else if ((this.getGainRatio() - o.getGainRatio()) >0)
1245                                            result = 1;
1246                                    
1247                                    return result;
1248                            }
1249                    }
1250    
1251                    ArrayList<FeatureFunctionInformationGainPair> gainRatioList = new ArrayList<FeatureFunctionInformationGainPair>();
1252                    
1253                    for(FeatureFunction f: divideFeatures)
1254                            gainRatioList.add(new FeatureFunctionInformationGainPair(f));
1255                    
1256                    //For all divide features calculate the gain ratio
1257                    
1258                    for(FeatureFunctionInformationGainPair p : gainRatioList){
1259    
1260                            HashMap<Integer, Integer> featureIdToCountMapTmp = featureIdToCountMap.get(p.featureFunction);
1261                            
1262                            HashMap<Integer, HashMap<Integer, Integer>> featureIdToClassIdToCountMapTmp = featureIdToClassIdToCountMap.get(p.featureFunction);
1263    
1264                            double sum = 0;
1265                            
1266                            for(Entry<Integer, Integer> entry:featureIdToCountMapTmp.entrySet()){
1267                                    int featureId = entry.getKey();
1268                                    int numberOfElementsWithFeatureId = entry.getValue();
1269                                    HashMap<Integer, Integer> classIdToCountMapTmp = featureIdToClassIdToCountMapTmp.get(featureId);
1270                                    
1271                                    double sumImpurityMesure = 0;
1272                                    int totalElementsWithIdAndClass = 0;
1273                                    for(int elementsWithIdAndClass : classIdToCountMapTmp.values()){
1274                                    
1275                                            double fractionOfInstancesBelongingToClass = ((double)elementsWithIdAndClass)/numberOfElementsWithFeatureId;
1276                                    
1277                                            totalElementsWithIdAndClass = totalElementsWithIdAndClass + elementsWithIdAndClass;
1278                                            
1279                                            sumImpurityMesure= sumImpurityMesure+fractionOfInstancesBelongingToClass*log2(fractionOfInstancesBelongingToClass);
1280                                    
1281                                    }
1282                                    
1283                                    double impurityMesure = -sumImpurityMesure;
1284                                    
1285                                    sum = sum + (((double)numberOfElementsWithFeatureId)/getFrequency())*impurityMesure;
1286                                    
1287                            }
1288                            p.informationGain = rootEntropy - sum;
1289                            
1290                            //Calculate split info
1291                            
1292                            double splitInfoTotal = 0;
1293                            
1294                            for(int nrOfElementsWithFeatureId:featureIdToCountMapTmp.values()){
1295                                    double fractionOfTotal = ((double)nrOfElementsWithFeatureId)/getFrequency();
1296                                    splitInfoTotal = splitInfoTotal + fractionOfTotal*log2(fractionOfTotal);
1297                            }
1298                            p.splitInfo= splitInfoTotal;
1299                            
1300                            
1301                    }
1302                    Collections.sort(gainRatioList);
1303    
1304    
1305                    
1306                    //Log the result if info is enabled
1307                    if (getGuide().getConfiguration().getConfigLogger().isInfoEnabled()) {
1308                            
1309                            getGuide().getConfiguration().getConfigLogger().info(
1310                                            "Gain ratio calculation finished the result follows:\n");
1311                            getGuide().getConfiguration().getConfigLogger().info(
1312                            "Divide Feature\tGain Ratio\tInformation Gain\tSplit Info\n");
1313                            
1314                            for(FeatureFunctionInformationGainPair p :gainRatioList)
1315                                    getGuide().getConfiguration().getConfigLogger().info(
1316                                     p.featureFunction + "\t" + p.getGainRatio() + "\t" +  p.informationGain + "\t" +  p.splitInfo  +"\n");
1317                    }
1318                    
1319                    LinkedList<FeatureFunction> divideFeaturesNew = new LinkedList<FeatureFunction>();
1320                    
1321                    for(FeatureFunctionInformationGainPair p :gainRatioList)
1322                            divideFeaturesNew.add(p.featureFunction);
1323                    
1324                    
1325                    return divideFeaturesNew;
1326    
1327            }
1328    
1329    }