コード例 #1
0
ファイル: qa_to_oie.py プロジェクト: yyNoMoon/oie-benchmark
    def loadFile(self, lines):
        sent = ''
        d = {}

        indsForQuestions = defaultdict(lambda: set())

        for line in lines.split('\n'):
            line = line.strip()
            if not line:
                continue
            data = line.split('\t')
            if len(data) == 1:
                if sent:
                    for ex in d[sent]:
                        ex.indsForQuestions = dict(indsForQuestions)
                sent = line
                d[sent] = []
                indsForQuestions = defaultdict(lambda: set())

            else:
                pred = data[0]
                pred_index = data[1]
                cur = Extraction((pred, all_index(sent, pred, matchCase = False)),
                                 sent,
                                 confidence = 1.0,
                                 question_dist = self.question_dist)
                for q, a in zip(data[2::2], data[3::2]):
                    indices = all_index(sent, a, matchCase = False)
                    cur.addArg((a, indices), q)
                    indsForQuestions[q] = indsForQuestions[q].union(indices)

                if sent:
                    if cur.noPronounArgs():
                        d[sent].append(cur)
        return d
コード例 #2
0
    def read(self, fn):
        # d = defaultdict(lambda: [])
        d = dict()
        with open(fn) as fin:
            for line_ind, line in enumerate(fin):
                #print(line)
                data = line.strip().split('\t')
                #print(data)
                text, rel = data[:2]
                args = data[2:]
                confidence = 1

                curExtraction = Extraction(pred=rel.strip(),
                                           head_pred_index=None,
                                           sent=text.strip(),
                                           confidence=float(confidence),
                                           index=line_ind)
                for arg in args:
                    if "C: " in arg:
                        continue
                    curExtraction.addArg(arg.strip())

                if text.strip() not in d:
                    d[text.strip()] = []
                d[text.strip()].append(curExtraction)
        self.oie = d
コード例 #3
0
    def read(self, fn):
        """
        Read a tabbed format line
        Each line consists of:
        sent, prob, pred, arg1, arg2, ...
        """
        d = {}
        ex_index = 0
        with open(fn) as fin:
            for line in fin:
                if not line.strip():
                    continue
                data = line.strip().split('\t')
                text, confidence, rel = data[:3]
                curExtraction = Extraction(
                    pred=rel,
                    head_pred_index=None,
                    sent=text,
                    confidence=float(confidence),
                    question_dist=
                    "./question_distributions/dist_wh_sbj_obj1.json",
                    index=ex_index)
                ex_index += 1

                for arg in data[3:]:
                    curExtraction.addArg(arg)

                d[text] = d.get(text, []) + [curExtraction]
        self.oie = d
コード例 #4
0
    def read(self, fn):
        """
        Read a tabbed format line
        Each line consists of:
        sent, prob, pred, arg1, arg2, ...
        """
        d = {}
        d_list = []
        ex_index = 0
        with open(fn) as fin:
            for line in fin:
                if not line.strip():
                    continue
                data = line.strip().split('\t')
                text, confidence, rel = data[:3]
                rel = rel.rsplit(
                    '##')  # split from right to avoid # symbol in tokens
                pred_pos = int(rel[1]) if len(rel) == 2 else None
                head_pred_index = pred_pos  # TODO: head_pred_index is not necessarily the first predicate index
                # rel is a tuple, where the first element is str and the second element is a list of index
                rel = (rel[0],
                       [pred_pos + i for i, w in enumerate(rel[0].split(' '))
                        ]) if len(rel) == 2 else rel[0]
                curExtraction = Extraction(
                    pred=rel,
                    pred_pos=pred_pos,
                    head_pred_index=head_pred_index,
                    sent=text,
                    confidence=float(confidence),
                    question_dist=
                    "./question_distributions/dist_wh_sbj_obj1.json",
                    index=ex_index,
                    raw=line.strip())
                ex_index += 1

                for arg in data[3:]:
                    arg = arg.rsplit(
                        '##')  # split from right to avoid # symbol in tokens
                    arg_pos = int(arg[1]) if len(arg) == 2 else None
                    # arg is a tuple, where the first element is str and the second element is a list of index
                    arg = (arg[0], [
                        arg_pos + i for i, w in enumerate(arg[0].split(' '))
                    ]) if len(arg) == 2 else arg[0]
                    curExtraction.addArg(arg, arg_pos)

                if text not in d:
                    d[text] = []
                    d_list.append([])
                d[text].append(curExtraction)
                d_list[-1].append(curExtraction)
        self.oie = d
        self.oie_list = d_list
