Exemplo n.º 1
0
    def forward(self, expl, s1_embed, s2_embed, mode, classif_lbl):
        # expl: Variable(seqlen x bsize x worddim)
        # s1/2_embed: Variable(bsize x sent_dim)

        assert mode in ['forloop', 'teacher'], mode

        batch_size = expl.size(1)
        assert_sizes(s1_embed, 2, [batch_size, self.sent_dim])
        assert_sizes(s2_embed, 2, [batch_size, self.sent_dim])
        assert_sizes(expl, 3, [expl.size(0), batch_size, self.word_emb_dim])

        context = torch.cat([s1_embed, s2_embed], 1).unsqueeze(0)
        if self.use_diff_prod_sent_embed:
            context = torch.cat([
                s1_embed, s2_embed,
                torch.abs(s1_embed - s2_embed), s1_embed * s2_embed
            ], 1).unsqueeze(0)
        if self.only_diff_prod:
            context = torch.cat(
                [torch.abs(s1_embed - s2_embed), s1_embed * s2_embed],
                1).unsqueeze(0)

        assert_sizes(
            context, 3,
            [1, batch_size, self.context_mutiply_coef * self.sent_dim])

        # init decoder
        context_init = torch.cat([s1_embed, s2_embed], 1).unsqueeze(0)
        if self.use_init:
            if 2 * self.sent_dim != self.dec_rnn_dim:
                init_0 = self.proj_init(
                    context_init.expand(self.n_layers_dec, batch_size,
                                        2 * self.sent_dim))
            else:
                init_0 = context_init
        else:
            init_0 = Variable(
                torch.zeros(self.n_layers_dec, batch_size,
                            self.dec_rnn_dim)).cuda()

        init_state = init_0
        if self.decoder_type == 'lstm':
            init_state = (init_0, init_0)

        self.decoder_rnn.flatten_parameters()

        if mode == "teacher":
            input_dec = torch.cat([
                expl,
                context.expand(expl.size(0), batch_size,
                               self.context_mutiply_coef * self.sent_dim)
            ], 2)
            input_dec = self.proj_inp_dec(
                nn.Dropout(self.dpout_dec)(input_dec))

            out, _ = self.decoder_rnn(input_dec, init_state)
            dp_out = nn.Dropout(self.dpout_dec)(out)

            if not self.use_vocab_proj:
                return self.vocab_layer(dp_out)
            return self.vocab_layer(self.vocab_proj(dp_out))

        else:
            assert classif_lbl is not None
            assert_sizes(classif_lbl, 1, [batch_size])
            pred_expls = []
            finished = []
            for i in range(batch_size):
                pred_expls.append("")
                finished.append(False)

            dec_inp_t = torch.cat([expl[0, :, :].unsqueeze(0), context], 2)
            dec_inp_t = self.proj_inp_dec(dec_inp_t)

            ht = init_state
            t = 0
            while t < self.max_T_decoder and not array_all_true(finished):
                t += 1
                word_embed = torch.zeros(1, batch_size, self.word_emb_dim)
                assert_sizes(dec_inp_t, 3, [1, batch_size, self.inp_dec_dim])
                dec_out_t, ht = self.decoder_rnn(dec_inp_t, ht)
                assert_sizes(dec_out_t, 3, [1, batch_size, self.dec_rnn_dim])
                if self.use_vocab_proj:
                    out_t_proj = self.vocab_proj(dec_out_t)
                    out_t = self.vocab_layer(out_t_proj).data
                else:
                    out_t = self.vocab_layer(
                        dec_out_t
                    ).data  # TODO: Use torch.stack with variables instead
                assert_sizes(out_t, 3, [1, batch_size, self.n_vocab])
                i_t = torch.max(out_t, 2)[1]
                assert_sizes(i_t, 2, [1, batch_size])
                pred_words = get_keys_from_vals(
                    i_t, self.word_index
                )  # array of bs of words at current timestep
                assert len(pred_words) == batch_size, "pred_words " + str(
                    len(pred_words)) + " batch_size " + str(batch_size)
                for i in range(batch_size):
                    if pred_words[i] == '</s>':
                        finished[i] = True
                    if not finished[i]:
                        pred_expls[i] += " " + pred_words[i]
                    if t > 1:
                        #print "self.word_vec[pred_words[i]]", type(self.word_vec[pred_words[i]])
                        word_embed[0, i] = torch.from_numpy(
                            self.word_vec[pred_words[i]])
                        #print "type(word_embed[0, i]) ", word_embed[0, i]
                        #assert False
                    else:
                        # put label predicted by classifier
                        classif_label = get_key_from_val(
                            classif_lbl[i], NLI_DIC_LABELS)
                        assert classif_label in [
                            'entailment', 'contradiction', 'neutral'
                        ], classif_label
                        word_embed[0, i] = torch.from_numpy(
                            self.word_vec[classif_label])
                word_embed = Variable(word_embed.cuda())
                assert_sizes(word_embed, 3, [1, batch_size, self.word_emb_dim])
                dec_inp_t = self.proj_inp_dec(
                    torch.cat([word_embed, context], 2))
            return pred_expls
