001    package org.maltparser.parser.guide.instance;
002    
003    import java.io.IOException;
004    import java.lang.reflect.Constructor;
005    import java.lang.reflect.InvocationTargetException;
006    import java.util.ArrayList;
007    import java.util.Formatter;
008    
009    import org.maltparser.core.exception.MaltChainedException;
010    import org.maltparser.core.feature.FeatureVector;
011    import org.maltparser.core.feature.function.FeatureFunction;
012    import org.maltparser.core.feature.function.Modifiable;
013    import org.maltparser.core.syntaxgraph.DependencyStructure;
014    import org.maltparser.ml.LearningMethod;
015    import org.maltparser.parser.guide.ClassifierGuide;
016    import org.maltparser.parser.guide.GuideException;
017    import org.maltparser.parser.guide.Model;
018    import org.maltparser.parser.history.action.SingleDecision;
019    
020    
021    /**
022    
023    @author Johan Hall
024    @since 1.0
025    */
026    public class AtomicModel implements InstanceModel {
027            private Model parent;
028            private String modelName;
029            private FeatureVector featureVector;
030            private int index;
031            private int frequency = 0;
032            private LearningMethod method;
033    
034            
035            /**
036             * Constructs an atomic model.
037             * 
038             * @param index the index of the atomic model (-1..n), where -1 is special value (used by a single model 
039             * or the master divide model) and n is number of divide models.
040             * @param features the feature vector used by the atomic model.
041             * @param parent the parent guide model.
042             * @throws MaltChainedException
043             */
044            public AtomicModel(int index, FeatureVector features, Model parent) throws MaltChainedException {
045                    setParent(parent);
046                    setIndex(index);
047                    if (index == -1) {
048                            setModelName(parent.getModelName()+".");
049                    } else {
050                            setModelName(parent.getModelName()+"."+new Formatter().format("%03d", index)+".");
051                    }
052                    setFeatures(features);
053                    setFrequency(0);
054                    initMethod();
055                    if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH && index == -1 && getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter() != null) {
056                            try {
057                                    getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().write(method.toString());
058                                    getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().flush();
059                            } catch (IOException e) {
060                                    throw new GuideException("Could not write learner settings to the information file. ", e);
061                            }
062                    }
063            }
064            
065            public void addInstance(SingleDecision decision) throws MaltChainedException {
066                    try {
067                            method.addInstance(decision, featureVector);
068                    } catch (NullPointerException e) {
069                            throw new GuideException("The learner cannot be found. ", e);
070                    }
071            }
072    
073            
074            public void noMoreInstances() throws MaltChainedException {
075                    try {
076                            method.noMoreInstances();
077                    } catch (NullPointerException e) {
078                            throw new GuideException("The learner cannot be found. ", e);
079                    }
080            }
081            
082            public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
083                    try {
084                            method.finalizeSentence(dependencyGraph);
085                    } catch (NullPointerException e) {
086                            throw new GuideException("The learner cannot be found. ", e);
087                    }
088            }
089    
090            public boolean predict(SingleDecision decision) throws MaltChainedException {
091                    try {
092                            if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
093                                    throw new GuideException("Cannot predict during batch training. ");
094                            }
095                            return method.predict(featureVector, decision);
096                    } catch (NullPointerException e) {
097                            throw new GuideException("The learner cannot be found. ", e);
098                    }
099            }
100    
101            public FeatureVector predictExtract(SingleDecision decision) throws MaltChainedException {
102                    try {
103                            if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
104                                    throw new GuideException("Cannot predict during batch training. ");
105                            }
106                            if (method.predict(featureVector, decision)) {
107                                    return featureVector;
108                            }
109                            return null;
110                    } catch (NullPointerException e) {
111                            throw new GuideException("The learner cannot be found. ", e);
112                    }
113            }
114            
115            public FeatureVector extract() throws MaltChainedException {
116                    return featureVector;
117            }
118            
119            public void terminate() throws MaltChainedException {
120                    if (method != null) {
121                            method.terminate();
122                            method = null;
123                    }
124                    featureVector = null;
125                    parent = null;
126            }
127            
128            /**
129             * Moves all instance from this atomic model into the destination atomic model and add the divide feature.
130             * This method is used by the feature divide model to sum up all model below a certain threshold.
131             * 
132             * @param model the destination atomic model 
133             * @param divideFeature the divide feature
134             * @param divideFeatureIndexVector the divide feature index vector
135             * @throws MaltChainedException
136             */
137            public void moveAllInstances(AtomicModel model, FeatureFunction divideFeature, ArrayList<Integer> divideFeatureIndexVector) throws MaltChainedException {
138                    if (method == null) {
139                            throw new GuideException("The learner cannot be found. ");
140                    } else if (model == null) {
141                            throw new GuideException("The guide model cannot be found. ");
142                    } else if (divideFeature == null) {
143                            throw new GuideException("The divide feature cannot be found. ");
144                    } else if (divideFeatureIndexVector == null) {
145                            throw new GuideException("The divide feature index vector cannot be found. ");
146                    }
147                    ((Modifiable)divideFeature).setFeatureValue(index);
148                    method.moveAllInstances(model.getMethod(), divideFeature, divideFeatureIndexVector);
149                    method.terminate();
150                    method = null;
151            }
152            
153            /**
154             * Invokes the train() of the learning method 
155             * 
156             * @throws MaltChainedException
157             */
158            public void train() throws MaltChainedException {
159                    try {
160                            method.train(featureVector);
161                            method.terminate();
162                            method = null;
163                    } catch (NullPointerException e) {      
164                            throw new GuideException("The learner cannot be found. ", e);
165                    }
166            }
167            
168            /**
169             * Initialize the learning method according to the option --learner-method.
170             * 
171             * @throws MaltChainedException
172             */
173            public void initMethod() throws MaltChainedException {
174                    Class<?> clazz = (Class<?>)getGuide().getConfiguration().getOptionValue("guide", "learner");
175    //              if (clazz == org.maltparser.ml.libsvm.Libsvm.class && (Boolean)getGuide().getConfiguration().getOptionValue("malt0.4", "behavior") == true) {
176    //                      try {
177    //                              clazz = Class.forName("org.maltparser.ml.libsvm.malt04.LibsvmMalt04");
178    //                      } catch (ClassNotFoundException e) {
179    //                              throw new GuideException("Could not find the class 'org.maltparser.ml.libsvm.malt04.LibsvmMalt04'. ", e);
180    //                      }
181    //              }
182                    Class<?>[] argTypes = { org.maltparser.parser.guide.instance.InstanceModel.class, java.lang.Integer.class };
183                    Object[] arguments = new Object[2];
184                    arguments[0] = this;
185                    if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
186                            arguments[1] = LearningMethod.CLASSIFY;
187                    } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
188                            arguments[1] = LearningMethod.BATCH;
189                    } 
190    
191                    try {   
192                            Constructor<?> constructor = clazz.getConstructor(argTypes);
193                            this.method = (LearningMethod)constructor.newInstance(arguments);
194                    } catch (NoSuchMethodException e) {
195                            throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
196                    } catch (InstantiationException e) {
197                            throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
198                    } catch (IllegalAccessException e) {
199                            throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
200                    } catch (InvocationTargetException e) {
201                            throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
202                    }
203            }
204            
205            
206            
207            /**
208             * Returns the parent guide model
209             * 
210             * @return the parent guide model
211             */
212            public Model getParent() throws MaltChainedException {
213                    if (parent == null) {
214                            throw new GuideException("The atomic model can only be used by a parent model. ");
215                    }
216                    return parent;
217            }
218    
219            /**
220             * Sets the parent guide model
221             * 
222             * @param parent the parent guide model
223             */
224            protected void setParent(Model parent) {
225                    this.parent = parent;
226            }
227    
228            public String getModelName() {
229                    return modelName;
230            }
231    
232            /**
233             * Sets the name of the atomic model
234             * 
235             * @param modelName the name of the atomic model
236             */
237            protected void setModelName(String modelName) {
238                    this.modelName = modelName;
239            }
240    
241            /**
242             * Returns the feature vector used by this atomic model
243             * 
244             * @return a feature vector object
245             */
246            public FeatureVector getFeatures() {
247                    return featureVector;
248            }
249    
250            /**
251             * Sets the feature vector used by the atomic model.
252             * 
253             * @param features a feature vector object
254             */
255            protected void setFeatures(FeatureVector features) {
256                    this.featureVector = features;
257            }
258    
259            public ClassifierGuide getGuide() {
260                    return parent.getGuide();
261            }
262            
263            /**
264             * Returns the index of the atomic model
265             * 
266             * @return the index of the atomic model
267             */
268            public int getIndex() {
269                    return index;
270            }
271    
272            /**
273             * Sets the index of the model (-1..n), where -1 is a special value.
274             * 
275             * @param index index value (-1..n) of the atomic model
276             */
277            protected void setIndex(int index) {
278                    this.index = index;
279            }
280    
281            /**
282             * Returns the frequency (number of instances)
283             * 
284             * @return the frequency (number of instances)
285             */
286            public int getFrequency() {
287                    return frequency;
288            }
289            
290            /**
291             * Increase the frequency by 1
292             */
293            public void increaseFrequency() {
294                    if (parent instanceof InstanceModel) {
295                            ((InstanceModel)parent).increaseFrequency();
296                    }
297                    frequency++;
298            }
299            
300            public void decreaseFrequency() {
301                    if (parent instanceof InstanceModel) {
302                            ((InstanceModel)parent).decreaseFrequency();
303                    }
304                    frequency--;
305            }
306            /**
307             * Sets the frequency (number of instances)
308             * 
309             * @param frequency (number of instances)
310             */
311            protected void setFrequency(int frequency) {
312                    this.frequency = frequency;
313            }
314            
315            /**
316             * Returns a learner object
317             * 
318             * @return a learner object
319             */
320            public LearningMethod getMethod() {
321                    return method;
322            }
323            
324            
325            /* (non-Javadoc)
326             * @see java.lang.Object#toString()
327             */
328            public String toString() {
329                    final StringBuilder sb = new StringBuilder();
330                    sb.append(method.toString());
331                    return sb.toString();
332            }
333    }