コード例 #5
0
 def read(self, fn):
     d = {}
     with open(fn) as fin:
         for line in fin:
             data = line.strip().split('\t')
             arg1, rel, arg2 = data[2:5]
             confidence = data[11] 
             text = data[12]
             curExtraction = Extraction(pred = rel, sent = text, confidence = float(confidence))
             curExtraction.addArg(arg1)
             curExtraction.addArg(arg2)
             d[text] = d.get(text, []) + [curExtraction]
     self.oie = d
コード例 #6
0
 def read(self, fn):
     d = {}
     with open(fn) as fin:
         fin.readline()  #remove header
         for line in fin:
             data = line.strip().split('\t')
             confidence, arg1, rel, arg2, enabler, attribution, text = data[:
                                                                            7]
             curExtraction = Extraction(pred=rel,
                                        sent=text,
                                        confidence=float(confidence))
             curExtraction.addArg(arg1)
             curExtraction.addArg(arg2)
             d[text] = d.get(text, []) + [curExtraction]
     self.oie = d
コード例 #7
0
ファイル: goldReader.py プロジェクト: BIU-NLP/oie-benchmark
    def read(self, fn):
        d = defaultdict(lambda: [])
        with open(fn) as fin:
            for line in fin:
                data = line.strip().split('\t')
                text, rel = data[:2]
                args = data[2:]
                confidence = 1

                curExtraction = Extraction(pred=rel,
                                           sent=text,
                                           confidence=float(confidence))
                for arg in args:
                    curExtraction.addArg(arg)

                d[text].append(curExtraction)
        self.oie = d
コード例 #8
0
        def process_extraction(extraction, sentence, score):
            # rel, arg1, arg2, loc, time = [], [], [], [], []
            rel, arg1, arg2, loc_time, args = [], [], [], [], []
            tag_mode = 'none'
            rel_case = 0
            for i, token in enumerate(sentence):
                if '[unused' in token:
                    if extraction[i].item() == 2:
                        rel_case = int(
                            re.search('\[unused(.*)\]', token).group(1))
                    continue
                if extraction[i] == 1:
                    arg1.append(token)
                if extraction[i] == 2:
                    rel.append(token)
                if extraction[i] == 3:
                    arg2.append(token)
                if extraction[i] == 4:
                    loc_time.append(token)

            rel = ' '.join(rel).strip()
            if rel_case == 1:
                rel = 'is ' + rel
            elif rel_case == 2:
                rel = 'is ' + rel + ' of'
            elif rel_case == 3:
                rel = 'is ' + rel + ' from'

            arg1 = ' '.join(arg1).strip()
            arg2 = ' '.join(arg2).strip()
            args = ' '.join(args).strip()
            loc_time = ' '.join(loc_time).strip()
            if not self.hparams.no_lt:
                arg2 = (arg2 + ' ' + loc_time + ' ' + args).strip()
            sentence_str = ' '.join(sentence).strip()

            extraction = Extraction(pred=rel,
                                    head_pred_index=None,
                                    sent=sentence_str,
                                    confidence=score,
                                    index=0)
            extraction.addArg(arg1)
            extraction.addArg(arg2)

            return extraction
コード例 #9
0
    def read(self, fn):
        d = {}
        with open(fn) as fin:
            for line in fin:
                if not line.strip():
                    continue
                data = line.strip().split('\t')
                confidence, text, rel = data[:3]
                curExtraction = Extraction(pred=rel,
                                           sent=text,
                                           confidence=float(confidence))

                for arg in data[4::2]:
                    curExtraction.addArg(arg)

                d[text] = d.get(text, []) + [curExtraction]
        self.oie = d
        self.normalizeConfidence()
