예제 #1
0
def batcher_gradient(batch_instances, labels, tokenizer, model, explainer, cuda_device):
    input_ids = [encode_instance(instance, tokenizer) for instance in batch_instances]
    attention_mask = [torch.ones_like(t) for t in input_ids]

    input_ids = collate_tokens(input_ids, pad_idx=1).to(cuda_device)
    attention_mask = collate_tokens(attention_mask, pad_idx=0).to(cuda_device)

    inputs_embeds = model.roberta.embeddings(input_ids=input_ids).detach()

    true_label_idx_list = [MNLI_LABEL2IDX[labels[instance.id]] for instance in batch_instances]
    true_label_idx_tensor = torch.tensor(true_label_idx_list, dtype=torch.long, device=cuda_device)

    inputs_embeds.requires_grad = True
    expl = explainer.explain(inp={"inputs_embeds": inputs_embeds, "attention_mask": attention_mask},
                             ind=true_label_idx_tensor)

    input_ids_np = input_ids.cpu().numpy()
    expl_np = expl.cpu().numpy()

    relevances = []
    for b_idx in range(input_ids_np.shape[0]):
        sent1_offsets, sent2_offsets = byte_pair_offsets(input_ids_np[b_idx].tolist(), tokenizer)

        relevance_dict = defaultdict(float)
        for offsets, sent_id in zip([sent1_offsets, sent2_offsets], ["sent1", "sent2"]):
            for token_idx, (token_start, token_end) in enumerate(zip(offsets, offsets[1:])):
                relevance = expl_np[b_idx][token_start: token_end].sum()
                relevance_dict[(sent_id, token_idx)] = relevance
        relevances.append(relevance_dict)

    return relevances
예제 #2
0
def predict(input_instances, model, tokenizer, cuda_device):
    if isinstance(input_instances, InputInstance):
        input_instances = [input_instances]

    input_ids = [encode_instance(instance, tokenizer) for instance in input_instances]
    attention_mask = [torch.ones_like(t) for t in input_ids]

    input_ids = collate_tokens(input_ids, pad_idx=1).to(cuda_device)
    attention_mask = collate_tokens(attention_mask, pad_idx=0).to(cuda_device)

    logits = model(input_ids=input_ids, attention_mask=attention_mask)[0]
    return F.softmax(logits, dim=-1)
예제 #3
0
 def merge(key, is_list=False, pad_idx=0):
     if is_list:
         res = []
         for i in range(len(samples[0][key])):
             res.append(
                 utils.collate_tokens(
                     [s[key][i] for s in samples],
                     pad_idx=pad_idx,
                 ))
         return res
     else:
         return utils.collate_tokens([s[key] for s in samples],
                                     pad_idx=pad_idx)
