1 | package felix.operator; |
2 | |
3 | |
4 | import java.io.BufferedWriter; |
5 | import java.io.FileInputStream; |
6 | import java.io.FileWriter; |
7 | import java.util.ArrayList; |
8 | import java.util.Arrays; |
9 | import java.util.HashMap; |
10 | import java.util.HashSet; |
11 | import java.util.Set; |
12 | |
13 | import org.postgresql.PGConnection; |
14 | |
15 | import tuffy.db.RDB; |
16 | import tuffy.mln.Predicate; |
17 | import tuffy.ra.ConjunctiveQuery; |
18 | import tuffy.util.Config; |
19 | import tuffy.util.ExceptionMan; |
20 | import tuffy.util.StringMan; |
21 | import tuffy.util.Timer; |
22 | import tuffy.util.UIMan; |
23 | import felix.dstruct.DataMovementOperator; |
24 | import felix.dstruct.FelixPredicate; |
25 | import felix.dstruct.FelixQuery; |
26 | import felix.dstruct.StatOperator; |
27 | import felix.dstruct.FelixPredicate.FPProperty; |
28 | import felix.parser.FelixCommandOptions; |
29 | import felix.util.FelixConfig; |
30 | import felix.util.FelixStringMan; |
31 | import felix.util.FelixUIMan; |
32 | |
33 | /** |
34 | * A CRF operator in Felix. |
35 | * @author Ce Zhang |
36 | * |
37 | */ |
38 | public class CRFOperator extends StatOperator{ |
39 | |
40 | /** |
41 | * Target predicate of this CRF operator. |
42 | */ |
43 | FelixPredicate crfHead; |
44 | |
45 | /** |
46 | * Mapping from label's constant ID to a new label ID. This new label ID is from 0, which |
47 | * is used in the inference of CRF (s.t., we can use array to represent |
48 | * labels). |
49 | */ |
50 | HashMap<String, Integer> label2ID = new HashMap<String, Integer>(); |
51 | |
52 | /** |
53 | * The inverse map of {@link CRFOperator#label2ID}. Here String[] means |
54 | * there can be multiple fields corresponding to CRF's labeling field. |
55 | */ |
56 | HashMap<Integer, String[]> id2Label = new HashMap<Integer, String[]>(); |
57 | |
58 | /** |
59 | * All DataMovementOperators used as LR rules. |
60 | */ |
61 | ArrayList<DataMovementOperator> lrDMOs = new ArrayList<DataMovementOperator>(); |
62 | |
63 | /** |
64 | * The DataMovementOperator used as CRF Chain rule. |
65 | */ |
66 | DataMovementOperator crfDMO = null; |
67 | |
68 | /** |
69 | * The DataMovementOperator which is the union of all LR rules. |
70 | */ |
71 | DataMovementOperator lrDMO = null; |
72 | |
73 | /** |
74 | * The DataMovementOperator representing the table/view for all unigram features. |
75 | */ |
76 | DataMovementOperator unigramDMO = null; |
77 | |
78 | /** |
79 | * The DataMovementOperator representing the table/view for all bigram features. |
80 | */ |
81 | DataMovementOperator bigramDMO = null; |
82 | |
83 | /** |
84 | * The DataMovementOperator representing the table/view for all possible labels. |
85 | */ |
86 | DataMovementOperator labelDomainDMO = null; |
87 | |
88 | /** |
89 | * The DataMovementOperator fetching a partition of grounded literals (which may |
90 | * be a sequence or multiples sequences) |
91 | */ |
92 | DataMovementOperator getAllPossiblePartitioningDMO = null; |
93 | |
94 | /** |
95 | * The DataMovementOperator fetching all bigram features for a given sequence. |
96 | */ |
97 | DataMovementOperator getBigramFeaturesForPartitioningDMO = null; |
98 | |
99 | /** |
100 | * The DataMovementOperator fetching all unigram features for a given sequence |
101 | */ |
102 | DataMovementOperator getUnigramFeaturesForPartitioningDMO = null; |
103 | |
104 | /** |
105 | * The DataMovementOperator fetching all bigram features. |
106 | */ |
107 | DataMovementOperator getAllBigramFeaturesDMO = null; |
108 | |
109 | /** |
110 | * The DataMovementOperator fetching all unigram features. |
111 | */ |
112 | DataMovementOperator getAllUnigramFeaturesDMO = null; |
113 | |
114 | public int nRuns = 0; |
115 | |
116 | /** |
117 | * The constructor of CRFOperator. |
118 | * @param _fq Felix query. |
119 | * @param _goalPredicates target predicates of this coref operator. |
120 | * @param _opt Command line options of this Felix run. |
121 | */ |
122 | public CRFOperator(FelixQuery _fq, HashSet<FelixPredicate> _goalPredicates, |
123 | FelixCommandOptions _opt) { |
124 | super(_fq, _goalPredicates, _opt); |
125 | this.type = OPType.CRF; |
126 | this.precedence = 5; |
127 | } |
128 | |
129 | boolean prepared = false; |
130 | |
131 | /** |
132 | * Prepares operator for execution. |
133 | */ |
134 | @Override |
135 | public void prepare() { |
136 | |
137 | this.lrDMOs.clear(); |
138 | allDMOs.clear(); |
139 | |
140 | //if(!prepared){ |
141 | |
142 | db = RDB.getRDBbyConfig(Config.db_schema); |
143 | |
144 | crfHead = this.getTargetPredicateIfHasOnlyOne(); |
145 | HashSet<ConjunctiveQuery> chainQueries = |
146 | this.translateFelixClasesIntoFactorGraphEdgeQueries(crfHead, false, this.inputPredicateScope, FPProperty.CHAIN_RECUR); |
147 | HashSet<ConjunctiveQuery> lrQueries = |
148 | this.translateFelixClasesIntoFactorGraphEdgeQueries(crfHead, false, this.inputPredicateScope, FPProperty.NON_RECUR); |
149 | |
150 | this.prepareDMO(lrQueries, chainQueries); |
151 | |
152 | prepared = true; |
153 | //} |
154 | |
155 | } |
156 | |
157 | /** |
158 | * Return the signature of a string array as a string. |
159 | * @param _array |
160 | * @return |
161 | */ |
162 | String array2str(String[] _array){ |
163 | String ret = ""; |
164 | for(String s : _array){ |
165 | ret = ret + ":" + s; |
166 | } |
167 | return ret; |
168 | } |
169 | |
170 | /** |
171 | * Executes operator. |
172 | */ |
173 | @Override |
174 | public void run() { |
175 | |
176 | nRuns ++; |
177 | |
178 | if(!this.options.useDualDecomposition){ |
179 | crfHead.setHasSoftEvidence(true); |
180 | } |
181 | |
182 | UIMan.println(">>> Start Running " + this); |
183 | |
184 | Timer.start("CRF-Operator-" + crfHead.getName() + "-" + this.getId()); |
185 | |
186 | try{ |
187 | BufferedWriter bw = new BufferedWriter( |
188 | new FileWriter(Config.getLoadingDir() + "/_loading_crf_" + |
189 | crfHead.getName() + "_op" + this.getId())); |
190 | |
191 | this.labelDomainDMO.execute(null, new ArrayList<Integer>()); |
192 | int labelID = 0; |
193 | int nLabelFileds = crfHead.getLabelFieldsArgs().size(); |
194 | |
195 | while(this.labelDomainDMO.next()){ |
196 | String[] currLabel = new String[nLabelFileds]; |
197 | |
198 | for(int i=0;i<nLabelFileds;i++){ |
199 | currLabel[i] = this.labelDomainDMO.getNext(i+1).toString(); |
200 | } |
201 | |
202 | this.id2Label.put(labelID, currLabel); |
203 | this.label2ID.put(array2str(currLabel), labelID); |
204 | labelID++; |
205 | } |
206 | |
207 | if(id2Label.size() == 0){ |
208 | return; |
209 | } |
210 | |
211 | |
212 | if(crfHead.getCRFPartitionFields() != null){ |
213 | |
214 | db.disableAutoCommitForNow(); |
215 | this.fastInfer(bw); |
216 | db.commit(); |
217 | db.restoreAutoCommitState(); |
218 | |
219 | }else{ |
220 | UIMan.warn("Your CRF Rule Cannot be partitioned into different components..."); |
221 | |
222 | ExceptionMan.die("Rewritting your rules may be a better idea... " + |
223 | "Or simply use -noCRF option"); |
224 | |
225 | /** |
226 | db.disableAutoCommitForNow(); |
227 | this.slowInfer(bw); |
228 | db.commit(); |
229 | db.restoreAutoCommitState(); |
230 | **/ |
231 | } |
232 | |
233 | bw.close(); |
234 | |
235 | |
236 | FileInputStream in = new FileInputStream(Config.getLoadingDir() + |
237 | "/_loading_crf_" + crfHead.getName() + "_op" + this.getId()); |
238 | |
239 | PGConnection con = (PGConnection) db.getConnection(); |
240 | |
241 | String sql; |
242 | |
243 | if(options.useDualDecomposition){ |
244 | for(FelixPredicate fp : this.dd_CommonOutput){ |
245 | if(!fp.getName().equals(this.crfHead.getName())){ |
246 | ExceptionMan.die("ERROR: I am not fuzzy-LR/CRF/COREF!!! Contact us!!!"); |
247 | continue; |
248 | } |
249 | |
250 | |
251 | in = new FileInputStream(Config.getLoadingDir() + |
252 | "/_loading_crf_" + crfHead.getName() + "_op" + this.getId()); |
253 | String tableName = this.dd_commonOutputPredicate_2_tableName.get(fp); |
254 | |
255 | |
256 | sql = "COPY " + tableName + "(truth, prior, club, " + StringMan.commaList(crfHead.getArgs()) + " ) FROM STDIN CSV"; |
257 | con.getCopyAPI().copyIn(sql, in); |
258 | in.close(); |
259 | |
260 | } |
261 | |
262 | if(FelixConfig.isFirstRunOfDD){ |
263 | in = new FileInputStream(Config.getLoadingDir() + |
264 | "/_loading_crf_" + crfHead.getName() + "_op" + this.getId()); |
265 | |
266 | sql = "COPY " + crfHead.getRelName() + "(truth, prior, club, " + StringMan.commaList(crfHead.getArgs()) + " ) FROM STDIN CSV"; |
267 | con.getCopyAPI().copyIn(sql, in); |
268 | in.close(); |
269 | } |
270 | crfHead.isCurrentlyView = false; |
271 | |
272 | }else{ |
273 | |
274 | sql = "COPY " + crfHead.getRelName() + "(truth, prior, club, " + StringMan.commaList(crfHead.getArgs()) + " ) FROM STDIN CSV"; |
275 | con.getCopyAPI().copyIn(sql, in); |
276 | in.close(); |
277 | crfHead.isCurrentlyView = false; |
278 | |
279 | } |
280 | |
281 | FelixUIMan.println(0,0,"\n>>> {" + this + "} uses " + Timer.elapsed("CRF-Operator-" + crfHead.getName() + "-" + this.getId())); |
282 | |
283 | db.close(); |
284 | |
285 | if(!options.useDualDecomposition){ |
286 | this.belongsToBucket.runNextOperatorInBucket(); |
287 | } |
288 | |
289 | }catch(Exception e){ |
290 | e.printStackTrace(); |
291 | } |
292 | |
293 | } |
294 | |
295 | @Override |
296 | public String explain() { |
297 | //TODO: |
298 | return null; |
299 | } |
300 | |
301 | /** |
302 | * Conduct CRF infer WITH knowledge about partitioning which is parsed statically |
303 | * from the input program. |
304 | * |
305 | * @param bw Buffered writer to dump results. |
306 | */ |
307 | public void fastInfer(BufferedWriter bw){ |
308 | |
309 | //get the whole sequence in db based on provided seqHead |
310 | |
311 | try { |
312 | |
313 | if(this.getAllUnigramFeaturesDMO != null){ |
314 | this.getAllUnigramFeaturesDMO.execute(null, new ArrayList<Integer>()); |
315 | this.getAllUnigramFeaturesDMO.next(); |
316 | }else{ |
317 | //TODO: |
318 | return; |
319 | } |
320 | |
321 | boolean reachUnigramEnd = false; |
322 | |
323 | this.getAllPossiblePartitioningDMO.execute(null, new ArrayList<Integer>()); |
324 | int ctt = 0; |
325 | while(this.getAllPossiblePartitioningDMO.next()){ |
326 | |
327 | //System.out.println(ctt++); |
328 | |
329 | String partSignature = ""; |
330 | ArrayList<Integer> bindings = new ArrayList<Integer>(); |
331 | ArrayList<String> toSig = new ArrayList<String>(); |
332 | for(String s : crfHead.getCRFPartitionFields()){ |
333 | toSig.add(this.getAllPossiblePartitioningDMO.getNext(s) + ""); |
334 | bindings.add(this.getAllPossiblePartitioningDMO.getNext(s)); |
335 | } |
336 | partSignature = StringMan.commaList(toSig); |
337 | |
338 | Sequence imSeq = new Sequence(crfHead, null, id2Label, label2ID); |
339 | |
340 | String[] currLabel = new String[crfHead.getLabelFieldsArgs().size()]; |
341 | String[] prevLabel = new String[crfHead.getLabelFieldsArgs().size()]; |
342 | |
343 | while(true){ |
344 | |
345 | if(reachUnigramEnd == true){ |
346 | break; |
347 | } |
348 | |
349 | String curPartSignature = ""; |
350 | toSig = new ArrayList<String>(); |
351 | for(String s : crfHead.getCRFPartitionFields()){ |
352 | toSig.add(this.getAllUnigramFeaturesDMO.getNext(s) + ""); |
353 | } |
354 | curPartSignature = StringMan.commaList(toSig); |
355 | |
356 | if(curPartSignature.equals(partSignature)){ |
357 | |
358 | String currSignature = ""; |
359 | int lct = 0; |
360 | toSig = new ArrayList<String>(); |
361 | for(int i=0;i<crfHead.arity();i++){ |
362 | if(crfHead.getLabelPositions().contains(i)){ |
363 | currLabel[lct++] = this.getAllUnigramFeaturesDMO.getNext(i+1).toString(); |
364 | toSig.add("%s"); |
365 | }else{ |
366 | toSig.add(this.getAllUnigramFeaturesDMO.getNext(i+1).toString()); |
367 | } |
368 | } |
369 | currSignature = StringMan.commaList(toSig); |
370 | |
371 | Double weight = this.getAllUnigramFeaturesDMO.getNextDouble(crfHead.arity()+1); |
372 | |
373 | imSeq.registerNodeIfNotExist(currSignature); |
374 | imSeq.registerUnigramFeatures(currSignature, this.label2ID.get(array2str(currLabel)), weight); |
375 | |
376 | if(this.getAllUnigramFeaturesDMO.next() == null){ |
377 | reachUnigramEnd = true; |
378 | break; |
379 | } |
380 | |
381 | }else{ |
382 | break; |
383 | } |
384 | } |
385 | |
386 | this.getBigramFeaturesForPartitioningDMO.execute(null, bindings); |
387 | |
388 | /* |
389 | System.err.println(); |
390 | System.err.println(this.crfDMO.getAllFreeViewName()); |
391 | System.err.println(this.crfDMO.physicalQueryPlan.objectConjunctiveQuery); |
392 | System.err.println(this.crfDMO.physicalQueryPlan.objectPreparedStatement); |
393 | System.err.println(this.getBigramFeaturesForPartitioningDMO.physicalQueryPlan.objectPreparedStatement); |
394 | */ |
395 | |
396 | while(getBigramFeaturesForPartitioningDMO.next()){ |
397 | |
398 | String prevSignature = ""; |
399 | |
400 | toSig = new ArrayList<String>(); |
401 | int lct = 0; |
402 | for(int i=0;i<crfHead.arity();i++){ |
403 | if(crfHead.getLabelPositions().contains(i)){ |
404 | prevLabel[lct++] = this.getBigramFeaturesForPartitioningDMO.getNext(i+1).toString(); |
405 | toSig.add("%s"); |
406 | }else{ |
407 | toSig.add(this.getBigramFeaturesForPartitioningDMO.getNext(i+1).toString()); |
408 | } |
409 | } |
410 | prevSignature = StringMan.commaList(toSig); |
411 | |
412 | String currSignature = ""; |
413 | lct = 0; |
414 | toSig = new ArrayList<String>(); |
415 | for(int i=crfHead.arity();i< 2* crfHead.arity();i++){ |
416 | if(crfHead.getLabelPositions().contains(i - crfHead.arity())){ |
417 | currLabel[lct++] = this.getBigramFeaturesForPartitioningDMO.getNext(i+1).toString(); |
418 | toSig.add("%s"); |
419 | }else{ |
420 | toSig.add(this.getBigramFeaturesForPartitioningDMO.getNext(i+1).toString()); |
421 | } |
422 | } |
423 | currSignature = StringMan.commaList(toSig); |
424 | |
425 | Double weight = getBigramFeaturesForPartitioningDMO.getNextDouble(2*crfHead.arity()+1); |
426 | |
427 | imSeq.registerNodeIfNotExist(prevSignature); |
428 | imSeq.registerNodeIfNotExist(currSignature); |
429 | imSeq.registerBigramFeatures(prevSignature, currSignature, label2ID.get(array2str(prevLabel)), |
430 | label2ID.get(array2str(currLabel)), weight); |
431 | } |
432 | |
433 | |
434 | |
435 | imSeq.infer(); |
436 | imSeq.dumpAnswers(bw); |
437 | |
438 | } |
439 | |
440 | } catch (Exception e) { |
441 | ExceptionMan.die("unconsistent value!"); |
442 | e.printStackTrace(); |
443 | } |
444 | } |
445 | |
446 | /** |
447 | * Conduct CRF infer WITHOUT knowledge about partitioning which is parsed statically |
448 | * from the input program. |
449 | * |
450 | * @deprecated |
451 | * |
452 | * @param bw Buffered writer to dump results. |
453 | */ |
454 | public void slowInfer(BufferedWriter bw){ |
455 | |
456 | //get the whole sequence in db based on provided seqHead |
457 | |
458 | try { |
459 | |
460 | if(this.getAllUnigramFeaturesDMO != null){ |
461 | this.getAllUnigramFeaturesDMO.execute(null, new ArrayList<Integer>()); |
462 | } |
463 | if(this.getAllBigramFeaturesDMO != null){ |
464 | this.getAllBigramFeaturesDMO.execute(null, new ArrayList<Integer>()); |
465 | } |
466 | |
467 | Sequence imSeq = new Sequence(crfHead, null, id2Label, label2ID); |
468 | String[] currLabel = new String[crfHead.getLabelFieldsArgs().size()]; |
469 | String[] prevLabel = new String[crfHead.getLabelFieldsArgs().size()]; |
470 | ArrayList<String> toSig = new ArrayList<String>(); |
471 | |
472 | if(this.getAllUnigramFeaturesDMO != null){ |
473 | while(this.getAllUnigramFeaturesDMO.next()){ |
474 | |
475 | String currSignature = ""; |
476 | int lct = 0; |
477 | toSig = new ArrayList<String>(); |
478 | for(int i=0;i<crfHead.arity();i++){ |
479 | if(crfHead.getLabelPositions().contains(i)){ |
480 | currLabel[lct++] = this.getAllUnigramFeaturesDMO.getNext(i+1).toString(); |
481 | toSig.add("%s"); |
482 | }else{ |
483 | toSig.add(this.getAllUnigramFeaturesDMO.getNext(i+1).toString()); |
484 | } |
485 | } |
486 | currSignature = StringMan.commaList(toSig); |
487 | |
488 | Double weight = this.getAllUnigramFeaturesDMO.getNextDouble(crfHead.arity()+1); |
489 | |
490 | imSeq.registerNodeIfNotExist(currSignature); |
491 | imSeq.registerUnigramFeatures(currSignature, this.label2ID.get(array2str(currLabel)), weight); |
492 | } |
493 | } |
494 | |
495 | |
496 | |
497 | if(this.getAllBigramFeaturesDMO != null){ |
498 | while(this.getAllBigramFeaturesDMO.next()){ |
499 | |
500 | String prevSignature = ""; |
501 | |
502 | toSig = new ArrayList<String>(); |
503 | int lct = 0; |
504 | for(int i=0;i<crfHead.arity();i++){ |
505 | if(crfHead.getLabelPositions().contains(i)){ |
506 | prevLabel[lct++] = this.getAllBigramFeaturesDMO.getNext(i+1).toString(); |
507 | toSig.add("%s"); |
508 | }else{ |
509 | toSig.add(this.getAllBigramFeaturesDMO.getNext(i+1).toString()); |
510 | } |
511 | } |
512 | prevSignature = StringMan.commaList(toSig); |
513 | |
514 | String currSignature = ""; |
515 | lct = 0; |
516 | toSig = new ArrayList<String>(); |
517 | for(int i=crfHead.arity();i< 2* crfHead.arity();i++){ |
518 | if(crfHead.getLabelPositions().contains(i - crfHead.arity())){ |
519 | currLabel[lct++] = this.getAllBigramFeaturesDMO.getNext(i+1).toString(); |
520 | toSig.add("%s"); |
521 | }else{ |
522 | toSig.add(this.getAllBigramFeaturesDMO.getNext(i+1).toString()); |
523 | } |
524 | } |
525 | currSignature = StringMan.commaList(toSig); |
526 | |
527 | Double weight = getAllBigramFeaturesDMO.getNextDouble(2*crfHead.arity()+1); |
528 | |
529 | imSeq.registerNodeIfNotExist(prevSignature); |
530 | imSeq.registerNodeIfNotExist(currSignature); |
531 | imSeq.registerBigramFeatures(prevSignature, currSignature, label2ID.get(array2str(prevLabel)), |
532 | label2ID.get(array2str(currLabel)), weight); |
533 | |
534 | } |
535 | } |
536 | |
537 | imSeq.infer(); |
538 | imSeq.dumpAnswers(bw); |
539 | |
540 | |
541 | } catch (Exception e) { |
542 | e.printStackTrace(); |
543 | } |
544 | } |
545 | |
546 | /** |
547 | * Generate Data Movement Operator used by this CRF Operator. |
548 | * @param rules rules defining this operator. |
549 | */ |
550 | public void prepareDMO(HashSet<ConjunctiveQuery> lrQueries, HashSet<ConjunctiveQuery> chainQueries){ |
551 | |
552 | try { |
553 | |
554 | // DMO for LR rules |
555 | for(ConjunctiveQuery cq : lrQueries){ |
556 | |
557 | DataMovementOperator dmo = new DataMovementOperator(db, this); |
558 | dmo.logicQueryPlan.addQuery(cq, cq.head.getPred().getArgs(), |
559 | new ArrayList<String>(Arrays.asList("weight")) ); |
560 | |
561 | dmo.predictedBB = 0; |
562 | dmo.PredictedFF = 1; |
563 | dmo.PredictedBF = 0; |
564 | |
565 | dmo.allowOptimization = false; |
566 | |
567 | allDMOs.add(dmo); |
568 | lrDMOs.add(dmo); |
569 | } |
570 | |
571 | // DMOs for CRF chain rules |
572 | for(ConjunctiveQuery cq : chainQueries){ |
573 | |
574 | DataMovementOperator dmo = new DataMovementOperator(db, this); |
575 | dmo.logicQueryPlan.addQuery(cq, cq.head.getPred().getArgs(), |
576 | new ArrayList<String>(Arrays.asList("weight")) ); |
577 | |
578 | dmo.predictedBB = 0; |
579 | dmo.PredictedFF = 0; |
580 | dmo.PredictedBF = this.getPartitionSize(); |
581 | |
582 | dmo.allowOptimization = true; |
583 | |
584 | allDMOs.add(dmo); |
585 | crfDMO = dmo; |
586 | } |
587 | |
588 | |
589 | //the DMO for the union of all LR DMOs |
590 | if(this.lrDMOs.size() > 0){ |
591 | this.lrDMO = DataMovementOperator.UnionAll(db, this, |
592 | this.lrDMOs, StringMan.zeros(this.crfHead.arity() + 1), new ArrayList<Integer>()); |
593 | allDMOs.add(lrDMO); |
594 | |
595 | //the DMO for all unigram features. |
596 | this.unigramDMO = new DataMovementOperator(db, this); |
597 | this.unigramDMO.allowOptimization = false; |
598 | this.unigramDMO.asView = false; |
599 | this.unigramDMO.logicQueryPlan.addQuery(db.getPrepareStatement( |
600 | "SELECT " + StringMan.commaList(lrDMO.selListFromRule) + ", sum(weight) AS sumweight " + |
601 | " FROM " + lrDMO.getAllFreeViewName() + " GROUP BY " + StringMan.commaList(lrDMO.selListFromRule) + |
602 | (crfHead.getCRFPartitionFields() == null ? "" : |
603 | " ORDER BY " + StringMan.commaList(crfHead.getCRFPartitionFields())) ), |
604 | lrDMO.selListFromRule, new ArrayList<String>(Arrays.asList("sumweight"))); |
605 | allDMOs.add(this.unigramDMO); |
606 | } |
607 | |
608 | //the DMO for the label domain |
609 | this.labelDomainDMO = new DataMovementOperator(db, this); |
610 | this.labelDomainDMO.allowOptimization = false; |
611 | ArrayList<String> fields = crfHead.getLabelFieldsTypeTable(); |
612 | ArrayList<String> fieldsViewName = new ArrayList<String>(); |
613 | for(int i=0;i<fields.size();i++){ |
614 | fieldsViewName.add("t" + i + "." + "constantid AS l" + i); |
615 | fields.set(i, fields.get(i) + " t" + i); |
616 | } |
617 | this.labelDomainDMO.logicQueryPlan.addQuery(db.getPrepareStatement( |
618 | "SELECT " + StringMan.commaList(fieldsViewName) +" FROM " + FelixStringMan.commaList(fields)), |
619 | crfHead.getLabelFieldsArgs(), |
620 | new ArrayList<String>()); |
621 | allDMOs.add(this.labelDomainDMO); |
622 | |
623 | // check whether we can use partitioning, or we must use a slow |
624 | // version which detects the head for each sequence using recursive SQL. |
625 | if(crfHead.getCRFPartitionFields() == null){ |
626 | |
627 | }else{ |
628 | |
629 | if(this.unigramDMO != null){ |
630 | this.getAllPossiblePartitioningDMO = new DataMovementOperator(db, this); |
631 | this.getAllPossiblePartitioningDMO.allowOptimization = false; |
632 | this.getAllPossiblePartitioningDMO.logicQueryPlan.addQuery(db.getPrepareStatement( |
633 | "SELECT DISTINCT * FROM ( " + |
634 | //"SELECT " + StringMan.commaList(crfHead.getCRFPartitionFields()) |
635 | // + " FROM " + crfDMO.getAllFreeViewName() + " UNION " + |
636 | " SELECT " + FelixStringMan.commaList(crfHead.getCRFPartitionFields()) |
637 | + " FROM " + unigramDMO.getAllFreeViewName() + ") nt ORDER BY " + |
638 | FelixStringMan.commaList(crfHead.getCRFPartitionFields())), crfHead.getCRFPartitionFields(), |
639 | new ArrayList<String>()); |
640 | allDMOs.add(this.getAllPossiblePartitioningDMO); |
641 | }else{ |
642 | this.getAllPossiblePartitioningDMO = new DataMovementOperator(db, this); |
643 | this.getAllPossiblePartitioningDMO.allowOptimization = false; |
644 | this.getAllPossiblePartitioningDMO.logicQueryPlan.addQuery(db.getPrepareStatement( |
645 | "SELECT DISTINCT * FROM ( SELECT " + FelixStringMan.commaList(crfHead.getCRFPartitionFields()) |
646 | + " FROM " + crfDMO.getAllFreeViewName() + ") nt ORDER BY " + |
647 | FelixStringMan.commaList(crfHead.getCRFPartitionFields())), crfHead.getCRFPartitionFields(), |
648 | new ArrayList<String>()); |
649 | allDMOs.add(this.getAllPossiblePartitioningDMO); |
650 | } |
651 | |
652 | |
653 | this.getBigramFeaturesForPartitioningDMO = DataMovementOperator.Select(db, this, |
654 | this.crfDMO, crfHead.getCRFPartitionFields()); |
655 | this.getBigramFeaturesForPartitioningDMO.isIntermediaDMO = true; |
656 | this.getBigramFeaturesForPartitioningDMO.hasKnownFetchingOrder = true; |
657 | allDMOs.add(this.getBigramFeaturesForPartitioningDMO); |
658 | |
659 | } |
660 | |
661 | this.getAllBigramFeaturesDMO = DataMovementOperator.Select(db, this, |
662 | this.crfDMO, new ArrayList<String>()); |
663 | this.getAllBigramFeaturesDMO.isIntermediaDMO = true; |
664 | this.getAllBigramFeaturesDMO.hasKnownFetchingOrder = true; |
665 | allDMOs.add(this.getAllBigramFeaturesDMO); |
666 | |
667 | if(this.unigramDMO != null){ |
668 | this.getAllUnigramFeaturesDMO = DataMovementOperator.SelectOrderBy(db, this, |
669 | this.unigramDMO, new ArrayList<String>(), |
670 | " ORDER BY " + StringMan.commaList(crfHead.getCRFPartitionFields())); |
671 | this.getAllUnigramFeaturesDMO.isIntermediaDMO = true; |
672 | this.getAllUnigramFeaturesDMO.hasKnownFetchingOrder = true; |
673 | allDMOs.add(this.getAllUnigramFeaturesDMO); |
674 | } |
675 | |
676 | } catch (Exception e) { |
677 | e.printStackTrace(); |
678 | } |
679 | } |
680 | |
681 | /** |
682 | * Estimate the number of sequences. |
683 | * @return |
684 | */ |
685 | public int getPartitionSize(){ |
686 | if(crfHead.getCRFPartitionFields() == null){ |
687 | return 1; |
688 | } |
689 | |
690 | int maxSingleField = -1; |
691 | for(int i=0;i<crfHead.getArgs().size();i++){ |
692 | if(crfHead.getCRFPartitionFields().contains(crfHead.getArgs().get(i))){ |
693 | // just an estimate |
694 | if(maxSingleField < crfHead.getTypeAt(i).size()){ |
695 | maxSingleField = crfHead.getTypeAt(i).size(); |
696 | } |
697 | } |
698 | } |
699 | if(maxSingleField > 0){ |
700 | return maxSingleField/this.partitionedInto; |
701 | } |
702 | |
703 | return Integer.MAX_VALUE; |
704 | } |
705 | |
706 | |
707 | /** |
708 | * Returns sum of given log numbers. |
709 | * @param logX |
710 | * @param logY |
711 | * @return |
712 | */ |
713 | //ACKNOWLEDGE: FROM https://facwiki.cs.byu.edu/nlp/index.php/Log_Domain_Computations |
714 | //COPYRIGHT OF THIS FUNCTION BELONGS TO ITS ORIGINAL AUTHOR |
715 | public static double logAdd(double logX, double logY) { |
716 | |
717 | if (logY > logX) { |
718 | double temp = logX; |
719 | logX = logY; |
720 | logY = temp; |
721 | } |
722 | |
723 | if (logX == Double.NEGATIVE_INFINITY) { |
724 | return logX; |
725 | } |
726 | |
727 | double negDiff = logY - logX; |
728 | //if (negDiff < -1000000) { |
729 | // return logX; |
730 | //} |
731 | |
732 | return logX + java.lang.Math.log(1.0 + java.lang.Math.exp(negDiff)); |
733 | } |
734 | |
735 | |
736 | /** |
737 | * Class for a node in the {@link Sequence}. |
738 | */ |
739 | class Node{ |
740 | public double[][] prev2currBigram; |
741 | public double[] currUnigram; |
742 | |
743 | public double[] forwardSum; |
744 | public double[] backwardSum; |
745 | public double[] currentMax; |
746 | public int[] prevArgMax; |
747 | public int[] nextArgMax; |
748 | |
749 | public Node(int nOfLabel){ |
750 | prev2currBigram = new double[nOfLabel][nOfLabel]; |
751 | for(int i=0;i<nOfLabel;i++){ |
752 | for(int j=0;j<nOfLabel;j++){ |
753 | prev2currBigram[i][j] = 0; |
754 | } |
755 | } |
756 | |
757 | currUnigram = new double[nOfLabel]; |
758 | forwardSum = new double[nOfLabel]; |
759 | backwardSum = new double[nOfLabel]; |
760 | currentMax = new double[nOfLabel]; |
761 | prevArgMax = new int[nOfLabel]; |
762 | nextArgMax = new int[nOfLabel]; |
763 | for(int i=0;i<nOfLabel;i++){ |
764 | currUnigram[i] = 0; |
765 | forwardSum[i] = Double.NEGATIVE_INFINITY; |
766 | backwardSum[i] = Double.NEGATIVE_INFINITY; |
767 | currentMax[i] = Double.NEGATIVE_INFINITY; |
768 | prevArgMax[i] = -1; |
769 | nextArgMax[i] = -1; |
770 | } |
771 | |
772 | |
773 | } |
774 | } |
775 | |
776 | /** |
777 | * In-memory representation of a CRF chain. This class |
778 | * supports infer (both marginal and MAP) and dumps results to file. |
779 | * |
780 | * @author Ce Zhang |
781 | * |
782 | */ |
783 | class Sequence{ |
784 | |
785 | // although I personally think using string instead of integer id will not |
786 | // be so slow (because the sequence is normally short), we may like to try a |
787 | // pure-integer version if we think current Viterbi is slow... |
788 | |
789 | /** |
790 | * The predicate to be labeled. |
791 | */ |
792 | Predicate pred = null; |
793 | |
794 | /** |
795 | * The signature of the root node. |
796 | */ |
797 | String rootSignature = ""; |
798 | |
799 | /** |
800 | * The signature of the last node |
801 | */ |
802 | String lastSignature = ""; |
803 | |
804 | /** |
805 | * Map from signature to Node object. |
806 | */ |
807 | HashMap<String, Node> signature2Node = new HashMap<String, Node>(); |
808 | |
809 | /** |
810 | * Set of all roots nodes in this sequence. |
811 | */ |
812 | HashSet<String> roots = new HashSet<String>(); |
813 | |
814 | /** |
815 | * Set of all last nodes in this sequence. |
816 | */ |
817 | HashSet<String> lasts = new HashSet<String>(); |
818 | |
819 | /** |
820 | * The optimal labels for the last nodes, which are used in MAP inference. |
821 | */ |
822 | HashMap<String, Integer> last2maxArg = new HashMap<String, Integer>(); |
823 | |
824 | /** |
825 | * Map from one node to the next node in the chain. |
826 | */ |
827 | HashMap<String, String> next = new HashMap<String, String>(); |
828 | |
829 | /** |
830 | * Map from one node to the previous node in the chain. |
831 | */ |
832 | HashMap<String, String> prev = new HashMap<String, String>(); |
833 | |
834 | /** |
835 | * See {@link CRFOperator#label2ID}. |
836 | */ |
837 | HashMap<Integer, String[]> id2Label = new HashMap<Integer, String[]>(); |
838 | |
839 | /** |
840 | * See {@link CRFOperator#id2Label}. |
841 | */ |
842 | HashMap<String, Integer> label2ID = new HashMap<String, Integer>(); |
843 | |
844 | /** |
845 | * the constructor. |
846 | * @param _p the predicate to be labeled. |
847 | * @param _rootSignature the root of this sequence, which can be null (in this case, |
848 | * this class will find roots before infer). |
849 | * @param _id2Label |
850 | * @param _label2ID |
851 | */ |
852 | public Sequence(Predicate _p , String _rootSignature, HashMap<Integer, String[]> _id2Label, HashMap<String, Integer> _label2ID){ |
853 | pred = _p; |
854 | rootSignature = _rootSignature; |
855 | if(rootSignature != null){ |
856 | signature2Node.put(rootSignature, new Node(_id2Label.size())); |
857 | } |
858 | id2Label = _id2Label; |
859 | label2ID = _label2ID; |
860 | } |
861 | |
862 | |
863 | /** |
864 | * Add a node in this sequence with a given signature. |
865 | * @param _signature |
866 | */ |
867 | public void registerNodeIfNotExist(String _signature){ |
868 | |
869 | if(signature2Node.containsKey(_signature)){ |
870 | return; |
871 | } |
872 | |
873 | signature2Node.put(_signature, new Node(this.id2Label.size())); |
874 | } |
875 | |
876 | /** |
877 | * Add a bigram feature for a node with signature _currSignature and label _currLabel. |
878 | * @param _prevSignature |
879 | * @param _currSignature |
880 | * @param _prevLabel |
881 | * @param _currLabel |
882 | * @param _weight |
883 | */ |
884 | public void registerBigramFeatures(String _prevSignature, String _currSignature, int _prevLabel, int _currLabel, Double _weight){ |
885 | |
886 | assert this.signature2Node.containsKey(_prevSignature); |
887 | assert this.signature2Node.containsKey(_currSignature); |
888 | |
889 | if(next.containsKey(_prevSignature)){ |
890 | assert next.get(_prevSignature).equals(_currSignature); |
891 | }else{ |
892 | next.put(_prevSignature, _currSignature); |
893 | } |
894 | |
895 | if(prev.containsKey(_currSignature)){ |
896 | assert prev.get(_currSignature).equals(_prevSignature); |
897 | }else{ |
898 | prev.put(_currSignature, _prevSignature); |
899 | } |
900 | |
901 | Node tmpNode = this.signature2Node.get(_currSignature); |
902 | tmpNode.prev2currBigram[_prevLabel][_currLabel] += _weight; |
903 | |
904 | } |
905 | |
906 | /** |
907 | * Add a unigram feature for a node with signature _currSignature and label _currLabel. |
908 | * @param _currSignature |
909 | * @param _currLabel |
910 | * @param _weight |
911 | */ |
912 | public void registerUnigramFeatures(String _currSignature, int _currLabel, Double _weight){ |
913 | assert this.signature2Node.containsKey(_currSignature); |
914 | |
915 | Node tmpNode = this.signature2Node.get(_currSignature); |
916 | tmpNode.currUnigram[_currLabel] = _weight; |
917 | } |
918 | |
919 | /** |
920 | * Infer on this sequence. |
921 | */ |
922 | public void infer(){ |
923 | |
924 | //find root |
925 | if(rootSignature == null){ |
926 | Set<String> nodes = new HashSet<String>(); |
927 | nodes.addAll(this.signature2Node.keySet()); |
928 | |
929 | while(true){ |
930 | if(nodes.size() == 0){ |
931 | break; |
932 | } |
933 | |
934 | String rootCandidate = nodes.iterator().next(); |
935 | while(true){ |
936 | if(this.next.get(rootCandidate) != null){ |
937 | rootCandidate = this.next.get(rootCandidate); |
938 | }else{ |
939 | break; |
940 | } |
941 | } |
942 | |
943 | while(true){ |
944 | nodes.remove(rootCandidate); |
945 | if(this.prev.get(rootCandidate) != null){ |
946 | rootCandidate = this.prev.get(rootCandidate); |
947 | }else{ |
948 | rootSignature = rootCandidate; |
949 | roots.add(rootSignature); |
950 | break; |
951 | } |
952 | } |
953 | |
954 | } |
955 | }else{ |
956 | roots.add(rootSignature); |
957 | } |
958 | |
959 | |
960 | //forward |
961 | for(String sssss : roots){ |
962 | this.rootSignature = sssss; |
963 | String current = this.rootSignature; |
964 | while(true){ |
965 | if(current.equals(rootSignature)){ |
966 | Node n = this.signature2Node.get(current); |
967 | // foreach label |
968 | for(int i = 0; i < n.forwardSum.length; i++){ |
969 | n.forwardSum[i] = n.currUnigram[i]; |
970 | n.currentMax[i] = n.currUnigram[i]; |
971 | } |
972 | }else{ |
973 | Node n = this.signature2Node.get(current); |
974 | Node p = this.signature2Node.get(this.prev.get(current)); |
975 | |
976 | for(int i = 0; i < n.forwardSum.length; i++){ |
977 | int maxArg = -1; |
978 | double maxValue = Double.NEGATIVE_INFINITY; |
979 | for(int j = 0; j< p.forwardSum.length; j++){ |
980 | n.forwardSum[i] = logAdd(n.prev2currBigram[j][i] + n.currUnigram[i] + p.forwardSum[j], n.forwardSum[i]); |
981 | double tmp = n.prev2currBigram[j][i] + n.currUnigram[i] + p.currentMax[j]; |
982 | if( tmp > maxValue ){ |
983 | maxValue = tmp; |
984 | maxArg = j; |
985 | } |
986 | } |
987 | n.prevArgMax[i] = maxArg; |
988 | n.currentMax[i] = maxValue; |
989 | } |
990 | } |
991 | if(!this.next.containsKey(current)){ |
992 | |
993 | Node n = this.signature2Node.get(current); |
994 | this.lastSignature = current; |
995 | |
996 | int maxArg = -1; |
997 | double maxValue = Double.NEGATIVE_INFINITY; |
998 | for(int i = 0; i < n.forwardSum.length; i++){ |
999 | if(n.currentMax[i] > maxValue){ |
1000 | maxArg = i; |
1001 | maxValue = n.currentMax[i]; |
1002 | } |
1003 | } |
1004 | |
1005 | this.lasts.add(current); |
1006 | this.last2maxArg.put(current, maxArg); |
1007 | break; |
1008 | } |
1009 | current = this.next.get(current); |
1010 | } |
1011 | |
1012 | //backward |
1013 | String last = current; |
1014 | while(true){ |
1015 | if(current.equals(last)){ |
1016 | Node n = this.signature2Node.get(current); |
1017 | // foreach label |
1018 | for(int i = 0; i < n.backwardSum.length; i++){ |
1019 | n.backwardSum[i] = n.currUnigram[i]; |
1020 | } |
1021 | }else{ |
1022 | Node p = this.signature2Node.get(current); |
1023 | Node n = this.signature2Node.get(this.next.get(current)); |
1024 | |
1025 | for(int i = 0; i < p.backwardSum.length; i++){ |
1026 | for(int j = 0; j< n.backwardSum.length; j++){ |
1027 | p.backwardSum[i] = logAdd(n.prev2currBigram[i][j] + n.currUnigram[j] + n.backwardSum[j], p.backwardSum[i]); |
1028 | } |
1029 | } |
1030 | } |
1031 | |
1032 | if(!this.prev.containsKey(current)){ |
1033 | break; |
1034 | } |
1035 | current = this.prev.get(current); |
1036 | } |
1037 | } |
1038 | } |
1039 | |
1040 | /** |
1041 | * Dump answers to the given buffered writer. These answers are in a format that can be |
1042 | * COPY into postgres table directly. |
1043 | * @param bw |
1044 | */ |
1045 | public void dumpAnswers(BufferedWriter bw){ |
1046 | try{ |
1047 | |
1048 | if(isMarginal || (FelixConfig.isFirstRunOfDD && dd_commonOutputPredicate_2_tableName.containsKey(crfHead) )){ |
1049 | for(String sssss : roots){ |
1050 | this.rootSignature = sssss; |
1051 | String current = this.rootSignature; |
1052 | while(true){ |
1053 | |
1054 | if(current.equals(rootSignature)){ |
1055 | Node n = this.signature2Node.get(current); |
1056 | double sum = Double.NEGATIVE_INFINITY; |
1057 | for(int i = 0; i < n.forwardSum.length; i++){ |
1058 | sum = logAdd(n.currUnigram[i] + n.backwardSum[i], sum); |
1059 | } |
1060 | for(int i = 0; i < n.forwardSum.length; i++){ |
1061 | double prob = n.currUnigram[i] + n.backwardSum[i] - sum; |
1062 | prob = Math.exp(prob); |
1063 | if(prob > Config.soft_evidence_activation_threshold){ |
1064 | ArrayList<String> parts = new ArrayList<String>(); |
1065 | // parts.add(Integer.toString(pred.nextTupleIDAndUpdate())); |
1066 | parts.add("TRUE"); |
1067 | parts.add(Double.toString(prob)); |
1068 | |
1069 | if(options.useDualDecomposition){ |
1070 | parts.add(Integer.toString(2)); |
1071 | }else{ |
1072 | parts.add(Integer.toString(2)); |
1073 | } |
1074 | // parts.add("1");//this is for vote |
1075 | |
1076 | //String tmp = current.replaceAll("\\?", id2Label.get(i)); |
1077 | String tmp = String.format(current, (Object[]) id2Label.get(i)); |
1078 | |
1079 | bw.append(FelixStringMan.commaListNoSpace(parts) + "," + tmp + "\n"); |
1080 | } |
1081 | } |
1082 | |
1083 | }else{ |
1084 | Node n = this.signature2Node.get(current); |
1085 | Node p = this.signature2Node.get(this.prev.get(current)); |
1086 | |
1087 | double sum = Double.NEGATIVE_INFINITY; |
1088 | for(int i = 0; i < n.forwardSum.length; i++){ |
1089 | for(int j = 0; j< p.forwardSum.length; j++){ |
1090 | sum = logAdd(p.forwardSum[j] + n.currUnigram[i] + n.prev2currBigram[j][i] |
1091 | + n.backwardSum[i], sum); |
1092 | } |
1093 | } |
1094 | |
1095 | for(int i = 0; i < n.forwardSum.length; i++){ |
1096 | double marginal = 0; |
1097 | |
1098 | for(int j = 0; j< p.forwardSum.length; j++){ |
1099 | marginal = logAdd( p.forwardSum[j] + n.currUnigram[i] + n.prev2currBigram[j][i] + n.backwardSum[i], |
1100 | marginal ); |
1101 | } |
1102 | |
1103 | double prob = Math.exp(marginal - sum); |
1104 | |
1105 | if(prob > Config.soft_evidence_activation_threshold){ |
1106 | ArrayList<String> parts = new ArrayList<String>(); |
1107 | // parts.add(Integer.toString(pred.nextTupleIDAndUpdate())); |
1108 | parts.add("TRUE"); |
1109 | parts.add(Double.toString(prob)); |
1110 | |
1111 | if(options.useDualDecomposition){ |
1112 | parts.add(Integer.toString(2)); |
1113 | }else{ |
1114 | parts.add(Integer.toString(2)); |
1115 | } |
1116 | // parts.add("1");//this is for vote |
1117 | |
1118 | String tmp = String.format(current, (Object[]) id2Label.get(i)); |
1119 | bw.append(FelixStringMan.commaListNoSpace(parts) + "," + tmp + "\n"); |
1120 | } |
1121 | |
1122 | } |
1123 | |
1124 | } |
1125 | |
1126 | if(!this.next.containsKey(current)){ |
1127 | break; |
1128 | } |
1129 | current = this.next.get(current); |
1130 | } |
1131 | } |
1132 | }else{ |
1133 | |
1134 | for(String sssss : lasts){ |
1135 | |
1136 | String current = sssss; |
1137 | |
1138 | int toDump = this.last2maxArg.get(sssss); |
1139 | |
1140 | while(true){ |
1141 | |
1142 | Node n = this.signature2Node.get(current); |
1143 | |
1144 | ArrayList<String> parts = new ArrayList<String>(); |
1145 | // parts.add(Integer.toString(pred.nextTupleIDAndUpdate())); |
1146 | parts.add("TRUE"); |
1147 | parts.add(""); |
1148 | |
1149 | if(options.useDualDecomposition){ |
1150 | parts.add(Integer.toString(2)); |
1151 | }else{ |
1152 | parts.add(Integer.toString(2)); |
1153 | } |
1154 | // parts.add("1");//this is for vote |
1155 | |
1156 | String tmp = String.format(current, (Object[]) id2Label.get(toDump)); |
1157 | bw.append(FelixStringMan.commaListNoSpace(parts) + "," + tmp + "\n"); |
1158 | |
1159 | if(!this.prev.containsKey(current)){ |
1160 | break; |
1161 | } |
1162 | |
1163 | toDump = n.prevArgMax[toDump]; |
1164 | current = this.prev.get(current); |
1165 | } |
1166 | } |
1167 | |
1168 | } |
1169 | }catch(Exception e){ |
1170 | e.printStackTrace(); |
1171 | } |
1172 | |
1173 | } |
1174 | } |
1175 | |
1176 | @Override |
1177 | public void learn() { |
1178 | |
1179 | } |
1180 | |
1181 | } |