コード例 #10
0
ファイル: metric.py プロジェクト: yyht/openie6
    def _process_allenlp_format(self, lines):
        assert self._all_predictions == {}
        for line in lines:
            extr = line.split('\t')
            sentence = extr[0]
            confidence = float(extr[2])
            
            arg1 = re.findall("<arg1>.*</arg1>", extr[1])[0].strip('<arg1>').strip('</arg1>').strip()
            rel = re.findall("<rel>.*</rel>", extr[1])[0].strip('<rel>').strip('</rel>').strip()
            arg2 = re.findall("<arg2>.*</arg2>", extr[1])[0].strip('<arg2>').strip('</arg2>').strip()
            
            extraction = Extraction(pred=rel, head_pred_index=None, sent=sentence, confidence=confidence, index=0)
            extraction.addArg(arg1)
            extraction.addArg(arg2)

            if sentence not in self._all_predictions:
                self._all_predictions[sentence] = []
            self._all_predictions[sentence] = extraction
コード例 #11
0
ファイル: goldReader.py プロジェクト: piekey1994/IDEF
    def read(self, fn):
        d = defaultdict(lambda: [])
        with open(fn, encoding='utf-8') as fin:
            for line_ind, line in enumerate(fin):
                data = line.strip().split('\t')
                text, rel = data[:2]
                args = data[2:]
                confidence = 1

                curExtraction = Extraction(pred=rel,
                                           head_pred_index=None,
                                           sent=text,
                                           confidence=float(confidence),
                                           index=line_ind)
                for arg in args:
                    curExtraction.addArg(arg)

                d[text].append(curExtraction)
        self.oie = d
コード例 #12
0
 def read(self, fn):
     d = {}
     with open(fn) as fin:
         for line in fin:
             data = line.strip().split('\t')
             confidence = data[0]
             if not all(data[2:5]):
                 continue
             arg1, rel, arg2 = [
                 s[s.index('(') + 1:s.index(',List(')] for s in data[2:5]
             ]
             text = data[5]
             curExtraction = Extraction(pred=rel,
                                        sent=text,
                                        confidence=float(confidence))
             curExtraction.addArg(arg1)
             curExtraction.addArg(arg2)
             d[text] = d.get(text, []) + [curExtraction]
     self.oie = d
コード例 #13
0
    def read(self, fn):
        d = {}
        with open(fn) as fin:
            for line in fin:
                data = line.strip().split('\t')
                if len(data) == 1:
                    text = data[0]
                elif len(data) == 5:
                    arg1, rel, arg2 = [s[1:-1] for s in data[1:4]]
                    confidence = data[4]

                    curExtraction = Extraction(pred=rel,
                                               sent=text,
                                               confidence=float(confidence))
                    curExtraction.addArg(arg1)
                    curExtraction.addArg(arg2)
                    d[text] = d.get(text, []) + [curExtraction]
        self.oie = d
        self.normalizeConfidence()
コード例 #14
0
    def read(self, fn):
        d = {}
        with open(fn) as fin:
            for line in fin:
                data = line.strip().split('\t')
                #confidence = data[0]
                #if not all(data[1]):
                #    continue
                text = data[0]
                rel = data[1]
                args = [s[s.index('::') + 2:] for s in data[2:]]
                #args = data[4].strip().split(');')
                ar1 = True
                arg1 = ''
                args = []
                for s in data[2:]:
                    if s[:s.index('::')]=='V':
                         ar1 = False
                         continue
                    if ar1:
                         arg1 = arg1 + ' ' + s[s.index('::') + 2:]
                    else:
                         args.append(s[s.index('::') + 2:])
                    
#                print arg1, rel, args
                curExtraction = Extraction(pred = rel, sent = text, confidence = float(0.1))
                curExtraction.addArg(arg1)
                for arg in args:
                    curExtraction.addArg(arg)
                d[text] = d.get(text, []) + [curExtraction]
        self.oie = d
