001    package org.maltparser.parser.algorithm.stack;
002    
003    import java.util.ArrayList;
004    import java.util.Stack;
005    
006    import org.maltparser.core.exception.MaltChainedException;
007    import org.maltparser.core.syntaxgraph.DependencyStructure;
008    import org.maltparser.core.syntaxgraph.node.DependencyNode;
009    import org.maltparser.parser.DependencyParserConfig;
010    import org.maltparser.parser.Oracle;
011    import org.maltparser.parser.ParserConfiguration;
012    import org.maltparser.parser.history.GuideUserHistory;
013    import org.maltparser.parser.history.action.GuideUserAction;
014    /**
015     * @author Johan Hall
016     *
017     */
018    public class SwapEagerOracle extends Oracle {
019            private ArrayList<Integer> swapArray;
020            private boolean swapArrayActive = false;
021            
022            public SwapEagerOracle(DependencyParserConfig manager, GuideUserHistory history) throws MaltChainedException {
023                    super(manager, history);
024                    setGuideName("swapeager");
025                    swapArray = new ArrayList<Integer>();
026            }
027            
028            public GuideUserAction predict(DependencyStructure gold, ParserConfiguration configuration) throws MaltChainedException {
029                    StackConfig config = (StackConfig)configuration;
030                    Stack<DependencyNode> stack = config.getStack();
031    
032                    if (!swapArrayActive) {
033                            createSwapArray(gold);
034                            swapArrayActive = true;
035                    }
036                    GuideUserAction action = null;
037                    if (stack.size() < 2) {
038                            action = updateActionContainers(NonProjective.SHIFT, null);
039                    } else {
040                            DependencyNode left = stack.get(stack.size()-2);
041                            int leftIndex = left.getIndex();
042                            int rightIndex = stack.get(stack.size()-1).getIndex();
043                            if (swapArray.get(leftIndex) > swapArray.get(rightIndex)) {
044                                    action =  updateActionContainers(NonProjective.SWAP, null);
045                            } else if (!left.isRoot() && gold.getTokenNode(leftIndex).getHead().getIndex() == rightIndex
046                                            && nodeComplete(gold, config.getDependencyGraph(), leftIndex)) {
047                                    action = updateActionContainers(NonProjective.LEFTARC, gold.getTokenNode(leftIndex).getHeadEdge().getLabelSet());
048                            } else if (gold.getTokenNode(rightIndex).getHead().getIndex() == leftIndex
049                                            && nodeComplete(gold, config.getDependencyGraph(), rightIndex)) {
050                                    action = updateActionContainers(NonProjective.RIGHTARC, gold.getTokenNode(rightIndex).getHeadEdge().getLabelSet());
051                            } else {
052                                    action = updateActionContainers(NonProjective.SHIFT, null);
053                            }
054                    }
055                    return action;
056            }
057            
058            private boolean nodeComplete(DependencyStructure gold, DependencyStructure parseDependencyGraph, int nodeIndex) {
059                    if (gold.getTokenNode(nodeIndex).hasLeftDependent()) {
060                            if (!parseDependencyGraph.getTokenNode(nodeIndex).hasLeftDependent()) {
061                                    return false;
062                            } else if (gold.getTokenNode(nodeIndex).getLeftmostDependent().getIndex() != parseDependencyGraph.getTokenNode(nodeIndex).getLeftmostDependent().getIndex()) {
063                                    return false;
064                            }
065                    }
066                    if (gold.getTokenNode(nodeIndex).hasRightDependent()) {
067                            if (!parseDependencyGraph.getTokenNode(nodeIndex).hasRightDependent()) {
068                                    return false;
069                            } else if (gold.getTokenNode(nodeIndex).getRightmostDependent().getIndex() != parseDependencyGraph.getTokenNode(nodeIndex).getRightmostDependent().getIndex()) {
070                                    return false;
071                            }
072                    }
073                    return true;
074            }
075            
076    //      private boolean checkRightDependent(DependencyStructure gold, DependencyStructure parseDependencyGraph, int index) throws MaltChainedException {
077    //              if (gold.getTokenNode(index).getRightmostDependent() == null) {
078    //                      return true;
079    //              } else if (parseDependencyGraph.getTokenNode(index).getRightmostDependent() != null) {
080    //                      if (gold.getTokenNode(index).getRightmostDependent().getIndex() == parseDependencyGraph.getTokenNode(index).getRightmostDependent().getIndex()) {
081    //                              return true;
082    //                      }
083    //              }
084    //              return false;
085    //      }
086            
087            public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
088                    swapArrayActive = false;
089            }
090            
091            public void terminate() throws MaltChainedException {
092            }
093            
094            private void createSwapArray(DependencyStructure goldDependencyGraph) throws MaltChainedException {
095                    swapArray.clear();
096                    for (int i = 0; i <= goldDependencyGraph.getHighestDependencyNodeIndex(); i++) {
097                            swapArray.add(new Integer(i));
098                    }
099                    createSwapArray(goldDependencyGraph.getDependencyRoot(), 0);
100            }
101            
102            private int createSwapArray(DependencyNode n, int order) {
103                    int o = order; 
104                    if (n != null) {
105                            for (int i=0; i < n.getLeftDependentCount(); i++) {
106                                    o = createSwapArray(n.getLeftDependent(i), o);
107                            }
108                            swapArray.set(n.getIndex(), o++);
109                            for (int i=n.getRightDependentCount(); i >= 0; i--) {
110                                    o = createSwapArray(n.getRightDependent(i), o);
111                            }
112                    }
113                    return o;
114            }
115    }