001    package org.maltparser.core.symbol.trie;
002    
003    import java.io.BufferedReader;
004    import java.io.BufferedWriter;
005    import java.io.IOException;
006    import java.util.Set;
007    import java.util.SortedMap;
008    import java.util.TreeMap;
009    
010    import org.apache.log4j.Logger;
011    import org.maltparser.core.exception.MaltChainedException;
012    import org.maltparser.core.io.dataformat.ColumnDescription;
013    import org.maltparser.core.symbol.SymbolException;
014    import org.maltparser.core.symbol.SymbolTable;
015    import org.maltparser.core.symbol.nullvalue.InputNullValues;
016    import org.maltparser.core.symbol.nullvalue.NullValues;
017    import org.maltparser.core.symbol.nullvalue.OutputNullValues;
018    import org.maltparser.core.symbol.nullvalue.NullValues.NullValueId;
019    /**
020    
021    @author Johan Hall
022    @since 1.0
023    */
024    public class TrieSymbolTable implements SymbolTable {
025            private final String name;
026            private final Trie trie;
027            private final SortedMap<Integer, TrieNode> codeTable;
028            private int columnCategory;
029            private NullValues nullValues;
030            private int valueCounter;
031        /** Cache the hash code for the symbol table */
032        private int cachedHash;
033        
034            public TrieSymbolTable(String name, Trie trie, int columnCategory, String nullValueStrategy) throws MaltChainedException {
035                    this.name = name;
036                    this.trie = trie;
037                    this.columnCategory = columnCategory;
038                    codeTable = new TreeMap<Integer, TrieNode>();
039                    if (columnCategory == ColumnDescription.INPUT) {
040                            nullValues = new InputNullValues(nullValueStrategy, this);
041                    } else if (columnCategory == ColumnDescription.DEPENDENCY_EDGE_LABEL) {
042                            nullValues = new OutputNullValues(nullValueStrategy, this, null);
043                    } else {
044                            nullValues = new InputNullValues(nullValueStrategy, this);
045                    }
046                    valueCounter = nullValues.getNextCode();
047            }
048    
049            public TrieSymbolTable(String name,  Trie trie, int columnCategory, String nullValueStrategy, String rootLabel) throws MaltChainedException {
050                    this.name = name;
051                    this.trie = trie;
052                    this.columnCategory = columnCategory;
053                    codeTable = new TreeMap<Integer, TrieNode>();
054                    if (columnCategory == ColumnDescription.INPUT) {
055                            nullValues = new InputNullValues(nullValueStrategy, this);
056                    } else if (columnCategory == ColumnDescription.DEPENDENCY_EDGE_LABEL) {
057                            nullValues = new OutputNullValues(nullValueStrategy, this, rootLabel);
058                    }
059                    valueCounter = nullValues.getNextCode();
060            }
061            
062            public TrieSymbolTable(String name, Trie trie) {
063                    this.name = name;
064                    this.trie = trie;
065                    codeTable = new TreeMap<Integer, TrieNode>();
066                    nullValues = new InputNullValues("one", this);
067                    //nullValues = null;
068                    valueCounter = 1;
069            }
070            
071            public int addSymbol(String symbol) throws MaltChainedException {
072                    if (nullValues == null || !nullValues.isNullValue(symbol)) {
073                            final TrieNode node = trie.addValue(symbol, this, -1);
074                            final int code = node.getEntry(this).getCode();
075                            if (!codeTable.containsKey(code)) {
076                                    codeTable.put(code, node);
077                            }
078                            return code;
079                    } else {
080                            return nullValues.symbolToCode(symbol);
081                    }
082            }
083            
084            public int addSymbol(StringBuilder symbol) throws MaltChainedException {
085                    if (nullValues == null || !nullValues.isNullValue(symbol)) {
086                            final TrieNode node = trie.addValue(symbol, this, -1);
087                            final int code = node.getEntry(this).getCode();
088                            if (!codeTable.containsKey(code)) {
089                                    codeTable.put(code, node);
090                            }
091                            return code;
092                    } else {
093                            return nullValues.symbolToCode(symbol);
094                    }
095            }
096            
097            public String getSymbolCodeToString(int code) throws MaltChainedException {
098                    if (code >= 0) {
099                            if (nullValues == null || !nullValues.isNullValue(code)) {
100                                    if (trie == null) {
101                                            throw new SymbolException("The symbol table is corrupt. ");
102                                    }
103                                    return trie.getValue(codeTable.get(code), this);
104                            } else {
105                                    return nullValues.codeToSymbol(code);
106                            }
107                    } else {
108                            throw new SymbolException("The symbol code '"+code+"' cannot be found in the symbol table. ");
109                    }
110            }
111            
112            public int getSymbolStringToCode(String symbol) throws MaltChainedException {
113                    if (symbol != null) {
114                            if (nullValues == null || !nullValues.isNullValue(symbol)) {
115                                    if (trie == null) {
116                                            throw new SymbolException("The symbol table is corrupt. ");
117                                    } 
118                                    final TrieEntry entry = trie.getEntry(symbol, this);
119                                    if (entry == null) {
120                                            throw new SymbolException("Could not find the symbol '"+symbol+"' in the symbol table. ");
121                                    }
122                                    return entry.getCode();                         
123                            } else {
124                                    return nullValues.symbolToCode(symbol);
125                            }
126                    } else {
127                            throw new SymbolException("The symbol code '"+symbol+"' cannot be found in the symbol table. ");
128                    }
129            }
130    
131            public String getNullValueStrategy() {
132                    if (nullValues == null) {
133                            return null;
134                    }
135                    return nullValues.getNullValueStrategy();
136            }
137            
138            
139            public int getColumnCategory() {
140                    return columnCategory;
141            }
142    
143            public boolean getKnown(int code) {
144                    if (code >= 0) {
145                            if (nullValues == null || !nullValues.isNullValue(code)) {
146                                    return codeTable.get(code).getEntry(this).isKnown();
147                            } else {
148                                    return true;
149                            }
150                    } else {
151                            return false;
152                    }
153            }
154    
155            public boolean getKnown(String symbol) {
156                    if (nullValues == null || !nullValues.isNullValue(symbol)) {
157                            final TrieEntry entry = trie.getEntry(symbol, this);
158                            if (entry == null) {
159                                    return false;
160                            }
161                            return entry.isKnown();
162                    } else {
163                            return true;
164                    }
165            }
166            
167            public void makeKnown(int code) {
168                    if (code >= 0) {
169                            if (nullValues == null || !nullValues.isNullValue(code)) {
170                                    codeTable.get(code).getEntry(this).setKnown(true);
171                            } 
172                    }
173            }
174            
175            public void printSymbolTable(Logger logger) throws MaltChainedException {
176                    for (Integer code : codeTable.keySet()) {
177                            logger.info(code+"\t"+trie.getValue(codeTable.get(code), this)+"\n");
178                    }
179            }
180            
181            public void saveHeader(BufferedWriter out) throws MaltChainedException  {
182                    try {
183                            out.append('\t');
184                            out.append(getName());
185                            out.append('\t');
186                            out.append(Integer.toString(getColumnCategory()));
187                            out.append('\t');
188                            out.append(getNullValueStrategy());
189                            out.append('\t');
190                            if (nullValues instanceof OutputNullValues && ((OutputNullValues)nullValues).getRootLabel() != null) {
191                                    out.append(((OutputNullValues)nullValues).getRootLabel());
192                            } else {
193                                    out.append("#DUMMY#");
194                            }
195                            out.append('\n');
196                    } catch (IOException e) {
197                            throw new SymbolException("Could not save the symbol table. ", e);
198                    }
199            }
200            
201            public int size() {
202                    return codeTable.size();
203            }
204            
205            public void save(BufferedWriter out) throws MaltChainedException  {
206                    try {
207                            out.write(name);
208                            out.write('\n');
209                            for (Integer code : codeTable.keySet()) {
210                                    out.write(code+"");
211                                    out.write('\t');
212                                    out.write(trie.getValue(codeTable.get(code), this));
213                                    out.write('\n');
214                            }
215                            out.write('\n');
216                    } catch (IOException e) {
217                            throw new SymbolException("Could not save the symbol table. ", e);
218                    }
219            }
220            
221            public void load(BufferedReader in) throws MaltChainedException {
222                    int max = 0;
223                    int index = 0;
224                    String fileLine;
225                    try {
226                            while ((fileLine = in.readLine()) != null) {
227                                    if (fileLine.length() == 0 || (index = fileLine.indexOf('\t')) == -1) {
228                                            setValueCounter(max+1);
229                                            break;
230                                    }
231                                    int code = Integer.parseInt(fileLine.substring(0,index));
232                                    final String str = fileLine.substring(index+1);
233                                    final TrieNode node = trie.addValue(str, this, code);
234                                    codeTable.put(node.getEntry(this).getCode(), node);
235                                    if (max < code) {
236                                            max = code;
237                                    }
238                            }
239                    } catch (NumberFormatException e) {
240                            throw new SymbolException("The symbol table file (.sym) contains a non-integer value in the first column. ", e);
241                    } catch (IOException e) {
242                            throw new SymbolException("Could not load the symbol table. ", e);
243                    }
244            }
245            
246            public String getName() {
247                    return name;
248            }
249    
250            public int getValueCounter() {
251                    return valueCounter;
252            }
253    
254            private void setValueCounter(int valueCounter) {
255                    this.valueCounter = valueCounter;
256            }
257            
258            protected void updateValueCounter(int code) {
259                    if (code > valueCounter) {
260                            valueCounter = code;
261                    }
262            }
263            
264            protected int increaseValueCounter() {
265                    return valueCounter++;
266            }
267            
268            public int getNullValueCode(NullValueId nullValueIdentifier) throws MaltChainedException {
269                    if (nullValues == null) {
270                            throw new SymbolException("The symbol table does not have any null-values. ");
271                    }
272                    return nullValues.nullvalueToCode(nullValueIdentifier);
273            }
274            
275            public String getNullValueSymbol(NullValueId nullValueIdentifier) throws MaltChainedException {
276                    if (nullValues == null) {
277                            throw new SymbolException("The symbol table does not have any null-values. ");
278                    }
279                    return nullValues.nullvalueToSymbol(nullValueIdentifier);
280            }
281            
282            public boolean isNullValue(String symbol) throws MaltChainedException {
283                    if (nullValues != null) {
284                            return nullValues.isNullValue(symbol);
285                    } 
286                    return false;
287            }
288            
289            public boolean isNullValue(int code) throws MaltChainedException {
290                    if (nullValues != null) {
291                            return nullValues.isNullValue(code);
292                    } 
293                    return false;
294            }
295            
296            public void copy(SymbolTable fromTable) throws MaltChainedException {
297                    final SortedMap<Integer, TrieNode> fromCodeTable =  ((TrieSymbolTable)fromTable).getCodeTable();
298                    int max = getValueCounter()-1;
299                    for (Integer code : fromCodeTable.keySet()) {
300                            final String str = trie.getValue(fromCodeTable.get(code), this);
301                            final TrieNode node = trie.addValue(str, this, code);
302                            codeTable.put(node.getEntry(this).getCode(), node);
303                            if (max < code) {
304                                    max = code;
305                            }
306                    }
307                    setValueCounter(max+1);
308            }
309    
310            public SortedMap<Integer, TrieNode> getCodeTable() {
311                    return codeTable;
312            }
313            
314            public Set<Integer> getCodes() {
315                    return codeTable.keySet();
316            }
317            
318            protected Trie getTrie() {
319                    return trie;
320            }
321            
322            public boolean equals(Object obj) {
323                    if (this == obj)
324                            return true;
325                    if (obj == null)
326                            return false;
327                    if (getClass() != obj.getClass())
328                            return false;
329                    return ((name == null) ? ((TrieSymbolTable)obj).name == null : name.equals(((TrieSymbolTable)obj).name));
330            }
331    
332            public int hashCode() {
333                    if (cachedHash == 0) {
334                            cachedHash = 31 * 7 + (null == name ? 0 : name.hashCode());
335                    }
336                    return cachedHash;
337            }
338            
339            public String toString() {
340                    final StringBuilder sb = new StringBuilder();
341                    sb.append(name);
342                    sb.append(" ");
343                    sb.append(valueCounter);
344                    return sb.toString();
345            }
346    }