コード例 #15
0
    def read(self, fn):
        d = {}
        with open(fn) as fin:
            for line in fin:
                data = line.strip().split('\t')
                confidence = data[0]

                if not all(data[2:5]):
                    continue
                arg1, rel = [s[s.index('(') + 1:s.index(',List(')] for s in data[2:4]]
                #args = data[4].strip().split(');')
                #print arg2s
                args = [s[s.index('(') + 1:s.index(',List(')] for s in data[4].strip().split(');')]
#                if arg1 == "the younger La Flesche":
#                    print len(args)
                text = data[5]
                if data[1]:
                    #print arg1, rel
                    s = data[1]
                    if not (arg1 + ' ' + rel).startswith(s[s.index('(') + 1:s.index(',List(')]):
                        #print "##########Not adding context" 
                        arg1 = s[s.index('(') + 1:s.index(',List(')] + ' ' + arg1
                        #print arg1 + rel, ",,,,, ", s[s.index('(') + 1:s.index(',List(')] 
                #curExtraction = Extraction(pred = rel, sent = text, confidence = float(confidence))
                curExtraction = Extraction(pred = rel, head_pred_index = -1, sent = text, confidence = float(confidence))
                curExtraction.addArg(arg1)
                for arg in args:
                    curExtraction.addArg(arg)
                d[text] = d.get(text, []) + [curExtraction]
        self.oie = d
コード例 #16
0
ファイル: single.py プロジェクト: yyht/openie6
def get_predictions(fp):
    old_sentence = ''
    all_predictions = dict()
    for line in open(fp, 'r'):
        line = line.strip('\n')
        sentence, extraction, confidence = line.split('\t')
        if old_sentence == '':
            old_sentence = sentence
            all_predictions[sentence] = []
        if old_sentence != sentence:
            all_predictions[sentence] = []
            old_sentence = sentence
        try:
            arg1 = line[line.index('<arg1>') + 6:line.index('</arg1>')]
        except:
            arg1 = ""
        try:
            rel = line[line.index('<rel>') + 5:line.index('</rel>')]
        except:
            rel = ""
        try:
            arg2 = line[line.index('<arg2>') + 6:line.index('</arg2>')]
        except:
            arg2 = ""

        if not (arg1 or arg2 or rel):
            continue
        curExtraction = Extraction(pred=rel,
                                   head_pred_index=-1,
                                   sent=sentence,
                                   confidence=float(confidence))
        curExtraction.addArg(arg1)
        curExtraction.addArg(arg2)
        all_predictions[sentence].append(curExtraction)

    return all_predictions
コード例 #17
0
ファイル: qa_to_oie.py プロジェクト: Joshua-Samjaya/elmo-oie
    def loadFile(self, lines):
        sent = ''
        d = {}

        indsForQuestions = defaultdict(lambda: set())

        for line in lines.split('\n'):
            line = line.strip()
            if not line:
                continue
            data = line.split('\t')
            if len(data) == 1:
                if sent:
                    for ex in d[sent]:
                        ex.indsForQuestions = dict(indsForQuestions)
                sent = line
                d[sent] = []
                indsForQuestions = defaultdict(lambda: set())

            else:
                pred = self.preproc(data[0])
                pred_indices = map(int, eval(data[1]))
                head_pred_index = int(data[2])
                cur = Extraction((pred, [pred_indices]),
                                 head_pred_index,
                                 sent,
                                 confidence=1.0)

                for q, a in zip(data[3::2], data[4::2]):
                    preproc_arg = self.preproc(a)
                    if not preproc_arg:
                        logging.warn("Argument reduced to None: {}".format(a))
                    indices = fuzzy_match_phrase(preproc_arg.split(" "),
                                                 sent.split(" "))
                    cur.addArg((preproc_arg, indices), q)
                    indsForQuestions[q] = indsForQuestions[q].union(
                        flatten(indices))

                if sent:
                    if cur.noPronounArgs():
                        cur.resolveAmbiguity()
                        d[sent].append(cur)

        return d
