def prepare_pet_mlm_batch(self, batch, mode="PET1"): ''' Prepare for train :param batch: :return: ''' list_passage = batch["input"]["passage"] list_question = batch["input"]["question"] list_true_entity = batch["input"]["true_entity"] list_false_entities = batch["input"]["false_entities"] list_lbl = batch["output"]["lbl"] bs = len(batch["input"]["passage"]) prep_lbl = np.random.randint(self.num_lbl, size=bs) tgt = torch.from_numpy(prep_lbl).long() == batch["output"]["lbl"] list_orig_input_ids = [] list_masked_input_ids = [] for b_idx, (p, q, te, fe, lbl) in enumerate( zip(list_passage, list_question, list_true_entity, list_false_entities, list_lbl)): txt_split_tuple = [] true_num_lbl_tok = self.get_lbl_num_lbl_tok(te) max_num_lbl_tok = true_num_lbl_tok for idx, wrong_enty in enumerate(fe): num_lbl_tok = self.get_lbl_num_lbl_tok(wrong_enty) if num_lbl_tok > max_num_lbl_tok: max_num_lbl_tok = num_lbl_tok txt_trim = -1 pattern = self.pet_patterns[self._pet_names.index(mode)] for idx, txt_split in enumerate(pattern): txt_split_inp = txt_split.replace("[PASSAGE]", p).replace( "[QUESTION]", q + " [SEP]").replace("@highlight", "-") txt_split_tuple.append(txt_split_inp) # Trim the paragraph if "[PASSAGE]" in txt_split: txt_trim = idx orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt( self.tokenizer, self.config, txt_split_tuple[0], txt_split_tuple[1], txt_split_tuple[2], txt_trim) list_orig_input_ids.append(orig_input_ids) list_masked_input_ids.append(masked_input_ids) return torch.tensor(list_orig_input_ids).to(device), torch.tensor( list_masked_input_ids).to(device), prep_lbl, tgt.to(device)
def prepare_pet_mlm_batch(self, batch, mode="PET1"): ''' Prepare for train :param batch: :return: ''' list_question = batch["input"]["question"] list_passage = batch["input"]["passage"] list_answer = batch["input"]["answer"] bs = len(batch["input"]["answer"]) prep_lbl = np.random.randint(self.num_lbl, size=bs) tgt = torch.from_numpy(prep_lbl).long() == batch["output"]["lbl"] pattern, label = self.pet_pvps[self._pet_names.index(mode)] list_orig_input_ids = [] list_masked_input_ids = [] for b_idx, (p, q, a, lbl) in enumerate( zip(list_passage, list_question, list_answer, prep_lbl)): txt_split_tuple = [] txt_trim = -1 for idx, txt_split in enumerate(pattern): txt_split_inp = txt_split.replace("[PARAGRAPH]", p).replace( "[QUESTION]", q).replace("[ANSWER]", a).replace("[MASK]", label[lbl]) txt_split_tuple.append(txt_split_inp) # Trim the paragraph if "[PARAGRAPH]" in txt_split: txt_trim = idx orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt( self.tokenizer, self.config, txt_split_tuple[0], txt_split_tuple[1], txt_split_tuple[2], txt_trim) list_orig_input_ids.append(orig_input_ids) list_masked_input_ids.append(masked_input_ids) return torch.tensor(list_orig_input_ids).to(device), torch.tensor( list_masked_input_ids).to(device), prep_lbl, tgt.to(device)
def prepare_pet_mlm_batch(self, batch, mode="PET1"): ''' Prepare for train :param batch: :return: ''' list_hypothesis = batch["input"]["hypothesis"] list_premise = batch["input"]["premise"] bs = len(batch["input"]["hypothesis"]) prep_lbl = np.random.randint(self.num_lbl, size=bs) tgt = torch.from_numpy(prep_lbl).long() == batch["output"]["lbl"] pattern, label = self.pet_pvps[self._pet_names.index(mode)] list_orig_input_ids = [] list_masked_input_ids = [] for b_idx, (h, p, lbl) in enumerate( zip(list_hypothesis, list_premise, prep_lbl)): txt_split_tuple = [] txt_trim = -1 for idx, txt_split in enumerate(pattern): txt_split_inp = txt_split.replace("[HYPOTHESIS]", h).replace( "[PREMISE]", p).replace("[MASK]", label[lbl]) txt_split_tuple.append(txt_split_inp) # Trim the paragraph if "[PREMISE]" in txt_split: txt_trim = idx orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt( self.tokenizer, self.config, txt_split_tuple[0], txt_split_tuple[1], txt_split_tuple[2], txt_trim) list_orig_input_ids.append(orig_input_ids) list_masked_input_ids.append(masked_input_ids) return torch.tensor(list_orig_input_ids).to(device), torch.tensor( list_masked_input_ids).to(device), prep_lbl, tgt.to(device)
def prepare_eval_pet_mlm_batch(self, batch, mode="PET1"): ''' Prepare for train :param batch: :return: ''' list_hypothesis = batch["input"]["hypothesis"] list_premise = batch["input"]["premise"] list_input_ids = [] list_masked_input_ids = [] pattern, label = self.pet_pvps[self._pet_names.index(mode)] for b_idx, (h, p) in enumerate(zip(list_hypothesis, list_premise)): mask_idx = None for l_idx, lbl in enumerate(label): txt_split_tuple = [] for idx, txt_split in enumerate(pattern): txt_split_inp = txt_split.replace( "[HYPOTHESIS]", h).replace("[PREMISE]", p).replace("[MASK]", lbl) txt_split_tuple.append(txt_split_inp) # Trim the paragraph if "[PREMISE]" in txt_split: txt_trim = idx orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt( self.tokenizer, self.config, txt_split_tuple[0], txt_split_tuple[1], txt_split_tuple[2], txt_trim, mask_idx=mask_idx) list_input_ids.append(orig_input_ids) list_masked_input_ids.append(masked_input_ids) return torch.tensor(list_input_ids).to(device), torch.tensor( list_masked_input_ids).to(device)
def prepare_pet_mlm_batch(self, batch, mode="PET1"): ''' Prepare for train :param batch: :return: ''' list_sentence1 = batch["input"]["sentence1"] list_sentence2 = batch["input"]["sentence2"] list_word = batch["input"]["word"] bs = len(batch["input"]["sentence1"]) prep_lbl = np.random.randint(self.num_lbl, size=bs) tgt = torch.from_numpy(prep_lbl).long() == batch["output"]["lbl"] pattern, label = self.pet_pvps[self._pet_names.index(mode)] list_orig_input_ids = [] list_masked_input_ids = [] for b_idx, (s1, s2, w, lbl) in enumerate(zip(list_sentence1, list_sentence2, list_word, prep_lbl)): txt_split_tuple = [] txt_trim = -1 for idx, txt_split in enumerate(pattern): txt_split_inp = txt_split.replace("[SENTENCE1]", s1).replace("[SENTENCE2]", s2).replace("[WORD]", w).replace("[MASK]", label[lbl]) txt_split_tuple.append(txt_split_inp) # Trim the paragraph if "[SENTENCE1]" in txt_split: txt_trim = idx orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt(self.tokenizer, self.config, txt_split_tuple[0], txt_split_tuple[1], txt_split_tuple[2], txt_trim) list_orig_input_ids.append(orig_input_ids) list_masked_input_ids.append(masked_input_ids) return torch.tensor(list_orig_input_ids).to(device), torch.tensor(list_masked_input_ids).to(device), prep_lbl, tgt.to(device)
def prepare_pet_mlm_batch(self, batch, mode="PET1"): ''' Prepare for train :param batch: :return: ''' list_text = batch["input"]["text"] list_pronoun = batch["input"]["pronoun"] list_noun = batch["input"]["noun"] list_lbl = batch["output"]["lbl"] list_orig_input_ids = [] list_masked_input_ids = [] tgt = torch.tensor([1.]).long() for b_idx, (t, p, n, lbl) in enumerate( zip(list_text, list_pronoun, list_noun, list_lbl)): txt_trim = -1 pattern = self.pet_patterns[self._pet_names.index(mode)] txt_split_tuple = [] for idx, txt_split in enumerate(pattern): txt_split_inp = txt_split.replace("[TEXT]", t).replace( "[NNP]", p).replace("[MASK]", n) txt_split_tuple.append(txt_split_inp) if "[TEXT]" in txt_split: txt_trim = idx orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt( self.tokenizer, self.config, txt_split_tuple[0], txt_split_tuple[1], txt_split_tuple[2], txt_trim) list_orig_input_ids.append(orig_input_ids) list_masked_input_ids.append(masked_input_ids) return torch.tensor(list_orig_input_ids).to(device), torch.tensor( list_masked_input_ids).to(device), None, tgt.to(device)
def prepare_pet_mlm_batch(self, batch, mode="PET1"): ''' Prepare for train :param batch: :return: ''' # Always use pattern 3 for COPA mode = "PET3" list_premise = batch["input"]["premise"] list_choice1 = batch["input"]["choice1"] list_choice2 = batch["input"]["choice2"] list_question = batch["input"]["question"] list_lbl = batch["output"]["lbl"] bs = len(batch["input"]["question"]) prep_lbl = np.random.randint(self.num_lbl, size=bs) tgt = torch.from_numpy(prep_lbl).long() == batch["output"]["lbl"] list_orig_input_ids = [] list_masked_input_ids = [] for b_idx, (p, c1, c2, ques, lbl) in enumerate( zip(list_premise, list_choice1, list_choice2, list_question, list_lbl)): txt_split_tuple = [] txt_trim = -1 if ques == "cause": pet_pvps = self.pet_patterns_cause elif ques == "effect": pet_pvps = self.pet_patterns_effect pattern = pet_pvps[self._pet_names.index(mode)] if lbl.item() == 0: lbl_choice = c1[:-1] elif lbl.item() == 1: lbl_choice = c2[:-1] else: raise ValueError("Invalid Lbl") for idx, txt_split in enumerate(pattern): txt_split_inp = txt_split.replace("[PREMISE]", p[:-1]).replace( "[CHOICE1]", c1[:-1]).replace("[CHOICE2]", c2[:-1]).replace("[MASK]", lbl_choice) txt_split_tuple.append(txt_split_inp) if lbl.item() == 0: # Trim the paragraph if "[PREMISE]" in txt_split: txt_trim = idx elif lbl.item() == 1: # Trim the paragraph if "[PREMISE]" in txt_split: txt_trim = idx else: raise ValueError("Invalid Lbl") orig_input_ids, masked_input_ids, mask_idx = tokenize_pet_mlm_txt( self.tokenizer, self.config, txt_split_tuple[0], txt_split_tuple[1], txt_split_tuple[2], txt_trim) list_orig_input_ids.append(orig_input_ids) list_masked_input_ids.append(masked_input_ids) return torch.tensor(list_orig_input_ids).to(device), torch.tensor( list_masked_input_ids).to(device), prep_lbl, tgt.to(device)