def batcher_gradient(batch_instances, labels, tokenizer, model, explainer, cuda_device): input_ids = [encode_instance(instance, tokenizer) for instance in batch_instances] attention_mask = [torch.ones_like(t) for t in input_ids] input_ids = collate_tokens(input_ids, pad_idx=1).to(cuda_device) attention_mask = collate_tokens(attention_mask, pad_idx=0).to(cuda_device) inputs_embeds = model.roberta.embeddings(input_ids=input_ids).detach() true_label_idx_list = [MNLI_LABEL2IDX[labels[instance.id]] for instance in batch_instances] true_label_idx_tensor = torch.tensor(true_label_idx_list, dtype=torch.long, device=cuda_device) inputs_embeds.requires_grad = True expl = explainer.explain(inp={"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}, ind=true_label_idx_tensor) input_ids_np = input_ids.cpu().numpy() expl_np = expl.cpu().numpy() relevances = [] for b_idx in range(input_ids_np.shape[0]): sent1_offsets, sent2_offsets = byte_pair_offsets(input_ids_np[b_idx].tolist(), tokenizer) relevance_dict = defaultdict(float) for offsets, sent_id in zip([sent1_offsets, sent2_offsets], ["sent1", "sent2"]): for token_idx, (token_start, token_end) in enumerate(zip(offsets, offsets[1:])): relevance = expl_np[b_idx][token_start: token_end].sum() relevance_dict[(sent_id, token_idx)] = relevance relevances.append(relevance_dict) return relevances
def predict(input_instances, model, tokenizer, cuda_device): if isinstance(input_instances, InputInstance): input_instances = [input_instances] input_ids = [encode_instance(instance, tokenizer) for instance in input_instances] attention_mask = [torch.ones_like(t) for t in input_ids] input_ids = collate_tokens(input_ids, pad_idx=1).to(cuda_device) attention_mask = collate_tokens(attention_mask, pad_idx=0).to(cuda_device) logits = model(input_ids=input_ids, attention_mask=attention_mask)[0] return F.softmax(logits, dim=-1)
def merge(key, is_list=False, pad_idx=0): if is_list: res = [] for i in range(len(samples[0][key])): res.append( utils.collate_tokens( [s[key][i] for s in samples], pad_idx=pad_idx, )) return res else: return utils.collate_tokens([s[key] for s in samples], pad_idx=pad_idx)
def collater(self, samples): batch = dict() batch['id'] = [s['id'] for s in samples] batch_size = len(batch['id']) # src (op + passage) src_tensors = [] src_len = [] for ix in range(batch_size): cur_src = samples[ix]['op'] if self.encode_passage: cur_src += samples[ix]['passage'] src_tensors.append(torch.LongTensor(cur_src)) src_len.append(len(cur_src)) batch['enc_src'] = utils.collate_tokens(values=src_tensors, pad_idx=self.vocab.pad_idx) batch['enc_src_len'] = torch.LongTensor(src_len) # target: dec_in and dec_out # sentence_types tgt_len = [len(s['tgt']) for s in samples] max_tgt_len = max(tgt_len) sent_nums = [len(s['sentence_type']) for s in samples] max_sent_nums = max(sent_nums) dec_inputs = [] dec_targets = [] dec_lens = [] dec_sent_ids = [] dec_mask = torch.zeros([batch_size, max_tgt_len - 1], dtype=torch.long) # add EOS sentence type sent_types = torch.full([batch_size, max_sent_nums + 1], fill_value=2, dtype=torch.long) for ix, item in enumerate(samples): cur_tgt = item['tgt'] dec_inputs.append(torch.LongTensor(cur_tgt[:-1])) dec_targets.append(torch.LongTensor(cur_tgt[1:])) dec_lens.append(len(cur_tgt) - 1) dec_mask[ix][:len(cur_tgt) - 1] = 1 dec_sent_ids.append(torch.LongTensor(item['tgt_sent_ids'][:-1])) sent_types[ix][:sent_nums[ix]] = torch.LongTensor( item["sentence_type"]) batch['dec_in'] = utils.collate_tokens(values=dec_inputs, pad_idx=self.vocab.pad_idx) batch['dec_out'] = utils.collate_tokens(values=dec_targets, pad_idx=self.vocab.pad_idx) batch['dec_in_len'] = torch.LongTensor(dec_lens) batch['dec_mask'] = dec_mask batch['dec_sent_id'] = utils.collate_tokens(values=dec_sent_ids, pad_idx=0) batch["sent_types"] = sent_types # pad phrase bank, for each sample the phrase bank is a 2D list # the result would be a 3D list phrase_banks = [s['phrase_bank'] for s in samples] phrase_sizes = [[len(ph) for ph in s] for s in phrase_banks] max_ph_num = max([len(x) for x in phrase_sizes]) max_ph_len = max( [max([len(p) for p in bank]) for bank in phrase_banks]) phrase_bank_tensor = torch.zeros([batch_size, max_ph_num, max_ph_len], dtype=torch.long) for ix in range(batch_size): cur_ph_bank = phrase_banks[ix] for j, ph in enumerate(cur_ph_bank): phrase_bank_tensor[ix][j][:len(ph)] = torch.LongTensor( list(ph)) batch['ph_bank_tensor'] = phrase_bank_tensor batch['ph_bank_len_tensor'] = torch.LongTensor( [len(x) for x in phrase_banks]) # create padded tensor for phrase selection indicators # sample[`phrase_bank_sel`] is a 3D list [sample_id, sent_id, phrase_id] phrase_sel = [s['phrase_bank_sel_ind'] for s in samples] sent_num = [len(x) for x in phrase_sel] phrase_sel_ind_tensor = torch.zeros( [batch_size, max(sent_num), max_ph_num], dtype=torch.long) for ix in range(batch_size): cur_sel = phrase_sel[ix] for sent_ix, sent_sel in enumerate(cur_sel): phrase_sel_ind_tensor[ ix, sent_ix, :len(sent_sel)] = torch.LongTensor(sent_sel) batch["ph_sel_ind_tensor"] = phrase_sel_ind_tensor # 3d list, batch_size x sent_num x ph_num x ph_len phrase_sel = [s['phrase_bank_sel'] for s in samples] max_ph_len = 0 max_ph_per_sent = 0 for sample in phrase_sel: ph_lens = [[len(ph) for ph in sent] for sent in sample] max_ph_per_sent = max(max_ph_per_sent, max([len(item) for item in ph_lens])) max_ph_len = max( max_ph_len, max([max(item) if len(item) > 0 else 0 for item in ph_lens])) phrase_sel_tensor = torch.zeros( [batch_size, max(sent_num), max_ph_per_sent, max_ph_len], dtype=torch.long) for six, sample in enumerate(phrase_sel): for sent_ix, sent in enumerate(sample): for ph_ix, ph in enumerate(sent): phrase_sel_tensor[six, sent_ix, ph_ix, :len(ph)] = torch.LongTensor(ph) batch['ph_sel_tensor'] = phrase_sel_tensor return batch