コード例 #18
0
    def read(self):
        ###What to do about generalized questions that are not yet in this distribution set?####
        ###Use analyze.py####
        question_dist = dict([(q, dict([(int(loc), cnt)
                                             for (loc, cnt)
                                             in dist.iteritems()]))
                                   for (q, dist)
                                   in json.load(open(self.dist_file)).iteritems()]) \
                                       if self.dist_file\
                                          else {}
        ##pull sentence##
        ##pull predicate##
        ##pull qa pairs with 5/6 or more validations##
        ##possibly preprocess at this step##
        #load json lines data into list
        qa_path = self.qa_path
        data = []
        with codecs.open(qa_path, 'r', encoding='utf8') as f:
            for line in f:
                data.append(json.loads(line))

        f_out = open(self.output_file, "w")
        jsonl_out = open('science_eval_sent.jsonl', "w")
        eval_out = open('science_eval.oie', "w")
        verb_types = []
        #parse qa data
        for item in data:
            #for item in data[(len(data)-100):(len(data) - 1)]:
            sent_id = item["sentenceId"].encode('utf-8')
            #remove science
            if sent_id.split(':')[0] != 'TQA':
                continue
            sentence_tokens = item["sentenceTokens"]
            sentence = ' '.join(sentence_tokens)
            sentence = sentence.encode('utf-8')
            if output_eval:
                jsonl_out.write("{" + '"' + "sentence" + '"' + ": " + '"' +
                                sentence + '"' + "}" + '\n')

            for _, verb_entry in item["verbEntries"].items():
                verb_index = verb_entry["verbIndex"]
                verb_inflected_forms = verb_entry["verbInflectedForms"]
                base_pred = sentence_tokens[verb_index]
                surfacePred = base_pred
                answer_list = []
                questions = []

                for _, question_label in verb_entry["questionLabels"].items():
                    #print(question_label["answerJudgments"])
                    answers = len(question_label["answerJudgments"])
                    valid_answers = len([
                        ans for ans in question_label["answerJudgments"]
                        if ans["isValid"]
                    ])
                    if valid_answers / (answers * 1.0) < self.min_correct:
                        #do not log this question set
                        continue
                    q_string = question_label['questionString']
                    ans_spans = []
                    for ans in question_label["answerJudgments"]:
                        if ans["isValid"]:
                            for span in ans["spans"]:
                                ans_spans.append(span)
                    #add long/short flag here
                    consolidated_spans = consolidate_answers(
                        ans_spans, self.length)
                    #look up answers in sentence tokens
                    lookup_ans = lambda ans, sentence: ' '.join(sentence[ans[
                        0]:ans[1]])
                    consolidated_ans = map(lookup_ans, consolidated_spans,
                                           [sentence_tokens] *
                                           len(consolidated_spans))
                    #here we can acquire of the question slots
                    wh = question_label["questionSlots"]["wh"].split()
                    wh = '_'.join(wh)
                    aux = question_label["questionSlots"]["aux"].split()
                    aux = '_'.join(aux)
                    subj = question_label["questionSlots"]["subj"].split()
                    subj = '_'.join(subj)
                    #iterate through and check verb types for len > 2
                    verb_type = question_label['questionSlots']['verb']
                    inflected_verb = verb_inflected_forms[verb_type.split()
                                                          [-1]]
                    if len(verb_type.split()) == 1:
                        trg = inflected_verb
                    else:
                        trg = verb_type.split()[:-1]
                        trg.append(inflected_verb)
                        trg = "_".join(trg)
                    obj1 = question_label["questionSlots"]["obj"].split()
                    obj1 = '_'.join(obj1)
                    pp = question_label["questionSlots"]["prep"].split()
                    pp = '_'.join(pp)
                    obj2 = question_label["questionSlots"]["obj2"].split()
                    obj2 = '_'.join(obj2)

                    slotted_q = " ".join(
                        (wh, aux, subj, trg, obj1, pp, obj2, "?"))

                    curSurfacePred = augment_pred_with_question(
                        base_pred, slotted_q)
                    if len(curSurfacePred) > len(surfacePred):
                        surfacePred = curSurfacePred

                    questions.append(slotted_q)
                    answer_list.append(consolidated_ans)
                    #print wh, subj, obj1
                    #for ans in consolidated_spans:
                    #question_answer_pairs.append((slotted_q,' '.join(sentence_tokens[ans[0]:ans[1]])))

                    ####this needs to be more sophisticated
                    ###for each predicate - create a list of qa pairs, s.t. each unique combination of questions and answers appear
                    ### e.g. 2 quesions each with 2 answers, leads to four qa pairs ((q1,a1),(q2,a1), ((q1,a1),(q2,a2)), ect.
                    ### each one of these sets will lead to an extraction

                    #now we have the augmented Pred with aux
                    #might want to revisit this methodology

                    #augment verb with aux
