Exemplo n.º 1
0
def eval_frame_linking(docs):
	o_links = []
	i_links = []
	c_links = []
	for d in docs:
		for f in d.frames:
			evs = [s for s in d.labels['BERT_ev'] if s.i == f.ev.i and s.f == f.ev.f]
			assert len(evs) == 1
			ev = evs[0]
			gold_i_spans = d.labels['GOLD_{}'.format(f.i.label)]
			pred_i = ev.pred_i
			i_links.append(any(utils.s_overlaps(pred_i, gold_i_spans)))

			gold_c_spans = d.labels['GOLD_{}'.format(f.c.label)]
			pred_c = ev.pred_c
			c_links.append(any(utils.s_overlaps(pred_c, gold_c_spans)))

			gold_o_spans = d.labels['GOLD_{}'.format(f.o.label)]
			found_o = False
			for pred_o in ev.pred_os:
				if any(utils.s_overlaps(pred_o, gold_o_spans)):
					found_o = True
					break
			o_links.append(found_o)
	print('i: {:.2f}'.format(np.mean(i_links)))
	print('c: {:.2f}'.format(np.mean(c_links)))
	print('o: {:.2f}'.format(np.mean(o_links)))
Exemplo n.º 2
0
def ner_entity_score(docs, true_prefix, pred_prefix, e_type):
	tp = 0
	fn = 0
	fp = 0
	for doc in docs:
		pred_spans = get_doc_spans(doc, pred_prefix)
		true_spans = get_doc_spans(doc, true_prefix)
		for e in doc.entities:
			if e.type == e_type:
				found = False
				for m in e.mentions:
					if utils.s_overlaps(m, pred_spans):
						found = True
						break
				# count any entity as a true positive if any of its mentions are tagged
				if found:
					tp += 1
				# and a false negative otherwise
				else:
					fn += 1
		# for false positives, we're trying to count how many extraneous "entities" there are
		# we need some notion of which pred spans are the same entity - lets be pessimistic
		# and say they're different entities unless they are exactly the same (overcount FPs)
		e_func = lambda s: s.text
		# to count false positives, take every pred span
		fp_spans = { e_func(s): 1 for s in pred_spans }
		for s in pred_spans:
			# and don't count ones that overlap a true mention
			if utils.s_overlaps(s, true_spans):
				fp_spans[e_func(s)] = 0
		fp += sum(fp_spans.values())
	p = tp / (tp + fp)
	r = tp / (tp + fn)
	f1 = 2*(p * r)/(p + r)
	return p, r, f1
Exemplo n.º 3
0
def write_o_ev_data_pipeline(docs, fdir):
    fout = open('{}/{}.tsv'.format(fdir, docs[0].group), 'w')
    for doc in docs:
        assert doc.group == 'test' or doc.group == 'testtest'
        for ev_span in doc.labels['BERT_ev']:
            for o_span in utils.s_overlaps(ev_span, doc.labels['NER_o']):
                fout.write('{}\t{}\t{}\t{}\t{}\t{}\n'.format('0', doc.id, o_span.i, o_span.f, \
                  utils.clean_str(o_span.text), utils.clean_str(ev_span.text)))
Exemplo n.º 4
0
def ner_span_score(docs, true_prefix, pred_prefix):
	tp = 0
	fp = 0
	fn = 0
	for doc in docs:
		pred_spans = get_doc_spans(doc, pred_prefix) 
		true_spans = get_doc_spans(doc, true_prefix) 
		for pred in pred_spans:
			if utils.s_overlaps(pred, true_spans):
				tp += 1
			else:
				fp += 1
		for true in true_spans:
			if utils.s_overlaps(true, pred_spans):
				pass # already counter the TP
			else:
				fn += 1
	p = tp / (tp + fp)
	r = tp / (tp + fn)
	f1 = 2*(p * r)/(p + r)
	return p, r, f1
Exemplo n.º 5
0
def write_o_ev_data(docs, fdir, add_i=False):
    group_docs = defaultdict(list)
    for doc in docs:
        group_docs[doc.group].append(doc)
    for group, doc_list in group_docs.items():
        fout = open('{}/{}.tsv'.format(fdir, group), 'w')
        for doc in doc_list:
            for frame in doc.frames:
                sents = utils.s_overlaps(frame.ev, doc.sents)
                ev_text = utils.clean_str(doc.text[sents[0].i:sents[-1].f])
                o_text = utils.clean_str(frame.o.text)
                if add_i:
                    o_text = '{} effect on {}'.format(
                        utils.clean_str(frame.i.text), o_text)
                fout.write('{}\t{}\t{}\t{}\n'.format(frame.label + 1, doc.id,
                                                     o_text, ev_text))
Exemplo n.º 6
0
def add_ic_ev_output(docs, group, fdir = '../models/sentence_classifier/data/i_c_intro'):
	model_input = '{}/{}.tsv'.format(fdir, group)
	model_output = '{}/results/{}_results.tsv'.format(fdir, group)
	inputs = [l.strip().split('\t') for l in open(model_input).readlines()]
	outputs = [list(map(float, l.strip().split('\t'))) for l in open(model_output).readlines()]
	assert len(inputs) == len(outputs)
	pmid_ev_map = defaultdict(lambda: defaultdict(list))
	for (_, pmid, ev_i, ev_f, i_i, i_f, i_text, context), class_probs in zip(inputs, outputs):
		result = { \
				'class_probs': list(map(float, class_probs)),
				'idx_i': int(i_i),
				'idx_f': int(i_f),
				'text': i_text
		}
		pmid_ev_map[pmid][(int(ev_i), int(ev_f))].append(result)
	for doc in docs:
		for (ev_i, ev_f), results in pmid_ev_map[doc.id].items():
			sents = [s for s in doc.labels['BERT_ev'] if s.i == ev_i and s.f == ev_f]
			assert len(sents) == 1
			sent = sents[0]
			best_i = max(results, key = lambda r: r['class_probs'][2])
			best_c = max(results, key = lambda r: r['class_probs'][1])
			sent.pred_i = classes.Span(best_i['idx_i'], best_i['idx_f'], best_i['text'])
			sent.pred_c = classes.Span(best_c['idx_i'], best_c['idx_f'], best_c['text'])
			try:
				assert sent.pred_i.text == utils.clean_str(doc.text[sent.pred_i.i:sent.pred_i.f])
			except AssertionError:
				print('Mismatch for I when loading IC results...')
				print(sent.pred_i.text)
				print(utils.clean_str(doc.text[sent.pred_i.i:sent.pred_i.f]))
			try:
				assert sent.pred_c.text == utils.clean_str(doc.text[sent.pred_c.i:sent.pred_c.f])
			except AssertionError:
				print('Mismatch for C when loading IC results...')
				print(sent.pred_c.text)
				print(utils.clean_str(doc.text[sent.pred_c.i:sent.pred_c.f]))
			sent.pred_os = utils.s_overlaps(sent, doc.labels['NER_o'])
Exemplo n.º 7
0
def get_overlapping_entities(doc, s):
	return [e.name for e in doc.entities if any(utils.s_overlaps(s, e.mentions))]
Exemplo n.º 8
0
def print_first_instance(doc, s):
	print(utils.s_overlaps(s, doc.sents))