Exemplo n.º 2
0
    def forward(self, expl, enc_out_s1, enc_out_s2, s1_embed, s2_embed, mode,
                visualize):
        # expl: Variable(seqlen x bsize x worddim)
        # s1/2_embed: Variable(bsize x sent_dim)

        assert mode in ['forloop', 'teacher'], mode

        current_T_dec = expl.size(0)
        batch_size = expl.size(1)
        assert_sizes(s1_embed, 2, [batch_size, self.sent_dim])
        assert_sizes(s2_embed, 2, [batch_size, self.sent_dim])
        assert_sizes(expl, 3, [current_T_dec, batch_size, self.word_emb_dim])
        assert_sizes(enc_out_s1, 3,
                     [self.max_T_encoder, batch_size, 2 * self.enc_rnn_dim])
        assert_sizes(enc_out_s2, 3,
                     [self.max_T_encoder, batch_size, 2 * self.enc_rnn_dim])

        context = torch.cat([
            s1_embed, s2_embed,
            torch.abs(s1_embed - s2_embed), s1_embed * s2_embed
        ], 1).unsqueeze(0)
        assert_sizes(context, 3, [1, batch_size, 4 * self.sent_dim])

        # init decoder
        if self.use_init:
            init_0 = self.context_proj(context).expand(self.n_layers_dec,
                                                       batch_size,
                                                       self.dec_rnn_dim)
        else:
            init_0 = Variable(
                torch.zeros(self.n_layers_dec, batch_size,
                            self.dec_rnn_dim)).cuda()

        init_state = init_0
        if self.decoder_type == 'lstm':
            init_state = (init_0, init_0)

        self.decoder_rnn.flatten_parameters()

        out_expl = None
        state_t = init_state
        context = self.context_proj(context)
        if mode == "teacher":
            for t_dec in range(current_T_dec):
                # attention over premise
                context1 = self.att_context_proj1(context).permute(1, 0, 2)
                assert_sizes(context1, 3, [batch_size, 1, self.att_hid_dim])

                inp_att_1 = self.att_ht_proj1(enc_out_s1).transpose(
                    1, 0).transpose(2, 1)
                assert_sizes(
                    inp_att_1, 3,
                    [batch_size, self.att_hid_dim, self.max_T_encoder])

                dot_prod_att_1 = torch.bmm(context1, inp_att_1)
                assert_sizes(dot_prod_att_1, 3,
                             [batch_size, 1, self.max_T_encoder])

                att_weights_1 = self.softmax_att(dot_prod_att_1)
                assert_sizes(att_weights_1, 3,
                             [batch_size, 1, self.max_T_encoder])

                att_applied_1 = torch.bmm(
                    att_weights_1,
                    self.att_ht_before_weighting_proj1(enc_out_s1).permute(
                        1, 0, 2))
                assert_sizes(att_applied_1, 3,
                             [batch_size, 1, self.att_hid_dim])

                att_applied_perm_1 = att_applied_1.permute(1, 0, 2)
                assert_sizes(att_applied_perm_1, 3,
                             [1, batch_size, self.att_hid_dim])

                # attention over hypothesis
                context2 = self.att_context_proj2(context).permute(1, 0, 2)
                assert_sizes(context2, 3, [batch_size, 1, self.att_hid_dim])

                inp_att_2 = self.att_ht_proj2(enc_out_s2).transpose(
                    1, 0).transpose(2, 1)
                assert_sizes(
                    inp_att_2, 3,
                    [batch_size, self.att_hid_dim, self.max_T_encoder])

                dot_prod_att_2 = torch.bmm(context2, inp_att_2)
                assert_sizes(dot_prod_att_2, 3,
                             [batch_size, 1, self.max_T_encoder])

                att_weights_2 = self.softmax_att(dot_prod_att_2)
                assert_sizes(att_weights_2, 3,
                             [batch_size, 1, self.max_T_encoder])

                att_applied_2 = torch.bmm(
                    att_weights_2,
                    self.att_ht_before_weighting_proj2(enc_out_s2).permute(
                        1, 0, 2))
                assert_sizes(att_applied_2, 3,
                             [batch_size, 1, self.att_hid_dim])

                att_applied_perm_2 = att_applied_2.permute(1, 0, 2)
                assert_sizes(att_applied_perm_2, 3,
                             [1, batch_size, self.att_hid_dim])

                input_dec = torch.cat([
                    expl[t_dec].unsqueeze(0), att_applied_perm_1,
                    att_applied_perm_2
                ], 2)
                input_dec = nn.Dropout(self.dpout_dec)(
                    self.proj_inp_dec(input_dec))

                out_dec, state_t = self.decoder_rnn(input_dec, state_t)
                assert_sizes(out_dec, 3, [1, batch_size, self.dec_rnn_dim])
                if self.decoder_type == 'lstm':
                    context = state_t[0]
                else:
                    context = state_t

                if out_expl is None:
                    out_expl = out_dec
                else:
                    out_expl = torch.cat([out_expl, out_dec], 0)

            out_expl = self.vocab_layer(out_expl)
            assert_sizes(out_expl, 3,
                         [current_T_dec, batch_size, self.n_vocab])
            return out_expl

        else:
            pred_expls = []
            finished = []
            for i in range(batch_size):
                pred_expls.append("")
                finished.append(False)

            t_dec = 0
            word_t = expl[0].unsqueeze(0)
            while t_dec < self.max_T_decoder and not array_all_true(finished):
                #print "\n\n\n t: ", t_dec

                assert_sizes(word_t, 3, [1, batch_size, self.word_emb_dim])
                word_embed = torch.zeros(1, batch_size, self.word_emb_dim)

                # attention over premise
                context1 = self.att_context_proj1(context).permute(1, 0, 2)
                assert_sizes(context1, 3, [batch_size, 1, self.att_hid_dim])

                inp_att_1 = self.att_ht_proj1(enc_out_s1).transpose(
                    1, 0).transpose(2, 1)
                assert_sizes(
                    inp_att_1, 3,
                    [batch_size, self.att_hid_dim, self.max_T_encoder])

                dot_prod_att_1 = torch.bmm(context1, inp_att_1)
                assert_sizes(dot_prod_att_1, 3,
                             [batch_size, 1, self.max_T_encoder])

                att_weights_1 = self.softmax_att(dot_prod_att_1)
                assert_sizes(att_weights_1, 3,
                             [batch_size, 1, self.max_T_encoder])

                att_applied_1 = torch.bmm(
                    att_weights_1,
                    self.att_ht_before_weighting_proj1(enc_out_s1).permute(
                        1, 0, 2))
                assert_sizes(att_applied_1, 3,
                             [batch_size, 1, self.att_hid_dim])

                att_applied_perm_1 = att_applied_1.permute(1, 0, 2)
                assert_sizes(att_applied_perm_1, 3,
                             [1, batch_size, self.att_hid_dim])

                # attention over hypothesis
                context2 = self.att_context_proj2(context).permute(1, 0, 2)
                assert_sizes(context2, 3, [batch_size, 1, self.att_hid_dim])

                inp_att_2 = self.att_ht_proj2(enc_out_s2).transpose(
                    1, 0).transpose(2, 1)
                assert_sizes(
                    inp_att_2, 3,
                    [batch_size, self.att_hid_dim, self.max_T_encoder])

                dot_prod_att_2 = torch.bmm(context2, inp_att_2)
                assert_sizes(dot_prod_att_2, 3,
                             [batch_size, 1, self.max_T_encoder])

                att_weights_2 = self.softmax_att(dot_prod_att_2)
                assert_sizes(att_weights_2, 3,
                             [batch_size, 1, self.max_T_encoder])

                att_applied_2 = torch.bmm(
                    att_weights_2,
                    self.att_ht_before_weighting_proj2(enc_out_s2).permute(
                        1, 0, 2))
                assert_sizes(att_applied_2, 3,
                             [batch_size, 1, self.att_hid_dim])

                att_applied_perm_2 = att_applied_2.permute(1, 0, 2)
                assert_sizes(att_applied_perm_2, 3,
                             [1, batch_size, self.att_hid_dim])

                input_dec = torch.cat(
                    [word_t, att_applied_perm_1, att_applied_perm_2], 2)
                input_dec = self.proj_inp_dec(input_dec)

                #print "att_weights_1[0] ", att_weights_1[0]
                #print "att_weights_2[0] ", att_weights_2[0]

                # get one visualization from the current batch
                if visualize:
                    if t_dec == 0:
                        weights_1 = att_weights_1[0]
                        weights_2 = att_weights_2[0]
                    else:
                        weights_1 = torch.cat([weights_1, att_weights_1[0]], 0)
                        weights_2 = torch.cat([weights_2, att_weights_2[0]], 0)

                for ii in range(batch_size):
                    assert abs(att_weights_1[ii].data.sum() - 1) < 1e-5, str(
                        att_weights_1[ii].data.sum())
                    assert abs(att_weights_2[ii].data.sum() - 1) < 1e-5, str(
                        att_weights_2[ii].data.sum())

                out_t, state_t = self.decoder_rnn(input_dec, state_t)
                assert_sizes(out_t, 3, [1, batch_size, self.dec_rnn_dim])
                out_t = self.vocab_layer(out_t)
                if self.decoder_type == 'lstm':
                    context = state_t[0]
                else:
                    context = state_t

                i_t = torch.max(out_t, 2)[1].data
                assert_sizes(i_t, 2, [1, batch_size])
                pred_words = get_keys_from_vals(
                    i_t, self.word_index
                )  # array of bs of words at current timestep
                assert len(pred_words) == batch_size, "pred_words " + str(
                    len(pred_words)) + " batch_size " + str(batch_size)
                for i in range(batch_size):
                    if pred_words[i] == '</s>':
                        finished[i] = True
                    if not finished[i]:
                        pred_expls[i] += " " + pred_words[i]
                    word_embed[0, i] = torch.from_numpy(
                        self.word_vec[pred_words[i]])
                word_t = Variable(word_embed.cuda())

                t_dec += 1

            if visualize:
                assert weights_1.dim() == 2
                assert weights_1.size(1) == self.max_T_encoder
                assert weights_2.dim() == 2
                assert weights_2.size(1) == self.max_T_encoder
                pred_expls = [pred_expls, weights_1, weights_2]
            return pred_expls