# =============================================================================
#                     if aux in QA_SRL_AUX_MODIFIERS:
#
#                         if len(verb_type.split()) == 1:
#                                 verb = aux + " " + inflected_verb
#
#                         else:
#                             #add the first modifier in verb tpye
#                             #may need to revisit - in previous approach, it looks like only the surface verb and aux were sent
#                             verb = aux + " " + verb_type.split()[0] + " " + inflected_verb
#
#                     else:
#                         if len(verb_type.split()) == 1:
#                                 verb = inflected_verb
#
#                         else:
#                             verb = verb_type.split()[0] + " " + inflected_verb
#
# =============================================================================
##now we have sentence tokens, verb index, valid question, valid answer spans
##need question blanks for augement pred with question
###for each predicate - create a list of qa pairs, s.t. each unique combination of questions and answers appear
### e.g. 2 quesions each with 2 answers, leads to four qa pairs ((q1,a1),(q2,a1)), ((q1,a1),(q2,a2)), ect.
### each one of these sets will lead to an extraction
##noticing many instances where the rare answer doesn't make sense
##e.g. Clouds that form on the ground are called fog

##what about questions that encode a similar argument? e.g. what for and why
##These organisms need the oxygen plants release to get energy out of the food .
#[(u'what _ _ needs something _ _ ?', u'organisms'), (u'why does something need something _ _ ?', u'to get energy out of the food'), (u'what does something need _ _ _ ?', u'oxygen'), (u'what does someone need something for _ ?', u'to get energy out of the food')]
#need    need    organisms       oxygen  for to get energy out of the food       to get energy out of the food
#Considering the following edits - for each argument, only take the first question that appears for it
#Considering the following edits - Only consider an answer span if it apoears by more than one annotator. Rare answers tend to be misleading
                surfacePred = surfacePred.encode('utf-8')
                base_pred = base_pred.encode('utf-8')
                #pred_indices = all_index(sentence, base_pred, matchCase = False)
                augmented_pred_indices = fuzzy_match_phrase(
                    surfacePred.split(" "), sentence.split(" "))
                #print augmented_pred_indices
                if not augmented_pred_indices:
                    #find equivalent of pred_index
                    head_pred_index = [verb_index]

                else:
                    head_pred_index = augmented_pred_indices[0]
                for ans_set in list(itertools.product(*answer_list)):
                    cur = Extraction((surfacePred, [head_pred_index]),
                                     verb_index,
                                     sentence,
                                     confidence=1.0,
                                     question_dist=self.dist_file)
                    #print 'Extraction', (surfacePred, [head_pred_index]), verb_index, sentence
                    q_as = zip(questions, ans_set)
                    if len(q_as) == 0:
                        continue
                    for q_a in q_as:
                        q = q_a[0].encode('utf-8')
                        a = q_a[1].encode('utf-8')
                        preproc_arg = self.preproc(a)
                        if not preproc_arg:
                            logging.warn(
                                "Argument reduced to None: {}".format(a))
                        indices = fuzzy_match_phrase(preproc_arg.split(" "),
                                                     sentence.split(" "))
                        #print 'q', q, 'preproc arg', preproc_arg, 'indices ', indices
                        cur.addArg((preproc_arg, indices), q)

                    if cur.noPronounArgs():

                        #print 'arguments', (preproc_arg,indices), q
                        cur.resolveAmbiguity()
                        if self.write:
                            #print sentence
                            #print q_as
                            if self.sort:
                                cur.getSortedArgs()
                            #print(cur.conll(external_feats = [1,2]))
                            f_out.write(cur.conll(external_feats=[1, 2]))
                            f_out.write('\n')
                        ### now to get the ordering down
                        ### seems like now and from before, the arguments are in the order they appear in the qa file...
                        ### get sent and word ID
                        ### generating an output file for downstream evaluation on OIE-2016
                        ### evaluation framework
                        if self.output_eval:
                            if self.sort:
                                cur.getSortedArgs()
                            eval_out.write(sentence + ' \t' + cur.__str__() +
                                           '\n')

                        self.extractions.append(cur)
