def loadFile(self, lines): sent = '' d = {} indsForQuestions = defaultdict(lambda: set()) for line in lines.split('\n'): line = line.strip() if not line: continue data = line.split('\t') if len(data) == 1: if sent: for ex in d[sent]: ex.indsForQuestions = dict(indsForQuestions) sent = line d[sent] = [] indsForQuestions = defaultdict(lambda: set()) else: pred = data[0] pred_index = data[1] cur = Extraction((pred, all_index(sent, pred, matchCase = False)), sent, confidence = 1.0, question_dist = self.question_dist) for q, a in zip(data[2::2], data[3::2]): indices = all_index(sent, a, matchCase = False) cur.addArg((a, indices), q) indsForQuestions[q] = indsForQuestions[q].union(indices) if sent: if cur.noPronounArgs(): d[sent].append(cur) return d
def read(self, fn): # d = defaultdict(lambda: []) d = dict() with open(fn) as fin: for line_ind, line in enumerate(fin): #print(line) data = line.strip().split('\t') #print(data) text, rel = data[:2] args = data[2:] confidence = 1 curExtraction = Extraction(pred=rel.strip(), head_pred_index=None, sent=text.strip(), confidence=float(confidence), index=line_ind) for arg in args: if "C: " in arg: continue curExtraction.addArg(arg.strip()) if text.strip() not in d: d[text.strip()] = [] d[text.strip()].append(curExtraction) self.oie = d
def read(self, fn): """ Read a tabbed format line Each line consists of: sent, prob, pred, arg1, arg2, ... """ d = {} ex_index = 0 with open(fn) as fin: for line in fin: if not line.strip(): continue data = line.strip().split('\t') text, confidence, rel = data[:3] curExtraction = Extraction( pred=rel, head_pred_index=None, sent=text, confidence=float(confidence), question_dist= "./question_distributions/dist_wh_sbj_obj1.json", index=ex_index) ex_index += 1 for arg in data[3:]: curExtraction.addArg(arg) d[text] = d.get(text, []) + [curExtraction] self.oie = d
def read(self, fn): """ Read a tabbed format line Each line consists of: sent, prob, pred, arg1, arg2, ... """ d = {} d_list = [] ex_index = 0 with open(fn) as fin: for line in fin: if not line.strip(): continue data = line.strip().split('\t') text, confidence, rel = data[:3] rel = rel.rsplit( '##') # split from right to avoid # symbol in tokens pred_pos = int(rel[1]) if len(rel) == 2 else None head_pred_index = pred_pos # TODO: head_pred_index is not necessarily the first predicate index # rel is a tuple, where the first element is str and the second element is a list of index rel = (rel[0], [pred_pos + i for i, w in enumerate(rel[0].split(' ')) ]) if len(rel) == 2 else rel[0] curExtraction = Extraction( pred=rel, pred_pos=pred_pos, head_pred_index=head_pred_index, sent=text, confidence=float(confidence), question_dist= "./question_distributions/dist_wh_sbj_obj1.json", index=ex_index, raw=line.strip()) ex_index += 1 for arg in data[3:]: arg = arg.rsplit( '##') # split from right to avoid # symbol in tokens arg_pos = int(arg[1]) if len(arg) == 2 else None # arg is a tuple, where the first element is str and the second element is a list of index arg = (arg[0], [ arg_pos + i for i, w in enumerate(arg[0].split(' ')) ]) if len(arg) == 2 else arg[0] curExtraction.addArg(arg, arg_pos) if text not in d: d[text] = [] d_list.append([]) d[text].append(curExtraction) d_list[-1].append(curExtraction) self.oie = d self.oie_list = d_list
def read(self, fn): d = {} with open(fn) as fin: for line in fin: data = line.strip().split('\t') arg1, rel, arg2 = data[2:5] confidence = data[11] text = data[12] curExtraction = Extraction(pred = rel, sent = text, confidence = float(confidence)) curExtraction.addArg(arg1) curExtraction.addArg(arg2) d[text] = d.get(text, []) + [curExtraction] self.oie = d
def read(self, fn): d = {} with open(fn) as fin: fin.readline() #remove header for line in fin: data = line.strip().split('\t') confidence, arg1, rel, arg2, enabler, attribution, text = data[: 7] curExtraction = Extraction(pred=rel, sent=text, confidence=float(confidence)) curExtraction.addArg(arg1) curExtraction.addArg(arg2) d[text] = d.get(text, []) + [curExtraction] self.oie = d
def read(self, fn): d = defaultdict(lambda: []) with open(fn) as fin: for line in fin: data = line.strip().split('\t') text, rel = data[:2] args = data[2:] confidence = 1 curExtraction = Extraction(pred=rel, sent=text, confidence=float(confidence)) for arg in args: curExtraction.addArg(arg) d[text].append(curExtraction) self.oie = d
def process_extraction(extraction, sentence, score): # rel, arg1, arg2, loc, time = [], [], [], [], [] rel, arg1, arg2, loc_time, args = [], [], [], [], [] tag_mode = 'none' rel_case = 0 for i, token in enumerate(sentence): if '[unused' in token: if extraction[i].item() == 2: rel_case = int( re.search('\[unused(.*)\]', token).group(1)) continue if extraction[i] == 1: arg1.append(token) if extraction[i] == 2: rel.append(token) if extraction[i] == 3: arg2.append(token) if extraction[i] == 4: loc_time.append(token) rel = ' '.join(rel).strip() if rel_case == 1: rel = 'is ' + rel elif rel_case == 2: rel = 'is ' + rel + ' of' elif rel_case == 3: rel = 'is ' + rel + ' from' arg1 = ' '.join(arg1).strip() arg2 = ' '.join(arg2).strip() args = ' '.join(args).strip() loc_time = ' '.join(loc_time).strip() if not self.hparams.no_lt: arg2 = (arg2 + ' ' + loc_time + ' ' + args).strip() sentence_str = ' '.join(sentence).strip() extraction = Extraction(pred=rel, head_pred_index=None, sent=sentence_str, confidence=score, index=0) extraction.addArg(arg1) extraction.addArg(arg2) return extraction
def read(self, fn): d = {} with open(fn) as fin: for line in fin: if not line.strip(): continue data = line.strip().split('\t') confidence, text, rel = data[:3] curExtraction = Extraction(pred=rel, sent=text, confidence=float(confidence)) for arg in data[4::2]: curExtraction.addArg(arg) d[text] = d.get(text, []) + [curExtraction] self.oie = d self.normalizeConfidence()
def _process_allenlp_format(self, lines): assert self._all_predictions == {} for line in lines: extr = line.split('\t') sentence = extr[0] confidence = float(extr[2]) arg1 = re.findall("<arg1>.*</arg1>", extr[1])[0].strip('<arg1>').strip('</arg1>').strip() rel = re.findall("<rel>.*</rel>", extr[1])[0].strip('<rel>').strip('</rel>').strip() arg2 = re.findall("<arg2>.*</arg2>", extr[1])[0].strip('<arg2>').strip('</arg2>').strip() extraction = Extraction(pred=rel, head_pred_index=None, sent=sentence, confidence=confidence, index=0) extraction.addArg(arg1) extraction.addArg(arg2) if sentence not in self._all_predictions: self._all_predictions[sentence] = [] self._all_predictions[sentence] = extraction
def read(self, fn): d = defaultdict(lambda: []) with open(fn, encoding='utf-8') as fin: for line_ind, line in enumerate(fin): data = line.strip().split('\t') text, rel = data[:2] args = data[2:] confidence = 1 curExtraction = Extraction(pred=rel, head_pred_index=None, sent=text, confidence=float(confidence), index=line_ind) for arg in args: curExtraction.addArg(arg) d[text].append(curExtraction) self.oie = d
def read(self, fn): d = {} with open(fn) as fin: for line in fin: data = line.strip().split('\t') confidence = data[0] if not all(data[2:5]): continue arg1, rel, arg2 = [ s[s.index('(') + 1:s.index(',List(')] for s in data[2:5] ] text = data[5] curExtraction = Extraction(pred=rel, sent=text, confidence=float(confidence)) curExtraction.addArg(arg1) curExtraction.addArg(arg2) d[text] = d.get(text, []) + [curExtraction] self.oie = d
def read(self, fn): d = {} with open(fn) as fin: for line in fin: data = line.strip().split('\t') if len(data) == 1: text = data[0] elif len(data) == 5: arg1, rel, arg2 = [s[1:-1] for s in data[1:4]] confidence = data[4] curExtraction = Extraction(pred=rel, sent=text, confidence=float(confidence)) curExtraction.addArg(arg1) curExtraction.addArg(arg2) d[text] = d.get(text, []) + [curExtraction] self.oie = d self.normalizeConfidence()
def read(self, fn): d = {} with open(fn) as fin: for line in fin: data = line.strip().split('\t') #confidence = data[0] #if not all(data[1]): # continue text = data[0] rel = data[1] args = [s[s.index('::') + 2:] for s in data[2:]] #args = data[4].strip().split(');') ar1 = True arg1 = '' args = [] for s in data[2:]: if s[:s.index('::')]=='V': ar1 = False continue if ar1: arg1 = arg1 + ' ' + s[s.index('::') + 2:] else: args.append(s[s.index('::') + 2:]) # print arg1, rel, args curExtraction = Extraction(pred = rel, sent = text, confidence = float(0.1)) curExtraction.addArg(arg1) for arg in args: curExtraction.addArg(arg) d[text] = d.get(text, []) + [curExtraction] self.oie = d
def read(self, fn): d = {} with open(fn) as fin: for line in fin: data = line.strip().split('\t') confidence = data[0] if not all(data[2:5]): continue arg1, rel = [s[s.index('(') + 1:s.index(',List(')] for s in data[2:4]] #args = data[4].strip().split(');') #print arg2s args = [s[s.index('(') + 1:s.index(',List(')] for s in data[4].strip().split(');')] # if arg1 == "the younger La Flesche": # print len(args) text = data[5] if data[1]: #print arg1, rel s = data[1] if not (arg1 + ' ' + rel).startswith(s[s.index('(') + 1:s.index(',List(')]): #print "##########Not adding context" arg1 = s[s.index('(') + 1:s.index(',List(')] + ' ' + arg1 #print arg1 + rel, ",,,,, ", s[s.index('(') + 1:s.index(',List(')] #curExtraction = Extraction(pred = rel, sent = text, confidence = float(confidence)) curExtraction = Extraction(pred = rel, head_pred_index = -1, sent = text, confidence = float(confidence)) curExtraction.addArg(arg1) for arg in args: curExtraction.addArg(arg) d[text] = d.get(text, []) + [curExtraction] self.oie = d
def get_predictions(fp): old_sentence = '' all_predictions = dict() for line in open(fp, 'r'): line = line.strip('\n') sentence, extraction, confidence = line.split('\t') if old_sentence == '': old_sentence = sentence all_predictions[sentence] = [] if old_sentence != sentence: all_predictions[sentence] = [] old_sentence = sentence try: arg1 = line[line.index('<arg1>') + 6:line.index('</arg1>')] except: arg1 = "" try: rel = line[line.index('<rel>') + 5:line.index('</rel>')] except: rel = "" try: arg2 = line[line.index('<arg2>') + 6:line.index('</arg2>')] except: arg2 = "" if not (arg1 or arg2 or rel): continue curExtraction = Extraction(pred=rel, head_pred_index=-1, sent=sentence, confidence=float(confidence)) curExtraction.addArg(arg1) curExtraction.addArg(arg2) all_predictions[sentence].append(curExtraction) return all_predictions
def loadFile(self, lines): sent = '' d = {} indsForQuestions = defaultdict(lambda: set()) for line in lines.split('\n'): line = line.strip() if not line: continue data = line.split('\t') if len(data) == 1: if sent: for ex in d[sent]: ex.indsForQuestions = dict(indsForQuestions) sent = line d[sent] = [] indsForQuestions = defaultdict(lambda: set()) else: pred = self.preproc(data[0]) pred_indices = map(int, eval(data[1])) head_pred_index = int(data[2]) cur = Extraction((pred, [pred_indices]), head_pred_index, sent, confidence=1.0) for q, a in zip(data[3::2], data[4::2]): preproc_arg = self.preproc(a) if not preproc_arg: logging.warn("Argument reduced to None: {}".format(a)) indices = fuzzy_match_phrase(preproc_arg.split(" "), sent.split(" ")) cur.addArg((preproc_arg, indices), q) indsForQuestions[q] = indsForQuestions[q].union( flatten(indices)) if sent: if cur.noPronounArgs(): cur.resolveAmbiguity() d[sent].append(cur) return d
def read(self): ###What to do about generalized questions that are not yet in this distribution set?#### ###Use analyze.py#### question_dist = dict([(q, dict([(int(loc), cnt) for (loc, cnt) in dist.iteritems()])) for (q, dist) in json.load(open(self.dist_file)).iteritems()]) \ if self.dist_file\ else {} ##pull sentence## ##pull predicate## ##pull qa pairs with 5/6 or more validations## ##possibly preprocess at this step## #load json lines data into list qa_path = self.qa_path data = [] with codecs.open(qa_path, 'r', encoding='utf8') as f: for line in f: data.append(json.loads(line)) f_out = open(self.output_file, "w") jsonl_out = open('science_eval_sent.jsonl', "w") eval_out = open('science_eval.oie', "w") verb_types = [] #parse qa data for item in data: #for item in data[(len(data)-100):(len(data) - 1)]: sent_id = item["sentenceId"].encode('utf-8') #remove science if sent_id.split(':')[0] != 'TQA': continue sentence_tokens = item["sentenceTokens"] sentence = ' '.join(sentence_tokens) sentence = sentence.encode('utf-8') if output_eval: jsonl_out.write("{" + '"' + "sentence" + '"' + ": " + '"' + sentence + '"' + "}" + '\n') for _, verb_entry in item["verbEntries"].items(): verb_index = verb_entry["verbIndex"] verb_inflected_forms = verb_entry["verbInflectedForms"] base_pred = sentence_tokens[verb_index] surfacePred = base_pred answer_list = [] questions = [] for _, question_label in verb_entry["questionLabels"].items(): #print(question_label["answerJudgments"]) answers = len(question_label["answerJudgments"]) valid_answers = len([ ans for ans in question_label["answerJudgments"] if ans["isValid"] ]) if valid_answers / (answers * 1.0) < self.min_correct: #do not log this question set continue q_string = question_label['questionString'] ans_spans = [] for ans in question_label["answerJudgments"]: if ans["isValid"]: for span in ans["spans"]: ans_spans.append(span) #add long/short flag here consolidated_spans = consolidate_answers( ans_spans, self.length) #look up answers in sentence tokens lookup_ans = lambda ans, sentence: ' '.join(sentence[ans[ 0]:ans[1]]) consolidated_ans = map(lookup_ans, consolidated_spans, [sentence_tokens] * len(consolidated_spans)) #here we can acquire of the question slots wh = question_label["questionSlots"]["wh"].split() wh = '_'.join(wh) aux = question_label["questionSlots"]["aux"].split() aux = '_'.join(aux) subj = question_label["questionSlots"]["subj"].split() subj = '_'.join(subj) #iterate through and check verb types for len > 2 verb_type = question_label['questionSlots']['verb'] inflected_verb = verb_inflected_forms[verb_type.split() [-1]] if len(verb_type.split()) == 1: trg = inflected_verb else: trg = verb_type.split()[:-1] trg.append(inflected_verb) trg = "_".join(trg) obj1 = question_label["questionSlots"]["obj"].split() obj1 = '_'.join(obj1) pp = question_label["questionSlots"]["prep"].split() pp = '_'.join(pp) obj2 = question_label["questionSlots"]["obj2"].split() obj2 = '_'.join(obj2) slotted_q = " ".join( (wh, aux, subj, trg, obj1, pp, obj2, "?")) curSurfacePred = augment_pred_with_question( base_pred, slotted_q) if len(curSurfacePred) > len(surfacePred): surfacePred = curSurfacePred questions.append(slotted_q) answer_list.append(consolidated_ans) #print wh, subj, obj1 #for ans in consolidated_spans: #question_answer_pairs.append((slotted_q,' '.join(sentence_tokens[ans[0]:ans[1]]))) ####this needs to be more sophisticated ###for each predicate - create a list of qa pairs, s.t. each unique combination of questions and answers appear ### e.g. 2 quesions each with 2 answers, leads to four qa pairs ((q1,a1),(q2,a1), ((q1,a1),(q2,a2)), ect. ### each one of these sets will lead to an extraction #now we have the augmented Pred with aux #might want to revisit this methodology #augment verb with aux # ============================================================================= # if aux in QA_SRL_AUX_MODIFIERS: # # if len(verb_type.split()) == 1: # verb = aux + " " + inflected_verb # # else: # #add the first modifier in verb tpye # #may need to revisit - in previous approach, it looks like only the surface verb and aux were sent # verb = aux + " " + verb_type.split()[0] + " " + inflected_verb # # else: # if len(verb_type.split()) == 1: # verb = inflected_verb # # else: # verb = verb_type.split()[0] + " " + inflected_verb # # ============================================================================= ##now we have sentence tokens, verb index, valid question, valid answer spans ##need question blanks for augement pred with question ###for each predicate - create a list of qa pairs, s.t. each unique combination of questions and answers appear ### e.g. 2 quesions each with 2 answers, leads to four qa pairs ((q1,a1),(q2,a1)), ((q1,a1),(q2,a2)), ect. ### each one of these sets will lead to an extraction ##noticing many instances where the rare answer doesn't make sense ##e.g. Clouds that form on the ground are called fog ##what about questions that encode a similar argument? e.g. what for and why ##These organisms need the oxygen plants release to get energy out of the food . #[(u'what _ _ needs something _ _ ?', u'organisms'), (u'why does something need something _ _ ?', u'to get energy out of the food'), (u'what does something need _ _ _ ?', u'oxygen'), (u'what does someone need something for _ ?', u'to get energy out of the food')] #need need organisms oxygen for to get energy out of the food to get energy out of the food #Considering the following edits - for each argument, only take the first question that appears for it #Considering the following edits - Only consider an answer span if it apoears by more than one annotator. Rare answers tend to be misleading surfacePred = surfacePred.encode('utf-8') base_pred = base_pred.encode('utf-8') #pred_indices = all_index(sentence, base_pred, matchCase = False) augmented_pred_indices = fuzzy_match_phrase( surfacePred.split(" "), sentence.split(" ")) #print augmented_pred_indices if not augmented_pred_indices: #find equivalent of pred_index head_pred_index = [verb_index] else: head_pred_index = augmented_pred_indices[0] for ans_set in list(itertools.product(*answer_list)): cur = Extraction((surfacePred, [head_pred_index]), verb_index, sentence, confidence=1.0, question_dist=self.dist_file) #print 'Extraction', (surfacePred, [head_pred_index]), verb_index, sentence q_as = zip(questions, ans_set) if len(q_as) == 0: continue for q_a in q_as: q = q_a[0].encode('utf-8') a = q_a[1].encode('utf-8') preproc_arg = self.preproc(a) if not preproc_arg: logging.warn( "Argument reduced to None: {}".format(a)) indices = fuzzy_match_phrase(preproc_arg.split(" "), sentence.split(" ")) #print 'q', q, 'preproc arg', preproc_arg, 'indices ', indices cur.addArg((preproc_arg, indices), q) if cur.noPronounArgs(): #print 'arguments', (preproc_arg,indices), q cur.resolveAmbiguity() if self.write: #print sentence #print q_as if self.sort: cur.getSortedArgs() #print(cur.conll(external_feats = [1,2])) f_out.write(cur.conll(external_feats=[1, 2])) f_out.write('\n') ### now to get the ordering down ### seems like now and from before, the arguments are in the order they appear in the qa file... ### get sent and word ID ### generating an output file for downstream evaluation on OIE-2016 ### evaluation framework if self.output_eval: if self.sort: cur.getSortedArgs() eval_out.write(sentence + ' \t' + cur.__str__() + '\n') self.extractions.append(cur)
def splitpredict(hparams, checkpoint_callback, meta_data_vocab, train_dataloader, val_dataloader, test_dataloader, all_sentences): mapping, conj_word_mapping = {}, {} hparams.write_allennlp = True if hparams.split_fp == '': hparams.task = 'conj' hparams.checkpoint = hparams.conj_model hparams.model_str = 'bert-base-cased' hparams.mode = 'predict' model = predict(hparams, None, meta_data_vocab, None, None, test_dataloader, all_sentences) conj_predictions = model.all_predictions_conj sentences_indices = model.all_sentence_indices_conj # conj_predictions = model.predictions # sentences_indices = model.all_sentence_indices assert len(conj_predictions) == len(sentences_indices) all_conj_words = model.all_conjunct_words_conj sentences, orig_sentences = [], [] for i, sentences_str in enumerate(conj_predictions): list_sentences = sentences_str.strip('\n').split('\n') conj_words = all_conj_words[i] if len(list_sentences) == 1: orig_sentences.append(list_sentences[0] + ' [unused1] [unused2] [unused3]') mapping[list_sentences[0]] = list_sentences[0] conj_word_mapping[list_sentences[0]] = conj_words sentences.append(list_sentences[0] + ' [unused1] [unused2] [unused3]') elif len(list_sentences) > 1: orig_sentences.append(list_sentences[0] + ' [unused1] [unused2] [unused3]') conj_word_mapping[list_sentences[0]] = conj_words for sent in list_sentences[1:]: mapping[sent] = list_sentences[0] sentences.append(sent + ' [unused1] [unused2] [unused3]') else: assert False sentences.append('\n') count = 0 for sentence_indices in sentences_indices: if len(sentence_indices) == 0: count += 1 else: count += len(sentence_indices) assert count == len(sentences) - 1 else: with open(hparams.predict_fp, 'r') as f: lines = f.read() lines = lines.replace("\\", "") sentences = [] orig_sentences = [] extra_str = " [unused1] [unused2] [unused3]" for line in lines.split('\n\n'): if len(line) > 0: list_sentences = line.strip().split('\n') if len(list_sentences) == 1: mapping[list_sentences[0]] = list_sentences[0] sentences.append(list_sentences[0] + extra_str) orig_sentences.append(list_sentences[0] + extra_str) elif len(list_sentences) > 1: orig_sentences.append(list_sentences[0] + extra_str) for sent in list_sentences[1:]: mapping[sent] = list_sentences[0] sentences.append(sent + extra_str) else: assert False hparams.task = 'oie' hparams.checkpoint = hparams.oie_model hparams.model_str = 'bert-base-cased' _, _, split_test_dataset, meta_data_vocab, _ = data.process_data( hparams, sentences) split_test_dataloader = DataLoader(split_test_dataset, batch_size=hparams.batch_size, collate_fn=data.pad_data, num_workers=1) model = predict(hparams, None, meta_data_vocab, None, None, split_test_dataloader, mapping=mapping, conj_word_mapping=conj_word_mapping, all_sentences=all_sentences) if 'labels' in hparams.type: label_lines = get_labels(hparams, model, sentences, orig_sentences, sentences_indices) f = open(hparams.out + '.labels', 'w') f.write('\n'.join(label_lines)) f.close() if hparams.rescoring: print() print("Starting re-scoring ...") print() sentence_line_nums, prev_line_num, no_extractions = set(), 0, dict() for sentence_str in model.all_predictions_oie: sentence_str = sentence_str.strip('\n') num_extrs = len(sentence_str.split('\n')) - 1 if num_extrs == 0: if curr_line_num not in no_extractions: no_extractions[curr_line_num] = [] no_extractions[curr_line_num].append(sentence_str) continue curr_line_num = prev_line_num + num_extrs sentence_line_nums.add( curr_line_num ) # check extra empty lines, example with no extractions prev_line_num = curr_line_num # testing rescoring inp_fp = model.predictions_f_allennlp rescored = rescore(inp_fp, model_dir=hparams.rescore_model, batch_size=256) all_predictions, sentence_str = [], '' for line_i, line in enumerate(rescored): fields = line.split('\t') sentence = fields[0] confidence = float(fields[2]) if line_i == 0: sentence_str = f'{sentence}\n' exts = [] if line_i in sentence_line_nums: exts = sorted(exts, reverse=True, key=lambda x: float(x.split()[0][:-1])) exts = exts[:hparams.num_extractions] all_predictions.append(sentence_str + ''.join(exts)) sentence_str = f'{sentence}\n' exts = [] if line_i in no_extractions: for no_extraction_sentence in no_extractions[line_i]: all_predictions.append(f'{no_extraction_sentence}\n') arg1 = re.findall( "<arg1>.*</arg1>", fields[1])[0].strip('<arg1>').strip('</arg1>').strip() rel = re.findall( "<rel>.*</rel>", fields[1])[0].strip('<rel>').strip('</rel>').strip() arg2 = re.findall( "<arg2>.*</arg2>", fields[1])[0].strip('<arg2>').strip('</arg2>').strip() extraction = Extraction(pred=rel, head_pred_index=None, sent=sentence, confidence=math.exp(confidence), index=0) extraction.addArg(arg1) extraction.addArg(arg2) if hparams.type == 'sentences': ext_str = data.ext_to_sentence(extraction) + '\n' else: ext_str = data.ext_to_string(extraction) + '\n' exts.append(ext_str) exts = sorted(exts, reverse=True, key=lambda x: float(x.split()[0][:-1])) exts = exts[:hparams.num_extractions] all_predictions.append(sentence_str + ''.join(exts)) if line_i + 1 in no_extractions: for no_extraction_sentence in no_extractions[line_i + 1]: all_predictions.append(f'{no_extraction_sentence}\n') if hparams.out != None: print('Predictions written to ', hparams.out) predictions_f = open(hparams.out, 'w') predictions_f.write('\n'.join(all_predictions) + '\n') predictions_f.close() return