def eval_datasets_without_expl(esnli_net, which_set, data, word_vec,
                               word_emb_dim, batch_size, print_every,
                               current_run_dir):

    dict_labels = NLI_DIC_LABELS

    esnli_net.eval()
    correct = 0.
    correct_labels_expl = 0.

    s1 = data['s1']
    s2 = data['s2']
    label = data['label']
    label_expl = data['label_expl']

    headers = [
        "gold_label", "Premise", "Hypothesis", "pred_label", "pred_expl",
        "pred_lbl_decoder"
    ]
    expl_csv = os.path.join(
        current_run_dir,
        time.strftime("%d:%m") + "_" + time.strftime("%H:%M:%S") + "_" +
        which_set + ".csv")
    remove_file(expl_csv)
    expl_f = open(expl_csv, "a")
    writer = csv.writer(expl_f)
    writer.writerow(headers)

    for i in range(0, len(s1), batch_size):
        # prepare batch
        s1_batch, s1_len = get_batch(s1[i:i + batch_size], word_vec)
        s2_batch, s2_len = get_batch(s2[i:i + batch_size], word_vec)

        current_bs = s1_batch.size(1)
        assert_sizes(s1_batch, 3, [s1_batch.size(0), current_bs, word_emb_dim])
        assert_sizes(s2_batch, 3, [s2_batch.size(0), current_bs, word_emb_dim])

        s1_batch, s2_batch = Variable(s1_batch.cuda()), Variable(
            s2_batch.cuda())
        tgt_label_batch = Variable(torch.LongTensor(label[i:i +
                                                          batch_size])).cuda()
        tgt_label_expl_batch = label_expl[i:i + batch_size]

        expl_t0 = Variable(
            torch.from_numpy(word_vec['<s>']).float().unsqueeze(0).expand(
                current_bs, word_emb_dim).unsqueeze(0)).cuda()
        assert_sizes(expl_t0, 3, [1, current_bs, word_emb_dim])

        # model forward
        pred_expls, out_lbl = esnli_net((s1_batch, s1_len), (s2_batch, s2_len),
                                        expl_t0,
                                        mode="forloop")
        assert len(pred_expls) == current_bs, "pred_expls: " + str(
            len(pred_expls)) + " current_bs: " + str(current_bs)

        for b in range(len(pred_expls)):
            assert tgt_label_expl_batch[b] in [
                'entailment', 'neutral', 'contradiction'
            ]
            if len(pred_expls[b]) > 0:
                words = pred_expls[b].strip().split(" ")
                if words[0] == tgt_label_expl_batch[b]:
                    correct_labels_expl += 1

        # accuracy
        pred = out_lbl.data.max(1)[1]
        correct += pred.long().eq(tgt_label_batch.data.long()).cpu().sum()

        # write csv row of predictions
        # Look up for the headers order
        for j in range(len(pred_expls)):
            row = []
            row.append(get_key_from_val(label[i + j], dict_labels))
            row.append(' '.join(s1[i + j][1:-1]))
            row.append(' '.join(s2[i + j][1:-1]))
            row.append(get_key_from_val(pred[j], dict_labels))
            row.append(pred_expls[j][1:-1])
            row.append(pred_expls[j][0])
            writer.writerow(row)

        # print example
        if i % print_every == 0:
            print(which_set.upper() + " example: ")
            print("Premise:  ", ' '.join(s1[i]), " LENGHT: ", s1_len[0])
            print("Hypothesis:  ", ' '.join(s2[i]), " LENGHT: ", s2_len[0])
            print("Gold label:  ", get_key_from_val(label[i], dict_labels))
            print("Predicted label:  ", get_key_from_val(pred[0], dict_labels))
            print("Predicted explanation:  ", pred_expls[0], "\n\n\n")

    eval_acc = round(100 * correct / len(s1), 2)
    eval_acc_label_expl = round(100 * correct_labels_expl / len(s1), 2)
    print(which_set.upper() + " no train ", eval_acc, '\n\n\n')
    expl_f.close()
    return eval_acc, eval_acc_label_expl