예제 #4
0
    def collater(self, samples):
        batch = dict()
        batch['id'] = [s['id'] for s in samples]
        batch_size = len(batch['id'])

        # src (op + passage)
        src_tensors = []
        src_len = []
        for ix in range(batch_size):
            cur_src = samples[ix]['op']
            if self.encode_passage:
                cur_src += samples[ix]['passage']
            src_tensors.append(torch.LongTensor(cur_src))
            src_len.append(len(cur_src))
        batch['enc_src'] = utils.collate_tokens(values=src_tensors,
                                                pad_idx=self.vocab.pad_idx)
        batch['enc_src_len'] = torch.LongTensor(src_len)

        # target: dec_in and dec_out
        # sentence_types
        tgt_len = [len(s['tgt']) for s in samples]
        max_tgt_len = max(tgt_len)
        sent_nums = [len(s['sentence_type']) for s in samples]
        max_sent_nums = max(sent_nums)

        dec_inputs = []
        dec_targets = []
        dec_lens = []
        dec_sent_ids = []
        dec_mask = torch.zeros([batch_size, max_tgt_len - 1], dtype=torch.long)

        # add EOS sentence type
        sent_types = torch.full([batch_size, max_sent_nums + 1],
                                fill_value=2,
                                dtype=torch.long)

        for ix, item in enumerate(samples):
            cur_tgt = item['tgt']
            dec_inputs.append(torch.LongTensor(cur_tgt[:-1]))
            dec_targets.append(torch.LongTensor(cur_tgt[1:]))
            dec_lens.append(len(cur_tgt) - 1)
            dec_mask[ix][:len(cur_tgt) - 1] = 1

            dec_sent_ids.append(torch.LongTensor(item['tgt_sent_ids'][:-1]))
            sent_types[ix][:sent_nums[ix]] = torch.LongTensor(
                item["sentence_type"])

        batch['dec_in'] = utils.collate_tokens(values=dec_inputs,
                                               pad_idx=self.vocab.pad_idx)
        batch['dec_out'] = utils.collate_tokens(values=dec_targets,
                                                pad_idx=self.vocab.pad_idx)
        batch['dec_in_len'] = torch.LongTensor(dec_lens)
        batch['dec_mask'] = dec_mask
        batch['dec_sent_id'] = utils.collate_tokens(values=dec_sent_ids,
                                                    pad_idx=0)
        batch["sent_types"] = sent_types

        # pad phrase bank, for each sample the phrase bank is a 2D list
        # the result would be a 3D list
        phrase_banks = [s['phrase_bank'] for s in samples]
        phrase_sizes = [[len(ph) for ph in s] for s in phrase_banks]

        max_ph_num = max([len(x) for x in phrase_sizes])
        max_ph_len = max(
            [max([len(p) for p in bank]) for bank in phrase_banks])
        phrase_bank_tensor = torch.zeros([batch_size, max_ph_num, max_ph_len],
                                         dtype=torch.long)
        for ix in range(batch_size):
            cur_ph_bank = phrase_banks[ix]
            for j, ph in enumerate(cur_ph_bank):
                phrase_bank_tensor[ix][j][:len(ph)] = torch.LongTensor(
                    list(ph))
        batch['ph_bank_tensor'] = phrase_bank_tensor
        batch['ph_bank_len_tensor'] = torch.LongTensor(
            [len(x) for x in phrase_banks])

        # create padded tensor for phrase selection indicators
        # sample[`phrase_bank_sel`] is a 3D list [sample_id, sent_id, phrase_id]
        phrase_sel = [s['phrase_bank_sel_ind'] for s in samples]
        sent_num = [len(x) for x in phrase_sel]
        phrase_sel_ind_tensor = torch.zeros(
            [batch_size, max(sent_num), max_ph_num], dtype=torch.long)
        for ix in range(batch_size):
            cur_sel = phrase_sel[ix]
            for sent_ix, sent_sel in enumerate(cur_sel):
                phrase_sel_ind_tensor[
                    ix, sent_ix, :len(sent_sel)] = torch.LongTensor(sent_sel)
        batch["ph_sel_ind_tensor"] = phrase_sel_ind_tensor

        # 3d list, batch_size x sent_num x ph_num x ph_len
        phrase_sel = [s['phrase_bank_sel'] for s in samples]
        max_ph_len = 0
        max_ph_per_sent = 0
        for sample in phrase_sel:
            ph_lens = [[len(ph) for ph in sent] for sent in sample]
            max_ph_per_sent = max(max_ph_per_sent,
                                  max([len(item) for item in ph_lens]))
            max_ph_len = max(
                max_ph_len,
                max([max(item) if len(item) > 0 else 0 for item in ph_lens]))

        phrase_sel_tensor = torch.zeros(
            [batch_size,
             max(sent_num), max_ph_per_sent, max_ph_len],
            dtype=torch.long)
        for six, sample in enumerate(phrase_sel):
            for sent_ix, sent in enumerate(sample):
                for ph_ix, ph in enumerate(sent):
                    phrase_sel_tensor[six, sent_ix,
                                      ph_ix, :len(ph)] = torch.LongTensor(ph)
        batch['ph_sel_tensor'] = phrase_sel_tensor

        return batch