001    package org.maltparser.parser.algorithm.nivre;
002    
003    import java.util.Stack;
004    
005    import org.maltparser.core.exception.MaltChainedException;
006    import org.maltparser.core.symbol.SymbolTable;
007    import org.maltparser.core.symbol.SymbolTableHandler;
008    import org.maltparser.core.syntaxgraph.DependencyGraph;
009    import org.maltparser.core.syntaxgraph.DependencyStructure;
010    import org.maltparser.core.syntaxgraph.edge.Edge;
011    import org.maltparser.core.syntaxgraph.node.DependencyNode;
012    import org.maltparser.parser.ParserConfiguration;
013    import org.maltparser.parser.ParsingException;
014    /**
015     * @author Johan Hall
016     *
017     */
018    public class NivreConfig extends ParserConfiguration {
019            // Root Handling
020            public static final int STRICT = 1; //root tokens unattached, Reduce not permissible
021            public static final int RELAXED = 2; //root tokens unattached, Reduce permissible
022            public static final int NORMAL = 3; //root tokens attached to Root with RightArc
023            
024            private Stack<DependencyNode> stack;
025            private Stack<DependencyNode> input;
026            private DependencyStructure dependencyGraph;
027            private int rootHandling;
028    
029            
030            public NivreConfig(SymbolTableHandler symbolTableHandler, String rootHandling) throws MaltChainedException {
031                    super();
032                    stack = new Stack<DependencyNode>();
033                    input = new Stack<DependencyNode>();
034                    dependencyGraph = new DependencyGraph(symbolTableHandler);
035                    setRootHandling(rootHandling);
036            }
037            
038            public Stack<DependencyNode> getStack() {
039                    return stack;
040            }
041            
042            public Stack<DependencyNode> getInput() {
043                    return input;
044            }
045            
046            public DependencyStructure getDependencyStructure() {
047                    return dependencyGraph;
048            }
049            
050            public boolean isTerminalState() {
051                    return input.isEmpty();
052            }
053            
054            public DependencyNode getStackNode(int index) throws MaltChainedException {
055                    if (index < 0) {
056                            throw new ParsingException("Stack index must be non-negative in feature specification. ");
057                    }
058                    if (stack.size()-index > 0) {
059                            return stack.get(stack.size()-1-index);
060                    }
061                    return null;
062            }
063            
064            public DependencyNode getInputNode(int index) throws MaltChainedException {
065                    if (index < 0) {
066                            throw new ParsingException("Input index must be non-negative in feature specification. ");
067                    }
068                    if (input.size()-index > 0) {
069                            return input.get(input.size()-1-index);
070                    }       
071                    return null;
072            }
073            
074            public void setDependencyGraph(DependencyStructure source) throws MaltChainedException {
075                    dependencyGraph.clear();
076                    for (int index : source.getTokenIndices()) {
077                            DependencyNode gnode = source.getTokenNode(index);
078                            DependencyNode pnode = dependencyGraph.addTokenNode(gnode.getIndex());
079                            for (SymbolTable table : gnode.getLabelTypes()) {
080                                    pnode.addLabel(table, gnode.getLabelSymbol(table));
081                            }
082                            
083                            if (gnode.hasHead()) {
084                                    Edge s = gnode.getHeadEdge();
085                                    Edge t = dependencyGraph.addDependencyEdge(s.getSource().getIndex(), s.getTarget().getIndex());
086                                    
087                                    for (SymbolTable table : s.getLabelTypes()) {
088                                            t.addLabel(table, s.getLabelSymbol(table));
089                                    }
090                            }
091                    }
092                    for (SymbolTable table : source.getDefaultRootEdgeLabels().keySet()) {
093                            dependencyGraph.setDefaultRootEdgeLabel(table, source.getDefaultRootEdgeLabelSymbol(table));
094                    }
095            }
096            
097            public DependencyStructure getDependencyGraph() {
098                    return dependencyGraph;
099            }
100            
101            public void initialize(ParserConfiguration parserConfiguration) throws MaltChainedException {
102                    if (parserConfiguration != null) {
103                            NivreConfig nivreConfig = (NivreConfig)parserConfiguration;
104                            Stack<DependencyNode> sourceStack = nivreConfig.getStack();
105                            Stack<DependencyNode> sourceInput = nivreConfig.getInput();
106                            setDependencyGraph(nivreConfig.getDependencyGraph());
107                            for (int i = 0, n = sourceStack.size(); i < n; i++) {
108                                    stack.add(dependencyGraph.getDependencyNode(sourceStack.get(i).getIndex()));
109                            }
110                            for (int i = 0, n = sourceInput.size(); i < n; i++) {
111                                    input.add(dependencyGraph.getDependencyNode(sourceInput.get(i).getIndex()));
112                            }
113                    } else {
114                            stack.push(dependencyGraph.getDependencyRoot());
115                            for (int i = dependencyGraph.getHighestTokenIndex(); i > 0; i--) {
116                                    final DependencyNode node = dependencyGraph.getDependencyNode(i);
117                                    if (node != null && !node.hasHead()) { // added !node.hasHead()
118                                            input.push(node);
119                                    }
120                            }
121                    }
122            }
123            
124            public int getRootHandling() {
125                    return rootHandling;
126            }
127    
128            public void setRootHandling(int rootHandling) {
129                    this.rootHandling = rootHandling;
130            }
131            
132            protected void setRootHandling(String rh) throws MaltChainedException {
133                    if (rh.equalsIgnoreCase("strict")) {
134                            rootHandling = STRICT;
135                    } else if (rh.equalsIgnoreCase("relaxed")) {
136                            rootHandling = RELAXED;
137                    } else if (rh.equalsIgnoreCase("normal")) {
138                            rootHandling = NORMAL;
139                    } else {
140                            throw new ParsingException("The root handling '"+rh+"' is unknown");
141                    }
142            }
143            
144            public void clear() throws MaltChainedException {
145                    dependencyGraph.clear();
146                    stack.clear();
147                    input.clear();
148                    historyNode = null;
149            }
150            
151            public boolean equals(Object obj) {
152                    if (this == obj)
153                            return true;
154                    if (obj == null)
155                            return false;
156                    if (getClass() != obj.getClass())
157                            return false;
158                    NivreConfig that = (NivreConfig)obj;
159                    
160                    if (stack.size() != that.getStack().size()) 
161                            return false;
162                    if (input.size() != that.getInput().size())
163                            return false;
164                    if (dependencyGraph.nEdges() != that.getDependencyGraph().nEdges())
165                            return false;
166                    for (int i = 0; i < stack.size(); i++) {
167                            if (stack.get(i).getIndex() != that.getStack().get(i).getIndex()) {
168                                    return false;
169                            }
170                    }
171                    for (int i = 0; i < input.size(); i++) {
172                            if (input.get(i).getIndex() != that.getInput().get(i).getIndex()) {
173                                    return false;
174                            }
175                    }               
176                    return dependencyGraph.getEdges().equals(that.getDependencyGraph().getEdges());
177            }
178            
179            public String toString() {
180                    final StringBuilder sb = new StringBuilder();
181                    sb.append(stack.size());
182                    sb.append(", ");
183                    sb.append(input.size());
184                    sb.append(", ");
185                    sb.append(dependencyGraph.nEdges());
186                    return sb.toString();
187            }
188    }