コード例 #19
0
ファイル: run.py プロジェクト: yyht/openie6
def splitpredict(hparams, checkpoint_callback, meta_data_vocab,
                 train_dataloader, val_dataloader, test_dataloader,
                 all_sentences):
    mapping, conj_word_mapping = {}, {}
    hparams.write_allennlp = True
    if hparams.split_fp == '':
        hparams.task = 'conj'
        hparams.checkpoint = hparams.conj_model
        hparams.model_str = 'bert-base-cased'
        hparams.mode = 'predict'
        model = predict(hparams, None, meta_data_vocab, None, None,
                        test_dataloader, all_sentences)
        conj_predictions = model.all_predictions_conj
        sentences_indices = model.all_sentence_indices_conj
        # conj_predictions = model.predictions
        # sentences_indices = model.all_sentence_indices
        assert len(conj_predictions) == len(sentences_indices)
        all_conj_words = model.all_conjunct_words_conj

        sentences, orig_sentences = [], []
        for i, sentences_str in enumerate(conj_predictions):
            list_sentences = sentences_str.strip('\n').split('\n')
            conj_words = all_conj_words[i]
            if len(list_sentences) == 1:
                orig_sentences.append(list_sentences[0] +
                                      ' [unused1] [unused2] [unused3]')
                mapping[list_sentences[0]] = list_sentences[0]
                conj_word_mapping[list_sentences[0]] = conj_words
                sentences.append(list_sentences[0] +
                                 ' [unused1] [unused2] [unused3]')
            elif len(list_sentences) > 1:
                orig_sentences.append(list_sentences[0] +
                                      ' [unused1] [unused2] [unused3]')
                conj_word_mapping[list_sentences[0]] = conj_words
                for sent in list_sentences[1:]:
                    mapping[sent] = list_sentences[0]
                    sentences.append(sent + ' [unused1] [unused2] [unused3]')
            else:
                assert False
        sentences.append('\n')

        count = 0
        for sentence_indices in sentences_indices:
            if len(sentence_indices) == 0:
                count += 1
            else:
                count += len(sentence_indices)
        assert count == len(sentences) - 1

    else:
        with open(hparams.predict_fp, 'r') as f:
            lines = f.read()
            lines = lines.replace("\\", "")

        sentences = []
        orig_sentences = []
        extra_str = " [unused1] [unused2] [unused3]"
        for line in lines.split('\n\n'):
            if len(line) > 0:
                list_sentences = line.strip().split('\n')
                if len(list_sentences) == 1:
                    mapping[list_sentences[0]] = list_sentences[0]
                    sentences.append(list_sentences[0] + extra_str)
                    orig_sentences.append(list_sentences[0] + extra_str)
                elif len(list_sentences) > 1:
                    orig_sentences.append(list_sentences[0] + extra_str)
                    for sent in list_sentences[1:]:
                        mapping[sent] = list_sentences[0]
                        sentences.append(sent + extra_str)
                else:
                    assert False

    hparams.task = 'oie'
    hparams.checkpoint = hparams.oie_model
    hparams.model_str = 'bert-base-cased'
    _, _, split_test_dataset, meta_data_vocab, _ = data.process_data(
        hparams, sentences)
    split_test_dataloader = DataLoader(split_test_dataset,
                                       batch_size=hparams.batch_size,
                                       collate_fn=data.pad_data,
                                       num_workers=1)

    model = predict(hparams,
                    None,
                    meta_data_vocab,
                    None,
                    None,
                    split_test_dataloader,
                    mapping=mapping,
                    conj_word_mapping=conj_word_mapping,
                    all_sentences=all_sentences)

    if 'labels' in hparams.type:
        label_lines = get_labels(hparams, model, sentences, orig_sentences,
                                 sentences_indices)
        f = open(hparams.out + '.labels', 'w')
        f.write('\n'.join(label_lines))
        f.close()

    if hparams.rescoring:
        print()
        print("Starting re-scoring ...")
        print()

        sentence_line_nums, prev_line_num, no_extractions = set(), 0, dict()
        for sentence_str in model.all_predictions_oie:
            sentence_str = sentence_str.strip('\n')
            num_extrs = len(sentence_str.split('\n')) - 1
            if num_extrs == 0:
                if curr_line_num not in no_extractions:
                    no_extractions[curr_line_num] = []
                no_extractions[curr_line_num].append(sentence_str)
                continue
            curr_line_num = prev_line_num + num_extrs
            sentence_line_nums.add(
                curr_line_num
            )  # check extra empty lines, example with no extractions
            prev_line_num = curr_line_num

        # testing rescoring
        inp_fp = model.predictions_f_allennlp
        rescored = rescore(inp_fp,
                           model_dir=hparams.rescore_model,
                           batch_size=256)

        all_predictions, sentence_str = [], ''
        for line_i, line in enumerate(rescored):
            fields = line.split('\t')
            sentence = fields[0]
            confidence = float(fields[2])

            if line_i == 0:
                sentence_str = f'{sentence}\n'
                exts = []
            if line_i in sentence_line_nums:
                exts = sorted(exts,
                              reverse=True,
                              key=lambda x: float(x.split()[0][:-1]))
                exts = exts[:hparams.num_extractions]
                all_predictions.append(sentence_str + ''.join(exts))
                sentence_str = f'{sentence}\n'
                exts = []
            if line_i in no_extractions:
                for no_extraction_sentence in no_extractions[line_i]:
                    all_predictions.append(f'{no_extraction_sentence}\n')

            arg1 = re.findall(
                "<arg1>.*</arg1>",
                fields[1])[0].strip('<arg1>').strip('</arg1>').strip()
            rel = re.findall(
                "<rel>.*</rel>",
                fields[1])[0].strip('<rel>').strip('</rel>').strip()
            arg2 = re.findall(
                "<arg2>.*</arg2>",
                fields[1])[0].strip('<arg2>').strip('</arg2>').strip()
            extraction = Extraction(pred=rel,
                                    head_pred_index=None,
                                    sent=sentence,
                                    confidence=math.exp(confidence),
                                    index=0)
            extraction.addArg(arg1)
            extraction.addArg(arg2)
            if hparams.type == 'sentences':
                ext_str = data.ext_to_sentence(extraction) + '\n'
            else:
                ext_str = data.ext_to_string(extraction) + '\n'
            exts.append(ext_str)

        exts = sorted(exts,
                      reverse=True,
                      key=lambda x: float(x.split()[0][:-1]))
        exts = exts[:hparams.num_extractions]
        all_predictions.append(sentence_str + ''.join(exts))

        if line_i + 1 in no_extractions:
            for no_extraction_sentence in no_extractions[line_i + 1]:
                all_predictions.append(f'{no_extraction_sentence}\n')

        if hparams.out != None:
            print('Predictions written to ', hparams.out)
            predictions_f = open(hparams.out, 'w')
        predictions_f.write('\n'.join(all_predictions) + '\n')
        predictions_f.close()
        return