def evaluate_opt_model(combined_model, syntax_parses, annotations, match_parses, thresholds): tps, fps, tns, fns = defaultdict(int), defaultdict(int), defaultdict(int), defaultdict(int) for pk, local_syntax_parses in syntax_parses.iteritems(): print "="*80 match_parse = match_parses[pk] for number, syntax_parse in local_syntax_parses.iteritems(): print pk, number opt_model = TextGreedyOptModel(combined_model) # opt_model = FullGreedyOptModel(combined_model, match_parse) pos_semantic_trees = set(annotation_to_semantic_tree(syntax_parse, annotation) for annotation in annotations[pk][number].values()) pos_semantic_trees = set(t for t in pos_semantic_trees if t.content.signature.id != 'CC') tag_rules = combined_model.generate_tag_rules(syntax_parse) # tag_rules = set(itertools.chain(*[t.get_tag_rules() for t in pos_semantic_trees])) unary_rules = combined_model.generate_unary_rules(tag_rules) tag_rules = filter_tag_rules(combined_model.unary_model, tag_rules, unary_rules, 0.9) binary_rules = combined_model.generate_binary_rules(tag_rules) semantic_forest = SemanticForest(tag_rules, unary_rules, binary_rules) semantic_trees = semantic_forest.get_semantic_trees_by_type("truth").union(semantic_forest.get_semantic_trees_by_type('is')) semantic_trees = set(t for t in semantic_trees if combined_model.get_tree_score(t) > 0.01) neg_semantic_trees = semantic_trees - pos_semantic_trees for pst in pos_semantic_trees: print "pos:", combined_model.get_tree_score(pst), pst for nst in neg_semantic_trees: score = combined_model.get_tree_score(nst) if score > 0: print "neg:", score, nst print "" for th in thresholds: selected_trees = opt_model.optimize(semantic_trees, th) tp = len(selected_trees - neg_semantic_trees) fp = len(selected_trees - pos_semantic_trees) tn = len(neg_semantic_trees - selected_trees) fn = len(pos_semantic_trees - selected_trees) tps[th] += tp fps[th] += fp tns[th] += tn fns[th] += fn print "-"*80 prs = {} for th in thresholds: p = float(tps[th])/max(1,tps[th]+fps[th]) r = float(tps[th])/max(1,tps[th]+fns[th]) prs[th] = p, r return prs
def get_semantic_forest(self, syntax_parse): tag_rules = self.generate_tag_rules(syntax_parse) unary_rules = self.generate_unary_rules(tag_rules) tag_rules = filter_tag_rules(self.unary_model, tag_rules, unary_rules, 0.9) unary_rules = filter_unary_rules(tag_rules, unary_rules) binary_rules = self.generate_binary_rules(tag_rules) semantic_forest = SemanticForest(tag_rules, unary_rules, binary_rules) return semantic_forest
def evaluate_rule_model(combined_model, syntax_parses, annotations, thresholds): all_pos_unary_rules = [] all_pos_core_rules = [] all_pos_is_rules = [] all_pos_cc_rules = [] all_neg_unary_rules = [] all_neg_core_rules = [] all_neg_is_rules = [] all_neg_cc_rules = [] all_pos_bool_semantic_trees = [] all_neg_bool_semantic_trees = [] for pk, local_syntax_parses in syntax_parses.iteritems(): print "\n\n\n" print pk for number, syntax_parse in local_syntax_parses.iteritems(): pos_semantic_trees = set(annotation_to_semantic_tree(syntax_parse, annotation) for annotation in annotations[pk][number].values()) pos_unary_rules = set(itertools.chain(*[semantic_tree.get_unary_rules() for semantic_tree in pos_semantic_trees])) pos_binary_rules = set(itertools.chain(*[semantic_tree.get_binary_rules() for semantic_tree in pos_semantic_trees])) tag_rules = combined_model.tag_model.generate_tag_rules(syntax_parse) # tag_rules = set(itertools.chain(*[t.get_tag_rules() for t in pos_semantic_trees])) unary_rules = combined_model.generate_unary_rules(tag_rules) tag_rules = filter_tag_rules(combined_model.unary_model, tag_rules, unary_rules, 0.9) binary_rules = combined_model.generate_binary_rules(tag_rules) core_rules = combined_model.core_model.generate_binary_rules(tag_rules) is_rules = combined_model.is_model.generate_binary_rules(tag_rules) cc_rules = combined_model.cc_model.generate_binary_rules(tag_rules) pos_core_rules, pos_is_rules, pos_cc_rules = split_binary_rules(pos_binary_rules) span_pos_cc_rules = set(r.to_span_rule() for r in pos_cc_rules) negative_unary_rules = unary_rules - pos_unary_rules neg_core_rules = core_rules - pos_core_rules neg_is_rules = is_rules - pos_is_rules # neg_cc_rules = cc_rules - pos_cc_rules neg_cc_rules = set() pos_cc_rules = set() for r in cc_rules: if r.to_span_rule() in span_pos_cc_rules: pos_cc_rules.add(r) else: neg_cc_rules.add(r) all_pos_unary_rules.extend(pos_unary_rules) all_pos_core_rules.extend(pos_core_rules) all_pos_is_rules.extend(pos_is_rules) all_pos_cc_rules.extend(pos_cc_rules) all_neg_unary_rules.extend(negative_unary_rules) all_neg_core_rules.extend(neg_core_rules) all_neg_is_rules.extend(neg_is_rules) all_neg_cc_rules.extend(neg_cc_rules) pos_bool_semantic_trees = set(t for t in pos_semantic_trees if t.content.signature.id != 'CC') semantic_forest = SemanticForest(tag_rules, unary_rules, binary_rules) bool_semantic_trees = semantic_forest.get_semantic_trees_by_type("truth").union(semantic_forest.get_semantic_trees_by_type('is')) neg_bool_semantic_trees = bool_semantic_trees - pos_bool_semantic_trees all_pos_bool_semantic_trees.extend(pos_bool_semantic_trees) all_neg_bool_semantic_trees.extend(neg_bool_semantic_trees) for pst in pos_bool_semantic_trees: print "pos:", combined_model.get_tree_score(pst), pst print "" for nst in neg_bool_semantic_trees: score = combined_model.get_tree_score(nst) if score > 0: print "neg:", combined_model.get_tree_score(nst), nst unary_prs = combined_model.unary_model.get_prs(all_pos_unary_rules, all_neg_unary_rules, thresholds) core_prs = combined_model.core_model.get_prs(all_pos_core_rules, all_neg_core_rules, thresholds) is_prs = combined_model.is_model.get_prs(all_pos_is_rules, all_neg_is_rules, thresholds) cc_prs = combined_model.cc_model.get_prs(all_pos_cc_rules, all_neg_cc_rules, thresholds) core_tree_prs = combined_model.get_tree_prs(all_pos_bool_semantic_trees, all_neg_bool_semantic_trees, thresholds) return unary_prs, core_prs, is_prs, cc_prs, core_tree_prs