예제 #1
0
    def prepare_pet_mlm_batch(self, batch, mode="PET1"):
        '''
        Prepare for train
        :param batch:
        :return:
        '''

        list_passage = batch["input"]["passage"]
        list_question = batch["input"]["question"]
        list_true_entity = batch["input"]["true_entity"]
        list_false_entities = batch["input"]["false_entities"]
        list_lbl = batch["output"]["lbl"]

        bs = len(batch["input"]["passage"])

        prep_lbl = np.random.randint(self.num_lbl, size=bs)
        tgt = torch.from_numpy(prep_lbl).long() == batch["output"]["lbl"]

        list_orig_input_ids = []
        list_masked_input_ids = []

        for b_idx, (p, q, te, fe, lbl) in enumerate(
                zip(list_passage, list_question, list_true_entity,
                    list_false_entities, list_lbl)):
            txt_split_tuple = []

            true_num_lbl_tok = self.get_lbl_num_lbl_tok(te)
            max_num_lbl_tok = true_num_lbl_tok
            for idx, wrong_enty in enumerate(fe):
                num_lbl_tok = self.get_lbl_num_lbl_tok(wrong_enty)
                if num_lbl_tok > max_num_lbl_tok:
                    max_num_lbl_tok = num_lbl_tok

            txt_trim = -1
            pattern = self.pet_patterns[self._pet_names.index(mode)]

            for idx, txt_split in enumerate(pattern):
                txt_split_inp = txt_split.replace("[PASSAGE]", p).replace(
                    "[QUESTION]", q + " [SEP]").replace("@highlight", "-")
                txt_split_tuple.append(txt_split_inp)

                # Trim the paragraph
                if "[PASSAGE]" in txt_split:
                    txt_trim = idx

            orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt(
                self.tokenizer, self.config, txt_split_tuple[0],
                txt_split_tuple[1], txt_split_tuple[2], txt_trim)
            list_orig_input_ids.append(orig_input_ids)
            list_masked_input_ids.append(masked_input_ids)

        return torch.tensor(list_orig_input_ids).to(device), torch.tensor(
            list_masked_input_ids).to(device), prep_lbl, tgt.to(device)
예제 #2
0
    def prepare_pet_mlm_batch(self, batch, mode="PET1"):
        '''
        Prepare for train

        :param batch:
        :return:
        '''

        list_question = batch["input"]["question"]
        list_passage = batch["input"]["passage"]
        list_answer = batch["input"]["answer"]

        bs = len(batch["input"]["answer"])

        prep_lbl = np.random.randint(self.num_lbl, size=bs)
        tgt = torch.from_numpy(prep_lbl).long() == batch["output"]["lbl"]

        pattern, label = self.pet_pvps[self._pet_names.index(mode)]

        list_orig_input_ids = []
        list_masked_input_ids = []

        for b_idx, (p, q, a, lbl) in enumerate(
                zip(list_passage, list_question, list_answer, prep_lbl)):
            txt_split_tuple = []

            txt_trim = -1

            for idx, txt_split in enumerate(pattern):
                txt_split_inp = txt_split.replace("[PARAGRAPH]", p).replace(
                    "[QUESTION]", q).replace("[ANSWER]",
                                             a).replace("[MASK]", label[lbl])
                txt_split_tuple.append(txt_split_inp)

                # Trim the paragraph
                if "[PARAGRAPH]" in txt_split:
                    txt_trim = idx

            orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt(
                self.tokenizer, self.config, txt_split_tuple[0],
                txt_split_tuple[1], txt_split_tuple[2], txt_trim)
            list_orig_input_ids.append(orig_input_ids)
            list_masked_input_ids.append(masked_input_ids)

        return torch.tensor(list_orig_input_ids).to(device), torch.tensor(
            list_masked_input_ids).to(device), prep_lbl, tgt.to(device)
