001    package org.maltparser.parser.guide.decision;
002    
003    import java.lang.reflect.Constructor;
004    import java.lang.reflect.InvocationTargetException;
005    import java.util.HashMap;
006    
007    import org.maltparser.core.exception.MaltChainedException;
008    import org.maltparser.core.feature.FeatureModel;
009    import org.maltparser.core.feature.FeatureVector;
010    import org.maltparser.core.syntaxgraph.DependencyStructure;
011    import org.maltparser.parser.DependencyParserConfig;
012    import org.maltparser.parser.guide.ClassifierGuide;
013    import org.maltparser.parser.guide.GuideException;
014    import org.maltparser.parser.guide.instance.AtomicModel;
015    import org.maltparser.parser.guide.instance.DecisionTreeModel;
016    import org.maltparser.parser.guide.instance.FeatureDivideModel;
017    import org.maltparser.parser.guide.instance.InstanceModel;
018    import org.maltparser.parser.history.action.GuideDecision;
019    import org.maltparser.parser.history.action.MultipleDecision;
020    import org.maltparser.parser.history.action.SingleDecision;
021    import org.maltparser.parser.history.container.TableContainer.RelationToNextDecision;
022    /**
023    *
024    * @author Johan Hall
025    * @since 1.1
026    **/
027    public class BranchedDecisionModel implements DecisionModel {
028            private ClassifierGuide guide;
029            private String modelName;
030            private FeatureModel featureModel;
031            private InstanceModel instanceModel;
032            private int decisionIndex;
033            private DecisionModel parentDecisionModel;
034            private HashMap<Integer,DecisionModel> children;
035            private String branchedDecisionSymbols;
036            
037            public BranchedDecisionModel(ClassifierGuide guide, FeatureModel featureModel) throws MaltChainedException {
038                    this.branchedDecisionSymbols = "";
039                    setGuide(guide);
040                    setFeatureModel(featureModel);
041                    setDecisionIndex(0);
042                    setModelName("bdm"+decisionIndex);
043                    setParentDecisionModel(null);
044            }
045            
046            public BranchedDecisionModel(ClassifierGuide guide, DecisionModel parentDecisionModel, String branchedDecisionSymbol) throws MaltChainedException {
047                    if (branchedDecisionSymbol != null && branchedDecisionSymbol.length() > 0) {
048                            this.branchedDecisionSymbols = branchedDecisionSymbol;
049                    } else {
050                            this.branchedDecisionSymbols = "";
051                    }
052                    setGuide(guide);
053                    setParentDecisionModel(parentDecisionModel);
054                    setDecisionIndex(parentDecisionModel.getDecisionIndex() + 1);
055                    setFeatureModel(parentDecisionModel.getFeatureModel());
056                    if (branchedDecisionSymbols != null && branchedDecisionSymbols.length() > 0) {
057                            setModelName("bdm"+decisionIndex+branchedDecisionSymbols);
058                    } else {
059                            setModelName("bdm"+decisionIndex);
060                    }
061                    this.parentDecisionModel = parentDecisionModel;
062            }
063            
064            public void updateFeatureModel() throws MaltChainedException {
065                    featureModel.update();
066            }
067            
068            public void updateCardinality() throws MaltChainedException {
069                    featureModel.updateCardinality();
070            }
071            
072    
073            public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
074                    if (instanceModel != null) {
075                            instanceModel.finalizeSentence(dependencyGraph);
076                    }
077                    if (children != null) {
078                            for (DecisionModel child : children.values()) {
079                                    child.finalizeSentence(dependencyGraph);
080                            }
081                    }
082            }
083            
084            public void noMoreInstances() throws MaltChainedException {
085                    if (guide.getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
086                            throw new GuideException("The decision model could not create it's model. ");
087                    }
088                    featureModel.updateCardinality();
089                    if (instanceModel != null) {
090                            instanceModel.noMoreInstances();
091                            instanceModel.train();
092                    }
093                    if (children != null) {
094                            for (DecisionModel child : children.values()) {
095                                    child.noMoreInstances();
096                            }
097                    }
098            }
099    
100            public void terminate() throws MaltChainedException {
101                    if (instanceModel != null) {
102                            instanceModel.terminate();
103                            instanceModel = null;
104                    }
105                    if (children != null) {
106                            for (DecisionModel child : children.values()) {
107                                    child.terminate();
108                            }
109                    }
110            }
111            
112            public void addInstance(GuideDecision decision) throws MaltChainedException {
113                    if (decision instanceof SingleDecision) {
114                            throw new GuideException("A branched decision model expect more than one decisions. ");
115                    }
116                    updateFeatureModel();
117                    final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
118                    if (instanceModel == null) {
119                            initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
120                    }
121                    
122                    instanceModel.addInstance(singleDecision);
123                    if (decisionIndex+1 < decision.numberOfDecisions()) {
124                            if (singleDecision.continueWithNextDecision()) {
125                                    if (children == null) {
126                                            children = new HashMap<Integer,DecisionModel>();
127                                    }
128                                    DecisionModel child = children.get(singleDecision.getDecisionCode());
129                                    if (child == null) {
130                                            child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 
131                                                            branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
132                                            children.put(singleDecision.getDecisionCode(), child);
133                                    }
134                                    child.addInstance(decision);
135                            }
136                    }
137            }
138            
139            public boolean predict(GuideDecision decision) throws MaltChainedException {
140                    if (decision instanceof SingleDecision) {
141                            throw new GuideException("A branched decision model expect more than one decisions. ");
142                    }
143                    updateFeatureModel();
144                    final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
145                    if (instanceModel == null) {
146                            initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
147                    }
148                    instanceModel.predict(singleDecision);
149                    if (decisionIndex+1 < decision.numberOfDecisions()) {
150                            if (singleDecision.continueWithNextDecision()) {
151                                    if (children == null) {
152                                            children = new HashMap<Integer,DecisionModel>();
153                                    }
154                                    DecisionModel child = children.get(singleDecision.getDecisionCode());
155                                    if (child == null) {
156                                            child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 
157                                                            branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
158                                            children.put(singleDecision.getDecisionCode(), child);
159                                    }
160                                    child.predict(decision);
161                            }
162                    }
163    
164                    return true;
165            }
166            
167            public FeatureVector predictExtract(GuideDecision decision) throws MaltChainedException {
168                    if (decision instanceof SingleDecision) {
169                            throw new GuideException("A branched decision model expect more than one decisions. ");
170                    }
171                    updateFeatureModel();
172                    final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
173                    if (instanceModel == null) {
174                            initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
175                    }
176                    FeatureVector fv = instanceModel.predictExtract(singleDecision);
177                    if (decisionIndex+1 < decision.numberOfDecisions()) {
178                            if (singleDecision.continueWithNextDecision()) {
179                                    if (children == null) {
180                                            children = new HashMap<Integer,DecisionModel>();
181                                    }
182                                    DecisionModel child = children.get(singleDecision.getDecisionCode());
183                                    if (child == null) {
184                                            child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 
185                                                            branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
186                                            children.put(singleDecision.getDecisionCode(), child);
187                                    }
188                                    child.predictExtract(decision);
189                            }
190                    }
191    
192                    return fv;
193            }
194            
195            public FeatureVector extract() throws MaltChainedException {
196                    updateFeatureModel();
197                    return instanceModel.extract(); // TODO handle many feature vectors
198            }
199            
200            public boolean predictFromKBestList(GuideDecision decision) throws MaltChainedException {
201                    if (decision instanceof SingleDecision) {
202                            throw new GuideException("A branched decision model expect more than one decisions. ");
203                    }
204                    
205                    boolean success = false;
206                    final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
207                    if (decisionIndex+1 < decision.numberOfDecisions()) {
208                            if (singleDecision.continueWithNextDecision()) {
209                                    if (children == null) {
210                                            children = new HashMap<Integer,DecisionModel>();
211                                    }
212                                    DecisionModel child = children.get(singleDecision.getDecisionCode());
213                                    if (child != null) {
214                                            success = child.predictFromKBestList(decision);
215                                    }
216                                    
217                            }
218                    }
219                    if (!success) {
220                            success = singleDecision.updateFromKBestList();
221                            if (decisionIndex+1 < decision.numberOfDecisions()) {
222                                    if (singleDecision.continueWithNextDecision()) {
223                                            if (children == null) {
224                                                    children = new HashMap<Integer,DecisionModel>();
225                                            }
226                                            DecisionModel child = children.get(singleDecision.getDecisionCode());
227                                            if (child == null) {
228                                                    child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1), 
229                                                                    branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
230                                                    children.put(singleDecision.getDecisionCode(), child);
231                                            }
232                                            child.predict(decision);
233                                    }
234                            }
235                    }
236                    return success;
237            }
238            
239    
240            public ClassifierGuide getGuide() {
241                    return guide;
242            }
243    
244            public String getModelName() {
245                    return modelName;
246            }
247            
248            public FeatureModel getFeatureModel() {
249                    return featureModel;
250            }
251    
252            public int getDecisionIndex() {
253                    return decisionIndex;
254            }
255    
256            public DecisionModel getParentDecisionModel() {
257                    return parentDecisionModel;
258            }
259    
260            private void setFeatureModel(FeatureModel featureModel) {
261                    this.featureModel = featureModel;
262            }
263            
264            private void setDecisionIndex(int decisionIndex) {
265                    this.decisionIndex = decisionIndex;
266            }
267            
268            private void setParentDecisionModel(DecisionModel parentDecisionModel) {
269                    this.parentDecisionModel = parentDecisionModel;
270            }
271    
272            private void setModelName(String modelName) {
273                    this.modelName = modelName;
274            }
275            
276            private void setGuide(ClassifierGuide guide) {
277                    this.guide = guide;
278            }
279            
280            
281            private DecisionModel initChildDecisionModel(SingleDecision decision, String branchedDecisionSymbol) throws MaltChainedException {
282                    Class<?> decisionModelClass = null;
283                    if (decision.getRelationToNextDecision() == RelationToNextDecision.SEQUANTIAL) {
284                            decisionModelClass = org.maltparser.parser.guide.decision.SeqDecisionModel.class;
285                    } else if (decision.getRelationToNextDecision() == RelationToNextDecision.BRANCHED) {
286                            decisionModelClass = org.maltparser.parser.guide.decision.BranchedDecisionModel.class;
287                    } else if (decision.getRelationToNextDecision() == RelationToNextDecision.NONE) {
288                            decisionModelClass = org.maltparser.parser.guide.decision.OneDecisionModel.class;
289                    }
290    
291                    if (decisionModelClass == null) {
292                            throw new GuideException("Could not find an appropriate decision model for the relation to the next decision"); 
293                    }
294                    
295                    try {
296                            Class<?>[] argTypes = { org.maltparser.parser.guide.ClassifierGuide.class, org.maltparser.parser.guide.decision.DecisionModel.class, 
297                                                    java.lang.String.class };
298                            Object[] arguments = new Object[3];
299                            arguments[0] = getGuide();
300                            arguments[1] = this;
301                            arguments[2] = branchedDecisionSymbol;
302                            Constructor<?> constructor = decisionModelClass.getConstructor(argTypes);
303                            return (DecisionModel)constructor.newInstance(arguments);
304                    } catch (NoSuchMethodException e) {
305                            throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
306                    } catch (InstantiationException e) {
307                            throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
308                    } catch (IllegalAccessException e) {
309                            throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
310                    } catch (InvocationTargetException e) {
311                            throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
312                    }
313            }
314            
315            private void initInstanceModel(String subModelName) throws MaltChainedException {
316                    FeatureVector fv = featureModel.getFeatureVector(branchedDecisionSymbols+"."+subModelName);
317                    if (fv == null) {
318                            fv = featureModel.getFeatureVector(subModelName);
319                    }
320                    if (fv == null) {
321                            fv = featureModel.getMainFeatureVector();
322                    }
323                    
324                    DependencyParserConfig c = guide.getConfiguration();
325                    
326    //              if (c.getOptionValue("guide", "tree_automatic_split_order").toString().equals("yes") ||
327    //                              (c.getOptionValue("guide", "tree_split_columns")!=null &&
328    //                      c.getOptionValue("guide", "tree_split_columns").toString().length() > 0) ||
329    //                      (c.getOptionValue("guide", "tree_split_structures")!=null &&
330    //                      c.getOptionValue("guide", "tree_split_structures").toString().length() > 0)) {
331    //                      instanceModel = new DecisionTreeModel(fv, this); 
332    //              }else 
333                    if (c.getOptionValue("guide", "data_split_column").toString().length() == 0) {
334                            instanceModel = new AtomicModel(-1, fv, this);
335                    } else {
336                            instanceModel = new FeatureDivideModel(fv, this);
337                    }
338            }
339            
340            public String toString() {
341                    final StringBuilder sb = new StringBuilder();
342                    sb.append(modelName + ", ");
343                    for (DecisionModel model : children.values()) {
344                            sb.append(model.toString() + ", ");
345                    }
346                    return sb.toString();
347            }
348    }