1 | package felix.operator; |
2 | |
3 | import java.io.BufferedWriter; |
4 | import java.io.File; |
5 | import java.io.FileInputStream; |
6 | import java.io.FileOutputStream; |
7 | import java.io.OutputStreamWriter; |
8 | import java.sql.SQLException; |
9 | import java.util.ArrayList; |
10 | import java.util.Arrays; |
11 | import java.util.Collections; |
12 | import java.util.HashMap; |
13 | import java.util.HashSet; |
14 | import java.util.List; |
15 | |
16 | import org.postgresql.PGConnection; |
17 | |
18 | |
19 | import tuffy.db.RDB; |
20 | import tuffy.mln.Predicate; |
21 | import tuffy.mln.Type; |
22 | import tuffy.ra.ConjunctiveQuery; |
23 | import tuffy.ra.Expression; |
24 | import tuffy.ra.Function; |
25 | import tuffy.util.Config; |
26 | import tuffy.util.ExceptionMan; |
27 | import tuffy.util.StringMan; |
28 | import tuffy.util.Timer; |
29 | import tuffy.util.UIMan; |
30 | import tuffy.util.UnionFind; |
31 | import felix.dstruct.DataMovementOperator; |
32 | import felix.dstruct.FelixPredicate; |
33 | import felix.dstruct.FelixQuery; |
34 | import felix.dstruct.StatOperator; |
35 | import felix.dstruct.FelixPredicate.FPProperty; |
36 | import felix.parser.FelixCommandOptions; |
37 | import felix.util.FelixConfig; |
38 | import felix.util.FelixUIMan; |
39 | |
40 | /** |
41 | * A COREF operator in Felix. |
42 | * @author Ce Zhang |
43 | * |
44 | */ |
45 | public class COREFOperator extends StatOperator{ |
46 | |
47 | /** |
48 | * DMOs for soft positive edges. |
49 | */ |
50 | ArrayList<DataMovementOperator> softPosDMOs = new ArrayList<DataMovementOperator>(); |
51 | |
52 | /** |
53 | * DMOs for soft negative edges. |
54 | */ |
55 | ArrayList<DataMovementOperator> softNegDMOs = new ArrayList<DataMovementOperator>(); |
56 | |
57 | /** |
58 | * DMOs for hard positive edges. |
59 | */ |
60 | ArrayList<DataMovementOperator> hardPosDMOs = new ArrayList<DataMovementOperator>(); |
61 | |
62 | /** |
63 | * DMOs for hard negative edges. |
64 | */ |
65 | ArrayList<DataMovementOperator> hardNegDMOs = new ArrayList<DataMovementOperator>(); |
66 | |
67 | /** |
68 | * DMO for efficient representation of hard negative rules. |
69 | */ |
70 | DataMovementOperator nodeClassRule = null; |
71 | |
72 | /** |
73 | * DMO for efficient representation of hard negative rules. |
74 | */ |
75 | DataMovementOperator classTagsRule = null; |
76 | |
77 | /** |
78 | * DMO for retrieving the node's domain. |
79 | */ |
80 | DataMovementOperator nodeListDMO = null; |
81 | |
82 | /** |
83 | * The DataMovementOperator which is the union of all hard-neg DataMovementOperators. |
84 | */ |
85 | DataMovementOperator hardNegDMO1 = null; |
86 | |
87 | /** |
88 | * The DataMovementOperator which is the union of all hard-neg DataMovementOperators. |
89 | */ |
90 | DataMovementOperator hardNegDMO2 = null; |
91 | |
92 | /** |
93 | * The DataMovementOperator which is the union of all hard-pos DataMovementOperators. |
94 | */ |
95 | DataMovementOperator hardPosDMO = null; |
96 | |
97 | /** |
98 | * The DataMovementOperator which is the union of all soft-pos DataMovementOperators. |
99 | */ |
100 | DataMovementOperator softPosDMO1 = null; |
101 | |
102 | /** |
103 | * The DataMovementOperator which is the union of all soft-pos DataMovementOperators. |
104 | */ |
105 | DataMovementOperator softPosDMO2 = null; |
106 | |
107 | /** |
108 | * Whether represent clusterings results using pairwise representation. |
109 | * Note that, setting this parameter to true will cause quadratic numbers of |
110 | * result tuples. |
111 | */ |
112 | public boolean usePairwiseRepresentation = false; |
113 | |
114 | /** |
115 | * Target predicate of this Coref operator. |
116 | */ |
117 | FelixPredicate corefHead; |
118 | |
119 | /** |
120 | * Special syntax sugar for coref. |
121 | */ |
122 | HashMap<Integer, Integer> nodeClass = null; |
123 | |
124 | /** |
125 | * Special syntax sugar for coref. |
126 | */ |
127 | HashMap<Integer, HashSet<Integer>> classTags = null; |
128 | |
129 | /** |
130 | * The constructor of COREFOperator. |
131 | * @param _fq Felix query. |
132 | * @param _goalPredicates target predicates of this coref operator. |
133 | * @param _opt Command line options of this Felix run. |
134 | */ |
135 | public COREFOperator(FelixQuery _fq, HashSet<FelixPredicate> _goalPredicates, |
136 | FelixCommandOptions _opt) { |
137 | super(_fq, _goalPredicates, _opt); |
138 | for(FelixPredicate p : _goalPredicates){ |
139 | p.isCorefPredicate = true; |
140 | } |
141 | this.type = OPType.COREF; |
142 | this.precedence = 10; |
143 | } |
144 | |
145 | /** |
146 | * Get the size of domain on which clustering is conducted. |
147 | * @param p clustering predicate |
148 | * @return |
149 | */ |
150 | public int getDomainSize(Predicate p){ |
151 | return p.getTypeAt(0).size()/this.partitionedInto; |
152 | } |
153 | |
154 | |
155 | /** |
156 | * Generate Data Movement Operator used by this Coref Operator. |
157 | * @param rules rules defining this operator. |
158 | */ |
159 | public void prepareDMO(HashSet<ConjunctiveQuery> rules){ |
160 | |
161 | // build linear representation table for this operator |
162 | if(fq.getPredByName(corefHead.getName() + "_map") == null){ |
163 | FelixPredicate pmap = new FelixPredicate(corefHead.getName() + "_map", true); |
164 | pmap.appendArgument(corefHead.getTypeAt(0)); |
165 | pmap.appendArgument(corefHead.getTypeAt(1)); |
166 | pmap.prepareDB(db); |
167 | fq.addFelixPredicate(pmap); |
168 | } |
169 | |
170 | |
171 | // generate Data Movement Operator according to conjunctive queries |
172 | for(ConjunctiveQuery cq : rules){ |
173 | |
174 | /* |
175 | if(cq.sourceClause.hasEmbeddedWeight()){ |
176 | Expression e = new Expression(Function.GreaterThan); |
177 | Expression tmpe1 = Expression.exprVariableBinding(cq.sourceClause.getVarWeight()); |
178 | Expression tmpe = Expression.exprConstInteger(0); |
179 | e.addArgument(tmpe1); |
180 | e.addArgument(tmpe); |
181 | e.changeName = false; |
182 | |
183 | cq.addConstraint(e); |
184 | }*/ |
185 | |
186 | DataMovementOperator dmo = new DataMovementOperator(db, this); |
187 | dmo.logicQueryPlan.addQuery(cq, cq.head.getPred().getArgs(), |
188 | new ArrayList<String>(Arrays.asList("weight")) ); |
189 | dmo.whichToBound.add(cq.head.getTerms().get(0).toString()); |
190 | |
191 | // hard rule |
192 | if(cq.sourceClause.isHardClause() && cq.getWeight() > 0){ |
193 | dmo.predictedBB = 0; |
194 | dmo.PredictedFF = 1; |
195 | dmo.PredictedBF = 0; |
196 | hardPosDMOs.add(dmo); |
197 | allDMOs.add(dmo); |
198 | } |
199 | // hard negative rule |
200 | else if(cq.sourceClause.isHardClause() && cq.getWeight() < 0){ |
201 | dmo.predictedBB = 0; |
202 | dmo.PredictedFF = 1; |
203 | dmo.PredictedBF = this.getDomainSize(corefHead); |
204 | hardNegDMOs.add(dmo); |
205 | allDMOs.add(dmo); |
206 | } |
207 | // soft incomplete rule |
208 | //TODO |
209 | // else if( (!cq.sourceClause.isHardClause() && cq.getWeight() > 0) || |
210 | // cq.sourceClause.hasEmbeddedWeight() |
211 | // ){ |
212 | else if( (!cq.sourceClause.isHardClause() )){ |
213 | dmo.predictedBB = 0; |
214 | dmo.PredictedFF = 0; |
215 | dmo.PredictedBF = this.getDomainSize(corefHead); |
216 | softPosDMOs.add(dmo); |
217 | allDMOs.add(dmo); |
218 | }else{ |
219 | UIMan.warn("The following rule is ignored in the COREFOperator!\n" + cq); |
220 | } |
221 | |
222 | } |
223 | |
224 | // generate Data Movement Operator used in this operator |
225 | |
226 | //first, the DMO used to fetch the node domain |
227 | DataMovementOperator dmo = new DataMovementOperator(db, this); |
228 | dmo.predictedBB = 0; dmo.PredictedFF = 1; dmo.PredictedBF = 0; |
229 | dmo.logicQueryPlan.addQuery(db.getPrepareStatement( |
230 | "SELECT DISTINCT constantID FROM " + corefHead.getTypeAt(0).getRelName()), |
231 | new ArrayList<String>(Arrays.asList("constantID")), new ArrayList<String>()); |
232 | dmo.allowOptimization = false; |
233 | nodeListDMO = dmo; |
234 | allDMOs.add(dmo); |
235 | |
236 | //second, the DMO for the union of all hard-pos DMO |
237 | if(this.hardPosDMOs.size() > 0){ |
238 | this.hardPosDMO = DataMovementOperator.UnionAll(db, this, |
239 | this.hardPosDMOs, "000", new ArrayList<Integer>()); |
240 | this.hardPosDMO.isIntermediaDMO = true; |
241 | allDMOs.add(hardPosDMO); |
242 | } |
243 | |
244 | //third, the DMO for the union of all soft-pos DMO |
245 | if(this.softPosDMOs.size() > 0){ |
246 | |
247 | this.softPosDMO1 = DataMovementOperator.UnionAll(db, this, |
248 | this.softPosDMOs, "100", new ArrayList<Integer>()); |
249 | this.softPosDMO1.isIntermediaDMO = true; |
250 | allDMOs.add(softPosDMO1); |
251 | |
252 | /* |
253 | DataMovementOperator groupedSoftPosDMO1 = new DataMovementOperator(db, this); |
254 | groupedSoftPosDMO1.predictedBB = 0; groupedSoftPosDMO1.PredictedFF = 1; groupedSoftPosDMO1.PredictedBF = 0; |
255 | groupedSoftPosDMO1.logicQueryPlan.addQuery(db.getPrepareStatement( |
256 | "SELECT " + softPosDMO1.finalSelList.get(0) + "," |
257 | + softPosDMO1.finalSelList.get(1) + "," |
258 | + "sum(" + softPosDMO1.finalSelList.get(0) + ") as sumweight " |
259 | + "FROM " + softPosDMO1.getAllFreeViewName() + " " |
260 | + "GROUP BY " + softPosDMO1.finalSelList.get(0) + "," |
261 | + softPosDMO1.finalSelList.get(1) + " " |
262 | + "WHERE " + "sumweight > 0" |
263 | + " AND " + softPosDMO1.finalSelList.get(0) + " = ?"), |
264 | softPosDMO1.finalSelList, new ArrayList<String>()); |
265 | groupedSoftPosDMO1.allowOptimization = false; |
266 | allDMOs.add(groupedSoftPosDMO1); |
267 | */ |
268 | |
269 | |
270 | this.softPosDMO2 = DataMovementOperator.UnionAll(db, this, |
271 | this.softPosDMOs, "010", new ArrayList<Integer>()); |
272 | this.softPosDMO2.isIntermediaDMO = true; |
273 | allDMOs.add(softPosDMO2); |
274 | |
275 | /* |
276 | DataMovementOperator groupedSoftPosDMO2 = new DataMovementOperator(db, this); |
277 | groupedSoftPosDMO2.predictedBB = 0; groupedSoftPosDMO2.PredictedFF = 1; groupedSoftPosDMO2.PredictedBF = 0; |
278 | groupedSoftPosDMO2.logicQueryPlan.addQuery(db.getPrepareStatement( |
279 | "SELECT " + softPosDMO2.finalSelList.get(0) + "," |
280 | + softPosDMO2.finalSelList.get(1) + "," |
281 | + "sum(" + softPosDMO2.finalSelList.get(0) + ") as sumweight " |
282 | + "FROM " + softPosDMO2.getAllFreeViewName() + " " |
283 | + "GROUP BY " + softPosDMO2.finalSelList.get(0) + "," |
284 | + softPosDMO2.finalSelList.get(1) + " " |
285 | + "WHERE " + "sumweight > 0" |
286 | + " AND " + softPosDMO2.finalSelList.get(0) + " = ?"), |
287 | softPosDMO1.finalSelList, new ArrayList<String>()); |
288 | groupedSoftPosDMO2.allowOptimization = false; |
289 | allDMOs.add(groupedSoftPosDMO2); |
290 | */ |
291 | } |
292 | |
293 | //forth, the DMO for the union of all hard-negative DMO |
294 | if(this.hardNegDMOs.size() > 0){ |
295 | this.hardNegDMO1 = DataMovementOperator.UnionAll(db, this, |
296 | this.hardNegDMOs, "100", new ArrayList<Integer>()); |
297 | this.hardNegDMO1.isIntermediaDMO = true; |
298 | allDMOs.add(hardNegDMO1); |
299 | |
300 | this.hardNegDMO2 = DataMovementOperator.UnionAll(db, this, |
301 | this.hardNegDMOs, "010", new ArrayList<Integer>()); |
302 | this.hardNegDMO2.isIntermediaDMO = true; |
303 | allDMOs.add(hardNegDMO2); |
304 | } |
305 | |
306 | //then process special rules |
307 | for(ConjunctiveQuery cq : fq.getSpecialClusteringRules(this.corefHead.getName())){ |
308 | |
309 | if(cq.type == ConjunctiveQuery.CLUSTERING_RULE_TYPE.NODE_CLASS){ |
310 | |
311 | nodeClassRule = new DataMovementOperator(db, this); |
312 | nodeClassRule.logicQueryPlan.addQuery(cq, cq.head.getPred().getArgs(), |
313 | new ArrayList<String>() ); |
314 | nodeClassRule.whichToBound.add(cq.head.getTerms().get(0).toString()); |
315 | nodeClassRule.allowOptimization = false; |
316 | allDMOs.add(nodeClassRule); |
317 | |
318 | }else if(cq.type == ConjunctiveQuery.CLUSTERING_RULE_TYPE.CLASS_TAGS){ |
319 | classTagsRule = new DataMovementOperator(db, this); |
320 | classTagsRule.logicQueryPlan.addQuery(cq, cq.head.getPred().getArgs(), |
321 | new ArrayList<String>() ); |
322 | classTagsRule.whichToBound.add(cq.head.getTerms().get(0).toString()); |
323 | classTagsRule.allowOptimization = false; |
324 | allDMOs.add(classTagsRule); |
325 | |
326 | }else{ |
327 | ExceptionMan.die("No special rules other than NODE_CLASS " + |
328 | "and CLASS_TAGS are supported in Felix!"); |
329 | } |
330 | |
331 | } |
332 | |
333 | |
334 | } |
335 | |
336 | boolean prepared = false; |
337 | |
338 | /** |
339 | * Prepares operator for execution. |
340 | */ |
341 | @Override |
342 | public void prepare() { |
343 | |
344 | softPosDMOs.clear(); |
345 | |
346 | softNegDMOs.clear(); |
347 | |
348 | hardPosDMOs.clear(); |
349 | |
350 | hardNegDMOs.clear(); |
351 | |
352 | allDMOs.clear(); |
353 | //if(!prepared){ |
354 | |
355 | db = RDB.getRDBbyConfig(Config.db_schema); |
356 | |
357 | corefHead = this.getTargetPredicateIfHasOnlyOne(); |
358 | HashSet<ConjunctiveQuery> rules = |
359 | this.translateFelixClasesIntoFactorGraphEdgeQueries(corefHead, true, |
360 | this.inputPredicateScope, |
361 | FPProperty.NON_RECUR, |
362 | FPProperty.CHAIN_RECUR, |
363 | FPProperty.OTHER_RECUR); |
364 | |
365 | this.prepareDMO(rules); |
366 | prepared = true; |
367 | //} |
368 | } |
369 | |
370 | /** |
371 | * Executes operator. |
372 | */ |
373 | @Override |
374 | public void run() { |
375 | UIMan.print(">>> Start Running " + this); |
376 | |
377 | try{ |
378 | |
379 | this.isMarginal = belongsToBucket.isMarginal(); |
380 | |
381 | Timer.start("Coref-Op" + this.getId()); |
382 | |
383 | if(corefHead == null){ |
384 | throw new Exception("The head of this Coref operator is NULL."); |
385 | } |
386 | |
387 | db.disableAutoCommitForNow(); |
388 | |
389 | cluster(); |
390 | |
391 | db.setAutoCommit(true); |
392 | |
393 | //this.oriMLN.dumpMapAnswerForPredicate(options.fout+"_coref_" + headPredicate.getName() + "_op" + id, |
394 | // headPredicate, false); |
395 | |
396 | FelixUIMan.println(0,0, "\n>>> {" + this + "} uses " + Timer.elapsed("Coref-Op" + this.getId())); |
397 | |
398 | if(!options.useDualDecomposition){ |
399 | this.belongsToBucket.runNextOperatorInBucket(); |
400 | } |
401 | |
402 | db.commit(); |
403 | //db.close(); |
404 | |
405 | // TODO!!!!!!!!!!!!!!! |
406 | // db.close(); |
407 | |
408 | } catch (Exception e) { |
409 | e.printStackTrace(); |
410 | } |
411 | |
412 | } |
413 | |
414 | @Override |
415 | public String explain() { |
416 | // TODO Auto-generated method stub |
417 | return null; |
418 | } |
419 | |
420 | /** |
421 | * Clustering worker. |
422 | */ |
423 | public void cluster() throws Exception{ |
424 | |
425 | procSepcialRules(); |
426 | |
427 | // get domain for the clustering predicate |
428 | // We assume P(type,type) |
429 | ArrayList<Integer> nodes = new ArrayList<Integer>(); |
430 | this.nodeListDMO.execute(null, new ArrayList<Integer>()); |
431 | while(this.nodeListDMO.next()){ |
432 | nodes.add(this.nodeListDMO.getNext(1)); |
433 | } |
434 | |
435 | FelixUIMan.println(2,0,"#nodes = " + nodes.size()); |
436 | |
437 | Collections.shuffle(nodes); |
438 | HashMap<Integer, Integer> rankMap = new HashMap<Integer, Integer>(); |
439 | |
440 | ArrayList<Integer> ranks = new ArrayList<Integer>(); |
441 | for(int i=0; i<nodes.size(); i++){ |
442 | rankMap.put(nodes.get(i), i); |
443 | ranks.add(i); |
444 | } |
445 | |
446 | UnionFind<Integer> clusters = new UnionFind<Integer>(); |
447 | clusters.makeUnionFind(ranks); |
448 | ranks = null; |
449 | |
450 | Timer.start("clustering"); |
451 | int ct = 0; |
452 | int edges = 0; |
453 | |
454 | HashMap<Integer, HashSet<Integer>> hardClusters = new HashMap<Integer, HashSet<Integer>>(); |
455 | |
456 | for(int i=0; i<nodes.size(); i++){ |
457 | HashSet<Integer> s = new HashSet<Integer>(); |
458 | s.add(nodes.get(i)); |
459 | hardClusters.put(i, s); |
460 | } |
461 | |
462 | if(this.hardPosDMO != null){ |
463 | //UIMan.println(">>> Processing hard positive edges..."); |
464 | Timer.start("hardpos"); |
465 | db.disableAutoCommitForNow(); |
466 | |
467 | this.hardPosDMO.execute(null, new ArrayList<Integer>()); |
468 | int cnt = 0; |
469 | while(this.hardPosDMO.next()){ |
470 | cnt ++; |
471 | if(cnt % 100000000 == 0){ |
472 | // UIMan.print("*"); |
473 | FelixUIMan.println(2,0, "# hard edges: " + cnt); |
474 | } |
475 | Integer i = rankMap.get(this.hardPosDMO.getNext(1)); |
476 | Integer j = rankMap.get(this.hardPosDMO.getNext(2)); |
477 | i = clusters.getRoot(i); |
478 | j = clusters.getRoot(j); |
479 | if(i == j) continue; |
480 | if(i < j){ |
481 | hardClusters.get(i).addAll(hardClusters.get(j)); |
482 | hardClusters.remove(j); |
483 | }else{ |
484 | hardClusters.get(j).addAll(hardClusters.get(i)); |
485 | hardClusters.remove(i); |
486 | } |
487 | clusters.unionByValue(i, j); |
488 | } |
489 | |
490 | int x = 0; |
491 | for(int y : hardClusters.keySet()){ |
492 | x += hardClusters.get(y).size(); |
493 | } |
494 | |
495 | db.restoreAutoCommitState(); |
496 | //Timer.printElapsed("hardpos"); |
497 | } |
498 | |
499 | |
500 | for(int i=0; i<nodes.size(); i++){ |
501 | |
502 | int s = nodes.get(i); |
503 | if(ct % 100000 == 0){ |
504 | //UIMan.print("."); |
505 | FelixUIMan.println(2,0,ct + "/" + nodes.size() +" : " + |
506 | Timer.elapsed("clustering") + " edges : " + edges); |
507 | int nc = clusters.getNumClusters(); |
508 | FelixUIMan.println(2,0,"#clusters = " + nc); |
509 | } |
510 | ct ++; |
511 | |
512 | Integer root = clusters.getRoot(i); |
513 | |
514 | if(root != i) continue; |
515 | HashSet<Integer> classesInCluster = null; |
516 | |
517 | if(nodeClassRule != null && classTagsRule != null){ |
518 | classesInCluster = new HashSet<Integer>(); |
519 | for(int n : hardClusters.get(root)){ |
520 | classesInCluster.add(nodeClass.get(n)); |
521 | } |
522 | } |
523 | |
524 | List<Integer> wl = this.retrieveNeighbors(s); |
525 | HashSet<Integer> brokenByHardNeg = new HashSet<Integer>(); |
526 | brokenByHardNeg.addAll(this.retrieveHardNegEdges(s)); |
527 | |
528 | edges += wl.size(); |
529 | |
530 | //if(couldLinkPairwiseDMO != null){ |
531 | // group.addAll(hardClusters.get(i)); |
532 | //} |
533 | HashSet<Integer> merged = new HashSet<Integer>(); |
534 | merged.add(root); |
535 | |
536 | for(int t : wl){ |
537 | |
538 | int j = rankMap.get(t); |
539 | Integer r2 = clusters.getRoot(j); |
540 | if(j <= i || j != r2 || r2 == root) continue; |
541 | if(brokenByHardNeg.contains(t)) continue; |
542 | |
543 | if(nodeClassRule != null && classTagsRule != null){ |
544 | HashSet<Integer> classesInOtherCluster = new HashSet<Integer>(); |
545 | for(int n : hardClusters.get(r2)){ |
546 | classesInOtherCluster.add(nodeClass.get(n)); |
547 | } |
548 | classesInOtherCluster.removeAll(classesInCluster); |
549 | boolean compatible = true; |
550 | if(!classesInOtherCluster.isEmpty()){ |
551 | labf: |
552 | for(int x : classesInCluster){ |
553 | for(int y : classesInOtherCluster){ |
554 | HashSet<Integer> xt = classTags.get(x); |
555 | HashSet<Integer> yt = classTags.get(y); |
556 | xt.retainAll(yt); |
557 | if(xt.isEmpty()){ |
558 | compatible = false; |
559 | break labf; |
560 | } |
561 | } |
562 | } |
563 | } |
564 | if(compatible){ |
565 | classesInCluster.addAll(classesInOtherCluster); |
566 | }else{ |
567 | continue; |
568 | } |
569 | } |
570 | |
571 | clusters.unionByValue(root, r2); |
572 | brokenByHardNeg.addAll(this.retrieveHardNegEdges(t)); |
573 | merged.add(r2); |
574 | } |
575 | } |
576 | |
577 | int nc = clusters.getNumClusters(); |
578 | FelixUIMan.println(2,0,"# clusters = " + nc); |
579 | FelixUIMan.println(2,0,"# edges : " + edges); |
580 | |
581 | this.dumpAnswerToDBTable(corefHead, clusters, nodes); |
582 | } |
583 | |
584 | /** |
585 | * Get soft-pos neighbors of a given node. |
586 | * @param m1 |
587 | * @return |
588 | * @throws SQLException |
589 | */ |
590 | public List<Integer> retrieveNeighbors(Integer m1) throws SQLException{ |
591 | |
592 | ArrayList<Integer> ret = new ArrayList<Integer>(); |
593 | HashSet<Integer> ns = new HashSet<Integer>(); |
594 | |
595 | if(this.softPosDMO1 != null){ |
596 | |
597 | this.softPosDMO1.execute(null , new ArrayList<Integer>(Arrays.asList(m1))); |
598 | |
599 | while(this.softPosDMO1.next()){ |
600 | Integer i = softPosDMO1.getNext(1) + softPosDMO1.getNext(2) - m1; |
601 | if(!ns.contains(i)){ |
602 | ns.add(i); |
603 | ret.add(i); |
604 | } |
605 | } |
606 | } |
607 | |
608 | if(this.softPosDMO2 != null){ |
609 | |
610 | this.softPosDMO2.execute(null , new ArrayList<Integer>(Arrays.asList(m1))); |
611 | |
612 | while(this.softPosDMO2.next()){ |
613 | Integer i = softPosDMO2.getNext(1) + softPosDMO2.getNext(2) - m1; |
614 | if(!ns.contains(i)){ |
615 | ns.add(i); |
616 | ret.add(i); |
617 | } |
618 | } |
619 | } |
620 | |
621 | return ret; |
622 | } |
623 | |
624 | /** |
625 | * Get hard-neg neighbors of a given node. |
626 | * @param m1 |
627 | * @return |
628 | * @throws SQLException |
629 | */ |
630 | public List<Integer> retrieveHardNegEdges(Integer m1) throws SQLException{ |
631 | |
632 | ArrayList<Integer> ret = new ArrayList<Integer>(); |
633 | HashSet<Integer> ns = new HashSet<Integer>(); |
634 | |
635 | if(this.hardNegDMO1 != null){ |
636 | |
637 | this.hardNegDMO1.execute(null , new ArrayList<Integer>(Arrays.asList(m1))); |
638 | |
639 | while(this.hardNegDMO1.next()){ |
640 | Integer i = hardNegDMO1.getNext(1) + hardNegDMO1.getNext(2) - m1; |
641 | if(!ns.contains(i)){ |
642 | ns.add(i); |
643 | ret.add(i); |
644 | } |
645 | } |
646 | } |
647 | |
648 | if(this.hardNegDMO2 != null){ |
649 | |
650 | this.hardNegDMO2.execute(null , new ArrayList<Integer>(Arrays.asList(m1))); |
651 | |
652 | while(this.hardNegDMO2.next()){ |
653 | Integer i = hardNegDMO2.getNext(1) + hardNegDMO2.getNext(2) - m1; |
654 | if(!ns.contains(i)){ |
655 | ns.add(i); |
656 | ret.add(i); |
657 | } |
658 | } |
659 | } |
660 | |
661 | return ret; |
662 | } |
663 | |
664 | /** |
665 | * Process Tag and Class rules. |
666 | */ |
667 | private void procSepcialRules(){ |
668 | //System.out.println(">>> Processing special rules..."); |
669 | if(nodeClassRule != null){ |
670 | nodeClass = new HashMap<Integer, Integer>(); |
671 | nodeClassRule.execute(null, new ArrayList<Integer>()); |
672 | while(nodeClassRule.next()){ |
673 | int a = nodeClassRule.getNext(1); |
674 | int b = nodeClassRule.getNext(2); |
675 | nodeClass.put(a, b); |
676 | } |
677 | } |
678 | |
679 | if(classTagsRule != null){ |
680 | classTags = new HashMap<Integer, HashSet<Integer>>(); |
681 | classTagsRule.execute(null, new ArrayList<Integer>()); |
682 | while(classTagsRule.next()){ |
683 | int a = classTagsRule.getNext(1); |
684 | int b = classTagsRule.getNext(2); |
685 | if(!classTags.containsKey(a)){ |
686 | classTags.put(a, new HashSet<Integer>()); |
687 | } |
688 | classTags.get(a).add(b); |
689 | } |
690 | |
691 | } |
692 | |
693 | } |
694 | |
695 | /** |
696 | * Dump answers to a database table (or create view for it). |
697 | * @param p clustering predicate. |
698 | * @param clusters clustering result. |
699 | * @param nodes domain on which clustering is conducted. |
700 | */ |
701 | public void dumpAnswerToDBTable(Predicate p, UnionFind<Integer> clusters, ArrayList<Integer> nodes){ |
702 | |
703 | File loadingFile = new File(Config.getLoadingDir(), "loading_cg_" + p.getRelName() + "_op" + this.getId()); |
704 | //p.nextTupleID = 0; |
705 | |
706 | Predicate pmap = fq.getPredByName(p.getName() + "_map"); |
707 | //pmap.nextTupleID = 0; |
708 | |
709 | String relLinear = pmap.getRelName(); |
710 | |
711 | try { |
712 | BufferedWriter loadingFileWriter = new BufferedWriter(new OutputStreamWriter |
713 | (new FileOutputStream(loadingFile),"UTF8")); |
714 | |
715 | HashMap<Integer, Integer> map = clusters.getPartitionMap(); |
716 | HashMap<Integer, HashSet<Integer>> label2mention = new HashMap<Integer, HashSet<Integer>>(); |
717 | for(int k : map.keySet()){ |
718 | int c = map.get(k); |
719 | k = nodes.get(k); |
720 | c = nodes.get(c); |
721 | HashSet<Integer> set = label2mention.get(c); |
722 | if(set == null){ |
723 | set = new HashSet<Integer>(); |
724 | label2mention.put(c, set); |
725 | } |
726 | set.add(k); |
727 | } |
728 | |
729 | if(this.usePairwiseRepresentation){ |
730 | |
731 | for(Integer clusterID : label2mention.keySet()){ |
732 | if(label2mention.get(clusterID).size() == 0){ |
733 | continue; |
734 | } |
735 | |
736 | for(Integer node1 : label2mention.get(clusterID)){ |
737 | for(Integer node2 : label2mention.get(clusterID)){ |
738 | |
739 | ArrayList<String> parts = new ArrayList<String>(); |
740 | //parts.add(Integer.toString(p.nextTupleID(p.nextTupleID++))); |
741 | parts.add("TRUE"); |
742 | if(options.useDualDecomposition){ |
743 | parts.add(Integer.toString(1)); |
744 | }else{ |
745 | parts.add(Integer.toString(2)); |
746 | } |
747 | |
748 | |
749 | parts.add(Integer.toString(node1)); |
750 | parts.add(Integer.toString(node2)); |
751 | |
752 | loadingFileWriter.append(StringMan.join(",", parts) + "\n"); |
753 | |
754 | } |
755 | } |
756 | } |
757 | }else{ |
758 | for(Integer clusterID : label2mention.keySet()){ |
759 | if(label2mention.get(clusterID).size() == 0){ |
760 | continue; |
761 | } |
762 | |
763 | int newClusterID = clusterID; |
764 | for(Integer node1 : label2mention.get(clusterID)){ |
765 | if(node1 < newClusterID){ |
766 | newClusterID = node1; |
767 | } |
768 | } |
769 | |
770 | for(Integer node1 : label2mention.get(clusterID)){ |
771 | |
772 | ArrayList<String> parts = new ArrayList<String>(); |
773 | //parts.add(Integer.toString(p.nextTupleID(pmap.nextTupleID++))); |
774 | //parts.add("TRUE"); |
775 | |
776 | if(options.useDualDecomposition){ |
777 | parts.add("TRUE"); |
778 | parts.add(Integer.toString(2)); |
779 | }else{ |
780 | parts.add("TRUE"); |
781 | parts.add(Integer.toString(2)); |
782 | } |
783 | |
784 | parts.add(Integer.toString(node1)); |
785 | parts.add(Integer.toString(newClusterID)); |
786 | |
787 | loadingFileWriter.append(StringMan.join(",", parts) + "\n"); |
788 | |
789 | } |
790 | } |
791 | } |
792 | loadingFileWriter.close(); |
793 | |
794 | if(this.usePairwiseRepresentation){ |
795 | |
796 | FileInputStream in = new FileInputStream(loadingFile); |
797 | PGConnection con = (PGConnection)db.getConnection(); |
798 | |
799 | String sql; |
800 | //String sql = "DELETE FROM " + p.getRelName(); |
801 | //db.update(sql); |
802 | //db.vacuum(p.getRelName()); |
803 | |
804 | sql = "COPY " + p.getRelName() + "(truth, club, " + StringMan.commaList(p.getArgs()) + " ) FROM STDIN CSV"; |
805 | con.getCopyAPI().copyIn(sql, in); |
806 | in.close(); |
807 | p.isCurrentlyView = false; |
808 | |
809 | }else{ |
810 | |
811 | FileInputStream in = new FileInputStream(loadingFile); |
812 | PGConnection con = (PGConnection)db.getConnection(); |
813 | |
814 | String sql; |
815 | db.dropView(p.getRelName()); |
816 | db.dropTable(p.getRelName()); |
817 | |
818 | //sql = "DELETE FROM " + relLinear; |
819 | //db.update(sql); |
820 | //db.vacuum(relLinear); |
821 | |
822 | if(options.useDualDecomposition){ |
823 | |
824 | if(FelixConfig.isFirstRunOfDD){ |
825 | sql = "COPY " + relLinear + "(truth, club," + StringMan.commaList(pmap.getArgs()) + " ) FROM STDIN CSV"; |
826 | con.getCopyAPI().copyIn(sql, in); |
827 | in.close(); |
828 | //p.setHasSoftEvidence(true); |
829 | } |
830 | }else{ |
831 | sql = "COPY " + relLinear + "(truth, club, " + StringMan.commaList(pmap.getArgs()) + " ) FROM STDIN CSV"; |
832 | con.getCopyAPI().copyIn(sql, in); |
833 | in.close(); |
834 | } |
835 | |
836 | db.dropIndex(relLinear + "_label_idx"); |
837 | sql = "CREATE INDEX " + relLinear + "_label_idx on " + |
838 | relLinear + "(" + pmap.getArgs().get(1) + ")"; |
839 | db.update(sql); |
840 | |
841 | db.dropIndex(relLinear + "_node_idx"); |
842 | sql = "CREATE INDEX " + relLinear + "_node_idx on " + |
843 | relLinear + "(" + pmap.getArgs().get(0) + ")"; |
844 | db.update(sql); |
845 | |
846 | db.analyze(relLinear); |
847 | |
848 | db.dropSequence(p.getRelName()+"_seq"); |
849 | sql = "CREATE SEQUENCE " + p.getRelName()+"_seq" + |
850 | " START WITH 1"; |
851 | db.update(sql); |
852 | |
853 | sql = "CREATE VIEW " + p.getRelName() + " AS SELECT nextval('" + p.getRelName()+"_seq')" + |
854 | "::integer AS id," + |
855 | "TRUE::boolean AS truth, NULL::float as prior, " + |
856 | "2::integer as club, NULL::integer as atomID, " + |
857 | "t1."+pmap.getArgs().get(0)+"::integer as " + p.getArgs().get(0) + |
858 | ", t2."+pmap.getArgs().get(0)+"::integer as " + p.getArgs().get(1) + |
859 | " FROM " + relLinear + " t1, " + relLinear + " t2" + |
860 | " WHERE t1."+pmap.getArgs().get(1)+"=t2."+pmap.getArgs().get(1)+""; |
861 | p.isCurrentlyView = true; |
862 | //p.hasSoftEvidence = false; |
863 | p.setHasSoftEvidence(false); |
864 | db.update(sql); |
865 | |
866 | ((FelixPredicate)p).viewDef = sql + ""; |
867 | |
868 | if(options.useDualDecomposition){ |
869 | for(FelixPredicate fp : this.dd_CommonOutput){ |
870 | if(!fp.getName().equals(this.corefHead.getName()) |
871 | && !fp.getName().equals(this.corefHead.corefMAPPredicate.getName())){ |
872 | ExceptionMan.die("COREF 868: There must be something wrong with the parser!"); |
873 | } |
874 | |
875 | in = new FileInputStream(loadingFile); |
876 | String tableName; |
877 | String viewName; |
878 | |
879 | if(fp.isCorefMapPredicate){ |
880 | tableName = this.dd_commonOutputPredicate_2_tableName.get(fp); |
881 | }else{ |
882 | tableName = this.dd_commonOutputPredicate_2_tableName.get(fp.corefMAPPredicate); |
883 | } |
884 | |
885 | if(fp.isCorefMapPredicate){ |
886 | viewName = this.dd_commonOutputPredicate_2_tableName.get(fp.oriCorefPredicate); |
887 | }else{ |
888 | viewName = this.dd_commonOutputPredicate_2_tableName.get(fp); |
889 | } |
890 | |
891 | if(viewName != null) |
892 | db.dropView(viewName); |
893 | //db.dropTable(tableName); |
894 | |
895 | sql = "DELETE from " + tableName; |
896 | db.execute(sql); |
897 | |
898 | sql = "COPY " + tableName + "(truth, club, " + StringMan.commaList(pmap.getArgs()) + " ) FROM STDIN CSV"; |
899 | con.getCopyAPI().copyIn(sql, in); |
900 | in.close(); |
901 | |
902 | db.execute("UPDATE " + tableName + " SET prior = 1"); |
903 | |
904 | db.dropIndex(tableName + "_label_idx"); |
905 | sql = "CREATE INDEX " + tableName + "_label_idx on " + |
906 | tableName + "(" + pmap.getArgs().get(1) + ")"; |
907 | db.update(sql); |
908 | |
909 | db.dropIndex(tableName + "_node_idx"); |
910 | sql = "CREATE INDEX " + tableName + "_node_idx on " + |
911 | tableName + "(" + pmap.getArgs().get(0) + ")"; |
912 | db.update(sql); |
913 | |
914 | db.analyze(tableName); |
915 | |
916 | if(viewName != null){ |
917 | |
918 | db.dropSequence(viewName+"_seq"); |
919 | sql = "CREATE SEQUENCE " + viewName+"_seq" + |
920 | " START WITH 1"; |
921 | db.update(sql); |
922 | |
923 | sql = "CREATE VIEW " + viewName + " AS SELECT nextval('" + viewName+"_seq')" + |
924 | "::integer AS id," + |
925 | "TRUE::boolean AS truth, 1 as prior, " + |
926 | "2::integer as club, NULL::integer as atomID, " + |
927 | "t1."+pmap.getArgs().get(0)+"::integer as " + p.getArgs().get(0) + |
928 | ", t2."+pmap.getArgs().get(0)+"::integer as " + p.getArgs().get(1) + |
929 | " FROM " + tableName + " t1, " + tableName + " t2" + |
930 | " WHERE t1."+pmap.getArgs().get(1)+"=t2."+pmap.getArgs().get(1)+""; |
931 | db.update(sql); |
932 | |
933 | |
934 | |
935 | } |
936 | |
937 | } |
938 | } |
939 | |
940 | } |
941 | |
942 | } catch (Exception e) { |
943 | ExceptionMan.handle(e); |
944 | } |
945 | |
946 | } |
947 | |
948 | |
949 | @Override |
950 | public void learn() { |
951 | |
952 | } |
953 | |
954 | } |