예제 #3
0
    def prepare_pet_mlm_batch(self, batch, mode="PET1"):
        '''
        Prepare for train

        :param batch:
        :return:
        '''

        list_hypothesis = batch["input"]["hypothesis"]
        list_premise = batch["input"]["premise"]

        bs = len(batch["input"]["hypothesis"])

        prep_lbl = np.random.randint(self.num_lbl, size=bs)
        tgt = torch.from_numpy(prep_lbl).long() == batch["output"]["lbl"]

        pattern, label = self.pet_pvps[self._pet_names.index(mode)]

        list_orig_input_ids = []
        list_masked_input_ids = []

        for b_idx, (h, p, lbl) in enumerate(
                zip(list_hypothesis, list_premise, prep_lbl)):
            txt_split_tuple = []

            txt_trim = -1

            for idx, txt_split in enumerate(pattern):
                txt_split_inp = txt_split.replace("[HYPOTHESIS]", h).replace(
                    "[PREMISE]", p).replace("[MASK]", label[lbl])

                txt_split_tuple.append(txt_split_inp)

                # Trim the paragraph
                if "[PREMISE]" in txt_split:
                    txt_trim = idx

            orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt(
                self.tokenizer, self.config, txt_split_tuple[0],
                txt_split_tuple[1], txt_split_tuple[2], txt_trim)
            list_orig_input_ids.append(orig_input_ids)
            list_masked_input_ids.append(masked_input_ids)

        return torch.tensor(list_orig_input_ids).to(device), torch.tensor(
            list_masked_input_ids).to(device), prep_lbl, tgt.to(device)
예제 #4
0
    def prepare_eval_pet_mlm_batch(self, batch, mode="PET1"):
        '''
        Prepare for train

        :param batch:
        :return:
        '''
        list_hypothesis = batch["input"]["hypothesis"]
        list_premise = batch["input"]["premise"]

        list_input_ids = []
        list_masked_input_ids = []

        pattern, label = self.pet_pvps[self._pet_names.index(mode)]

        for b_idx, (h, p) in enumerate(zip(list_hypothesis, list_premise)):
            mask_idx = None

            for l_idx, lbl in enumerate(label):
                txt_split_tuple = []

                for idx, txt_split in enumerate(pattern):
                    txt_split_inp = txt_split.replace(
                        "[HYPOTHESIS]", h).replace("[PREMISE]",
                                                   p).replace("[MASK]", lbl)
                    txt_split_tuple.append(txt_split_inp)

                    # Trim the paragraph
                    if "[PREMISE]" in txt_split:
                        txt_trim = idx

                orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt(
                    self.tokenizer,
                    self.config,
                    txt_split_tuple[0],
                    txt_split_tuple[1],
                    txt_split_tuple[2],
                    txt_trim,
                    mask_idx=mask_idx)
                list_input_ids.append(orig_input_ids)
                list_masked_input_ids.append(masked_input_ids)

        return torch.tensor(list_input_ids).to(device), torch.tensor(
            list_masked_input_ids).to(device)
예제 #5
0
    def prepare_pet_mlm_batch(self, batch, mode="PET1"):
        '''
        Prepare for train

        :param batch:
        :return:
        '''

        list_sentence1 = batch["input"]["sentence1"]
        list_sentence2 = batch["input"]["sentence2"]
        list_word = batch["input"]["word"]

        bs = len(batch["input"]["sentence1"])

        prep_lbl = np.random.randint(self.num_lbl, size=bs)
        tgt = torch.from_numpy(prep_lbl).long() == batch["output"]["lbl"]

        pattern, label = self.pet_pvps[self._pet_names.index(mode)]

        list_orig_input_ids = []
        list_masked_input_ids = []

        for b_idx, (s1, s2, w, lbl) in enumerate(zip(list_sentence1, list_sentence2, list_word, prep_lbl)):
            txt_split_tuple = []

            txt_trim = -1

            for idx, txt_split in enumerate(pattern):
                txt_split_inp = txt_split.replace("[SENTENCE1]", s1).replace("[SENTENCE2]", s2).replace("[WORD]", w).replace("[MASK]",
                                                                                                     label[lbl])
                txt_split_tuple.append(txt_split_inp)

                # Trim the paragraph
                if "[SENTENCE1]" in txt_split:
                    txt_trim = idx

            orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt(self.tokenizer, self.config, txt_split_tuple[0], txt_split_tuple[1], txt_split_tuple[2], txt_trim)
            list_orig_input_ids.append(orig_input_ids)
            list_masked_input_ids.append(masked_input_ids)

        return torch.tensor(list_orig_input_ids).to(device),  torch.tensor(list_masked_input_ids).to(device), prep_lbl, tgt.to(device)
예제 #6
0
    def prepare_pet_mlm_batch(self, batch, mode="PET1"):
        '''
        Prepare for train

        :param batch:
        :return:
        '''
        list_text = batch["input"]["text"]
        list_pronoun = batch["input"]["pronoun"]
        list_noun = batch["input"]["noun"]
        list_lbl = batch["output"]["lbl"]

        list_orig_input_ids = []
        list_masked_input_ids = []

        tgt = torch.tensor([1.]).long()

        for b_idx, (t, p, n, lbl) in enumerate(
                zip(list_text, list_pronoun, list_noun, list_lbl)):
            txt_trim = -1
            pattern = self.pet_patterns[self._pet_names.index(mode)]
            txt_split_tuple = []

            for idx, txt_split in enumerate(pattern):

                txt_split_inp = txt_split.replace("[TEXT]", t).replace(
                    "[NNP]", p).replace("[MASK]", n)
                txt_split_tuple.append(txt_split_inp)

                if "[TEXT]" in txt_split:
                    txt_trim = idx

            orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt(
                self.tokenizer, self.config, txt_split_tuple[0],
                txt_split_tuple[1], txt_split_tuple[2], txt_trim)
            list_orig_input_ids.append(orig_input_ids)
            list_masked_input_ids.append(masked_input_ids)

        return torch.tensor(list_orig_input_ids).to(device), torch.tensor(
            list_masked_input_ids).to(device), None, tgt.to(device)
예제 #7
0
    def prepare_pet_mlm_batch(self, batch, mode="PET1"):
        '''
        Prepare for train

        :param batch:
        :return:
        '''
        # Always use pattern 3 for COPA
        mode = "PET3"

        list_premise = batch["input"]["premise"]
        list_choice1 = batch["input"]["choice1"]
        list_choice2 = batch["input"]["choice2"]
        list_question = batch["input"]["question"]
        list_lbl = batch["output"]["lbl"]

        bs = len(batch["input"]["question"])

        prep_lbl = np.random.randint(self.num_lbl, size=bs)
        tgt = torch.from_numpy(prep_lbl).long() == batch["output"]["lbl"]

        list_orig_input_ids = []
        list_masked_input_ids = []

        for b_idx, (p, c1, c2, ques, lbl) in enumerate(
                zip(list_premise, list_choice1, list_choice2, list_question,
                    list_lbl)):
            txt_split_tuple = []

            txt_trim = -1
            if ques == "cause":
                pet_pvps = self.pet_patterns_cause
            elif ques == "effect":
                pet_pvps = self.pet_patterns_effect
            pattern = pet_pvps[self._pet_names.index(mode)]

            if lbl.item() == 0:
                lbl_choice = c1[:-1]
            elif lbl.item() == 1:
                lbl_choice = c2[:-1]
            else:
                raise ValueError("Invalid Lbl")

            for idx, txt_split in enumerate(pattern):
                txt_split_inp = txt_split.replace("[PREMISE]", p[:-1]).replace(
                    "[CHOICE1]",
                    c1[:-1]).replace("[CHOICE2]",
                                     c2[:-1]).replace("[MASK]", lbl_choice)
                txt_split_tuple.append(txt_split_inp)

                if lbl.item() == 0:
                    # Trim the paragraph
                    if "[PREMISE]" in txt_split:
                        txt_trim = idx
                elif lbl.item() == 1:
                    # Trim the paragraph
                    if "[PREMISE]" in txt_split:
                        txt_trim = idx
                else:
                    raise ValueError("Invalid Lbl")

            orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt(
                self.tokenizer, self.config, txt_split_tuple[0],
                txt_split_tuple[1], txt_split_tuple[2], txt_trim)
            list_orig_input_ids.append(orig_input_ids)
            list_masked_input_ids.append(masked_input_ids)

        return torch.tensor(list_orig_input_ids).to(device), torch.tensor(
            list_masked_input_ids).to(device), prep_lbl, tgt.to(device)