Example #1
0
    def eval_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1,attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, s_t_1,
                                                        encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
                                                        extra_zeros, enc_batch_extend_vocab, coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_step_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        return loss.data[0]
Example #2
0
    def eval_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder(enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        if config.use_maxpool_init_ctx:
            c_t_1 = max_encoder_output

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1,attn_dist, p_gen, coverage = self.model.decoder(y_t_1, s_t_1,
                                                                encoder_outputs, enc_padding_mask, c_t_1,
                                                                extra_zeros, enc_batch_extend_vocab, coverage)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_step_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        return loss.data[0]
    def train_one_batch(self, batch, iter):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []

        words = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            words.append(self.vocab.id2word(final_dist[0].argmax().item()))

            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            # print('step_loss',step_loss)
            # print('step_loss.size()',step_loss.size())
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        if iter % 100 == 0:
            print(words)
            print([self.vocab.id2word(idx.item()) for idx in dec_batch[0]])
            print([self.vocab.id2word(idx.item()) for idx in target_batch[0]])

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()
Example #4
0
    def train_one_batch(self, batch, iter):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)
       # print(target_batch[:, 1:].contiguous().view(-1)[-10:])
        #print(dec_batch[:, 1:].contiguous().view(-1)[-10:])

        in_seq = enc_batch
        in_pos = self.get_pos_data(enc_padding_mask)
        tgt_seq = dec_batch
        tgt_pos = self.get_pos_data(dec_padding_mask)
        
        # padding is already done in previous function (see batcher.py - init_decoder_seq & init_decoder_seq - Batch class)
        self.optimizer.zero_grad()
        #logits = self.model.forward(in_seq, in_pos, tgt_seq, tgt_pos)
        logits = self.model.forward(in_seq, in_pos, tgt_seq, tgt_pos, extra_zeros, enc_batch_extend_vocab)

        # compute loss from logits
        loss = self.loss_func(logits, target_batch.contiguous().view(-1))

    

        # target_batch[torch.gather(logits, 2, target_batch.unsqueeze(2)).squeeze(2) == 0] = 1
        # target_batch = target_batch.contiguous().view(-1)
        # logits = logits.reshape(-1, logits.size()[2])
        # print(target_batch)
        # print('\n')
        # print(logits.size(), target_batch.size())
        # print('\n')
        #loss = self.loss_func(logits, target_batch)

        #print(loss)
        #sum_losses = torch.mean(torch.stack(losses, 1), 1)

        if iter % 50 == 0 and False:
            print(iter, loss)
            print('\n')
            # print(logits.max(1)[1][:20])
            # print('\n')
            # print(target_batch.contiguous().view(-1)[:20])
            # print('\n')
            #print(target_batch.contiguous().view(-1)[-10:])

        loss.backward()

        #print(logits.max(1)[1])
        #print('\n')
        #print(tgt_seq[:, 1:].contiguous().view(-1)[:10])
        #print(tgt_seq[:, 1:].contiguous().view(-1)[-10:])
        
        self.norm = clip_grad_norm_(self.model.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.parameters(), config.max_grad_norm)

        #self.optimizer.step()
        self.optimizer.step_and_update_lr()

        return loss.item()
Example #5
0
    def train_one_batch(self, batch):
        article_oovs, enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        enc_batch = [outputids2words(ids,self.vocab,article_oovs[i]) for i,ids in enumerate(enc_batch.numpy())]
        enc_batch_list = []
        for words in enc_batch:
            temp_list = []
            for w in words:
                l = ft_model.get_numpy_vector(w)
                temp_list.append(l)
            enc_batch_list.append(temp_list)
        enc_batch_list = torch.Tensor(enc_batch_list)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch_list, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            # for i, id in enumerate(y_t_1):
            #     print (id)
            #     myid2word(id, self.vocab, article_oovs[i])
            y_t_1 = [myid2word(id,self.vocab, article_oovs[i]) for i, id in enumerate(y_t_1.numpy())]
            y_t_1 = torch.Tensor([ft_model.get_numpy_vector(w) for w in y_t_1])
            final_dist, s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, s_t_1,
                                                        encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
                                                        extra_zeros, enc_batch_extend_vocab,
                                                                           coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage
                
            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses/dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)

        self.optimizer.step()

        return loss.item()
Example #6
0
    def eval_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        if not config.is_hierarchical:
            encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
                enc_batch, enc_lens)
            s_t_1 = self.model.reduce_state.forward1(encoder_hidden)

        else:
            stop_id = self.vocab.word2id('.')
            enc_sent_pos = get_sent_position(enc_batch, stop_id)
            dec_sent_pos = get_sent_position(dec_batch, stop_id)

            encoder_outputs, encoder_feature, encoder_hidden, sent_enc_outputs, sent_enc_feature, sent_enc_hidden, sent_enc_padding_mask = \
                                                                    self.model.encoder(enc_batch, enc_lens, enc_sent_pos)
            s_t_1, sent_s_t_1 = self.model.reduce_state(
                encoder_hidden, sent_enc_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            if not config.is_hierarchical:
                final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder.forward1(
                    y_t_1, s_t_1, encoder_outputs, encoder_feature,
                    enc_padding_mask, c_t_1, extra_zeros,
                    enc_batch_extend_vocab, coverage, di)

            else:

                final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                    y_t_1, s_t_1, enc_sent_pos, encoder_outputs,
                    encoder_feature, enc_padding_mask, sent_s_t_1,
                    sent_enc_outputs, sent_enc_feature, sent_enc_padding_mask,
                    c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di)

            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist,
                                      dim=1,
                                      index=target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_step_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        return loss.data.item()
    def train_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len,
                            config.max_dec_steps)):  # max_dec_steps=100
            """
            >>> a=np.array([[1,2,3],[4,5,6]])
            >>> a[:,2]
            array([3, 6])
            >>> a.shape
            (2, 3)
            """
            y_t_1 = dec_batch[:,
                              di]  # Teacher forcing, the last id of each sample.
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()
    def train_one_batch(self, batch, forcing_ratio=1):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, device)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, device)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        y_t_1_hat = None
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]
            # decide the next input
            if di == 0 or random.random() < forcing_ratio:
                x_t = y_t_1  # teacher forcing, use label from last time step as input
            else:
                # use embedding of UNK for all oov word
                y_t_1_hat[y_t_1_hat > self.vocab.size()] = self.vocab.word2id(
                    UNKNOWN_TOKEN)
                x_t = y_t_1_hat.flatten(
                )  # use prediction from last time step as input
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                x_t, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask,
                c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di)
            _, y_t_1_hat = final_dist.data.topk(1)
            target = target_batch[:, di].unsqueeze(1)
            step_loss = cal_NLLLoss(target, final_dist)
            if config.is_coverage:  # if not using coverge, keep coverage=None
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:,
                                         di]  # padding in target should not count into loss
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()
Example #9
0
    def train_one_batch(self, batch, it):
#        self.change_lr(it)

        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)


        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            if config.scratchpad:
                final_dist, s_t_1, _, attn_dist, p_gen, encoder_outputs = \
                    self.model.decoder(
                        y_t_1, s_t_1, encoder_outputs, encoder_feature, \
                        enc_padding_mask, c_t_1, extra_zeros, \
                        enc_batch_extend_vocab, coverage, di \
                    )
            else:
                final_dist, s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = \
                    self.model.decoder(
                        y_t_1, s_t_1, encoder_outputs, encoder_feature, \
                        enc_padding_mask, c_t_1, extra_zeros, \
                        enc_batch_extend_vocab, coverage, di \
                    )

            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses/dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)

        if it % config.update_every == 0:
          self.optimizer.step()
          self.optimizer.zero_grad()

        return loss.item()
    def train_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)
        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens, self.vocab, batch.art_oovs[0])
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []  #test
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di, self.vocab, batch.art_oovs[0])
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)
        if config.is_mixed_precision_training:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        #Gradient clipping
        if config.is_mixed_precision_training:
            torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer),
                                           config.max_grad_norm)
        else:
            self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                        config.max_grad_norm)
            clip_grad_norm_(self.model.decoder.parameters(),
                            config.max_grad_norm)
            clip_grad_norm_(self.model.reduce_state.parameters(),
                            config.max_grad_norm)

        self.optimizer.step()

        return loss.item()
    def train_one_batch(self, batch):
        enc_batch,query_enc_batch, enc_padding_mask, query_enc_padding_mask, enc_lens,query_enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder(
            enc_batch, enc_lens)
        query_encoder_outputs, query_encoder_hidden, max_query_encoder_output = self.model.query_encoder(
            query_enc_batch, query_enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)
        q_s_t_1 = self.model.query_reduce_state(query_encoder_hidden)

        if config.use_maxpool_init_ctx:
            c_t_1 = max_encoder_output

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, coverage = self.model.decoder(
                y_t_1, s_t_1, q_s_t_1, encoder_outputs, query_encoder_outputs,
                enc_padding_mask, query_enc_padding_mask, c_t_1, extra_zeros,
                enc_batch_extend_vocab, coverage)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        clip_grad_norm(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm(self.model.query_encoder.parameters(),
                       config.max_grad_norm)
        clip_grad_norm(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm(self.model.reduce_state.parameters(),
                       config.max_grad_norm)
        clip_grad_norm(self.model.query_reduce_state.parameters(),
                       config.max_grad_norm)

        self.optimizer.step()

        return loss.data[0]
Example #12
0
    def encode_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)

        h, c = self.model.reduce_state(encoder_hidden)
        h, c = h.squeeze(0), c.squeeze(0)
        encodes = torch.cat((h, c), 1)

        for id, encode in zip(batch.original_abstracts, encodes):
            print(encode)
            self.output[id] = encode
Example #13
0
    def train_one_batch(self, batch, iter):

        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch)

        self.optimizer.clear_gradients()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]

            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = \
                self.model.decoder(y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask,
                                   c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di)

            target = target_batch[:, di]
            add_index = paddle.arange(0, target.shape[0])
            new_index = paddle.stack([add_index, target], axis=1)
            gold_probs = paddle.gather_nd(final_dist, new_index).squeeze()
            step_loss = -paddle.log(gold_probs + config.eps)

            if config.is_coverage:
                step_coverage_loss = paddle.sum(
                    paddle.minimum(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = paddle.sum(paddle.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = paddle.mean(batch_avg_loss)

        loss.backward()
        self.optimizer.minimize(loss)

        return loss.numpy()[0]
Example #14
0
    def train_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1 = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = get_output_from_batch(
            batch, use_cuda)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, _, _ = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, enc_padding_mask, c_t_1,
                extra_zeros, enc_batch_extend_vocab)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_step_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)
        loss.backward()

        clip_grad_norm(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm(self.model.reduce_state.parameters(),
                       config.max_grad_norm)

        self.optimizer.step()

        return loss.data[0]
    def eval_one_batch(self, batch):
        enc_batch_list, enc_padding_mask_list, enc_lens_list, enc_batch_extend_vocab_list, extra_zeros_list, c_t_1_list, coverage_list = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        encoder_outputs_list = []
        encoder_feature_list = []
        s_t_1 = None
        s_t_1_0 = None
        s_t_1_1 = None
        for enc_batch, enc_lens in zip(enc_batch_list, enc_lens_list):
            sorted_indices = sorted(range(len(enc_lens)),
                                    key=enc_lens.__getitem__)
            sorted_indices.reverse()
            inverse_sorted_indices = [-1 for _ in range(len(sorted_indices))]
            for index, position in enumerate(sorted_indices):
                inverse_sorted_indices[position] = index
            sorted_enc_batch = torch.index_select(
                enc_batch, 0,
                torch.LongTensor(sorted_indices)
                if not use_cuda else torch.LongTensor(sorted_indices).cuda())
            sorted_enc_lens = enc_lens[sorted_indices]
            sorted_encoder_outputs, sorted_encoder_feature, sorted_encoder_hidden = self.model.encoder(
                sorted_enc_batch, sorted_enc_lens)
            encoder_outputs = torch.index_select(
                sorted_encoder_outputs, 0,
                torch.LongTensor(inverse_sorted_indices) if not use_cuda else
                torch.LongTensor(inverse_sorted_indices).cuda())
            encoder_feature = torch.index_select(
                sorted_encoder_feature.view(encoder_outputs.shape), 0,
                torch.LongTensor(inverse_sorted_indices) if not use_cuda else
                torch.LongTensor(inverse_sorted_indices).cuda()).view(
                    sorted_encoder_feature.shape)
            encoder_hidden = tuple([
                torch.index_select(
                    sorted_encoder_hidden[0], 1,
                    torch.LongTensor(inverse_sorted_indices) if not use_cuda
                    else torch.LongTensor(inverse_sorted_indices).cuda()),
                torch.index_select(
                    sorted_encoder_hidden[1], 1,
                    torch.LongTensor(inverse_sorted_indices) if not use_cuda
                    else torch.LongTensor(inverse_sorted_indices).cuda())
            ])
            #encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
            encoder_outputs_list.append(encoder_outputs)
            encoder_feature_list.append(encoder_feature)
            if s_t_1 is None:
                s_t_1 = self.model.reduce_state(encoder_hidden)
                s_t_1_0, s_t_1_1 = s_t_1
            else:
                s_t_1_new = self.model.reduce_state(encoder_hidden)
                s_t_1_0 = s_t_1_0 + s_t_1_new[0]
                s_t_1_1 = s_t_1_1 + s_t_1_new[1]
            s_t_1 = tuple([s_t_1_0, s_t_1_1])

        #encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
        #s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        target_words = []
        output_words = []
        id_to_words = {v: k for k, v in self.vocab.word_to_id.iteritems()}
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1_list, attn_dist_list, p_gen, next_coverage_list = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs_list, encoder_feature_list,
                enc_padding_mask_list, c_t_1_list, extra_zeros_list,
                enc_batch_extend_vocab_list, coverage_list, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            output_ids = final_dist.max(1)[1]
            output_2_candidates = final_dist.topk(2, 1)[1]
            for ind in range(output_ids.shape[0]):
                if self.vocab.word_to_id['X'] == output_ids[ind].item():
                    output_ids[ind] = output_2_candidates[ind][1]
            target_step = []
            output_step = []
            step_mask = dec_padding_mask[:, di]
            for i in range(target.shape[0]):
                if target[i].item() >= len(id_to_words) or step_mask[i].item(
                ) == 0:
                    target[i] = 0
                target_step.append(id_to_words[target[i].item()])
                if output_ids[i].item() >= len(
                        id_to_words) or step_mask[i].item() == 0:
                    output_ids[i] = 0
                output_step.append(id_to_words[output_ids[i].item()])
            target_words.append(target_step)
            output_words.append(output_step)
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                #step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                #step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                #coverage = next_coverage
                step_coverage_loss = 0.0
                for ind in range(len(coverage_list)):
                    step_coverage_loss += torch.sum(
                        torch.min(attn_dist_list[ind], coverage_list[ind]), 1)
                    coverage_list[ind] = next_coverage_list[ind]
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        self.write_words(output_words, "output.txt")
        self.write_words(target_words, "input.txt")

        sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_step_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        return loss.item()
Example #16
0
    def train_one_batch(self, batch, alpha, beta):

        #
        # print("BATCH")
        # print(batch)


        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        #
        # print("ENC_BATCH")
        # print(len(enc_batch))
        # print(len(enc_batch[0]))
        # print((enc_batch[0]))
        #
        # print("enc_padding_mask")
        # print(enc_padding_mask)
        # print(len(enc_padding_mask))
        # print(len(enc_padding_mask[0]))

        # print("enc_lens")
        # print(enc_lens)
        # print("enc_batch_extend_vocab")
        # print(enc_batch_extend_vocab)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)

        # print("encoder_outputs")
        # print(encoder_outputs.siz)

        s_t_1 = self.model.reduce_state(encoder_hidden)

        nll_list = []

        # sample_size 是啥?

        gen_summary = torch.LongTensor(
            config.batch_size * [config.sample_size * [[2]]])  # B x S x 1

        # print("gen_summary")
        # print(gen_summary.size())
        # print(gen_summary)

        if use_cuda: gen_summary = gen_summary.cuda()
        preds_y = gen_summary.squeeze(2)  # B x S

        # TODO: Print Gold Here!!!!
        # print("preds_y")
        # print(preds_y.size())
        # print(preds_y)
        # print(self.vocab.size())
        # print("temp")
        # from data import outputids2words
        # temp = outputids2words(list(map(lambda x : x.item(), dec_batch[1])),self.vocab,None)
        # print(temp)
        # # for item in dec_batch[1]:
        # #     temp = self.vocab.id2word(item.item())
        # #     from data import outputids2words(dec_batch[1])
        # #     print(temp)

        from data import outputids2words

        # print("dec_batch")
        # print(dec_batch[0])
        # temp = outputids2words(list(map(lambda x : x.item(), dec_batch[0])),self.vocab,None)
        # print(temp)
        # print()
        # print("target_batch")
        # print(target_batch[0])
        # temp = outputids2words(list(map(lambda x : x.item(), target_batch[0])),self.vocab,None)
        # print(temp)
        # print()

        for di in range(min(config.max_dec_steps, dec_batch.size(1))):
            # Select the current input word
            p1 = np.random.uniform()
            if p1 < alpha:  # use ground truth word
                y_t_1 = dec_batch[:, di]
            else:  # use decoded word
                y_t_1 = preds_y[:, 0]

            # print("y_t_1")
            # # print(y_t_1)
            # print("dec_batch")
            # print(dec_batch)
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)

            # Select the current output word
            p2 = np.random.uniform()
            if p2 < beta:  # sample the ground truth word
                target = target_batch[:, di]
                sampled_batch = torch.stack(config.sample_size * [target],
                                            1)  # B x S
            else:  # randomly sample a word with given probabilities
                sampled_batch = torch.multinomial(final_dist,
                                                  config.sample_size,
                                                  replacement=True)  # B x S

            # Compute the NLL
            probs = torch.gather(final_dist, 1, sampled_batch).squeeze()
            step_nll = -torch.log(probs + config.eps)

            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_nll = step_nll + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage
            nll_list.append(step_nll)

            # Store the decoded words in preds_y
            preds_y = gen_preds(sampled_batch, use_cuda)
            # Add the decoded words into gen_summary (mixed with ground truth and decoded words)
            gen_summary = torch.cat((gen_summary, preds_y.unsqueeze(2)),
                                    2)  # B x S x L

        # compute the REINFORCE score
        nll = torch.sum(torch.stack(nll_list, 2), 2)  # B x S
        all_rewards, avg_reward = compute_reward(batch, gen_summary,
                                                 self.vocab, config.mode,
                                                 use_cuda)  # B x S, 1
        batch_loss = torch.sum(nll * all_rewards, dim=1)  # B
        loss = torch.mean(batch_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()
        return loss.item(), avg_reward.item()
Example #17
0
    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder(enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        if config.use_maxpool_init_ctx:
            c_t_0 = max_encoder_output

        dec_h, dec_c = s_t_0 # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                      log_probs=[0.0],
                      state=(dec_h[0], dec_c[0]),
                      context = c_t_0[0],
                      coverage=(coverage_t_0[0] if config.is_coverage else None))
                 for _ in xrange(config.beam_size)]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h =[]
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t_1, s_t_1,
                                                        encoder_outputs, enc_padding_mask, c_t_1,
                                                        extra_zeros, enc_batch_extend_vocab, coverage_t_1)

            topk_log_probs, topk_ids = torch.topk(final_dist, config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in xrange(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in xrange(config.beam_size * 2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].data[0],
                                   log_prob=topk_log_probs[i, j].data[0],
                                   state=state_i,
                                   context=context_i,
                                   coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Example #18
0
    def train_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        if self.opt.train_mle == "yes":
            step_losses = []
            for di in range(min(max_dec_len, config.max_dec_steps)):
                y_t_1 = dec_batch[:, di]  # Teacher forcing
                final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                    y_t_1, s_t_1, encoder_outputs, encoder_feature,
                    enc_padding_mask, c_t_1, extra_zeros,
                    enc_batch_extend_vocab, coverage, di)
                target = target_batch[:, di]
                gold_probs = torch.gather(final_dist, 1,
                                          target.unsqueeze(1)).squeeze()
                step_loss = -torch.log(gold_probs + config.eps)
                if config.is_coverage:
                    step_coverage_loss = torch.sum(
                        torch.min(attn_dist, coverage), 1)
                    step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                    coverage = next_coverage

                step_mask = dec_padding_mask[:, di]
                step_loss = step_loss * step_mask
                step_losses.append(step_loss)

            sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
            batch_avg_loss = sum_losses / dec_lens_var
            mle_loss = torch.mean(batch_avg_loss)
        else:
            mle_loss = get_cuda(torch.FloatTensor([0]))
            # --------------RL training-----------------------------------------------------
        if self.opt.train_rl == "yes":  # perform reinforcement learning training
            # multinomial sampling
            sample_sents, RL_log_probs = self.train_batch_RL(
                encoder_outputs,
                encoder_hidden,
                enc_padding_mask,
                encoder_feature,
                enc_batch_extend_vocab,
                extra_zeros,
                c_t_1,
                batch.art_oovs,
                coverage,
                greedy=False)
            with torch.autograd.no_grad():
                # greedy sampling
                greedy_sents, _ = self.train_batch_RL(encoder_outputs,
                                                      encoder_hidden,
                                                      enc_padding_mask,
                                                      encoder_feature,
                                                      enc_batch_extend_vocab,
                                                      extra_zeros,
                                                      c_t_1,
                                                      batch.art_oovs,
                                                      coverage,
                                                      greedy=True)

            sample_reward = self.reward_function(sample_sents,
                                                 batch.original_abstracts)
            baseline_reward = self.reward_function(greedy_sents,
                                                   batch.original_abstracts)
            # if iter%200 == 0:
            #     self.write_to_file(sample_sents, greedy_sents, batch.original_abstracts, sample_reward, baseline_reward, iter)
            rl_loss = -(
                sample_reward - baseline_reward
            ) * RL_log_probs  # Self-critic policy gradient training (eq 15 in https://arxiv.org/pdf/1705.04304.pdf)
            rl_loss = torch.mean(rl_loss)

            batch_reward = torch.mean(sample_reward).item()
        else:
            rl_loss = get_cuda(torch.FloatTensor([0]))
            batch_reward = 0
        #loss.backward()
        (self.opt.mle_weight * mle_loss +
         self.opt.rl_weight * rl_loss).backward()
        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return mle_loss.item(), batch_reward
Example #19
0
    def beam_search(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, enc_batch_concept_extend_vocab, concept_p, position, concept_mask, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden, max_encoder_output, _, _ = self.model.encoder(
            enc_batch, enc_lens, enc_batch, enc_batch)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        if config.use_maxpool_init_ctx:
            c_t_0 = max_encoder_output

        dec_h, dec_c = s_t_0
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in xrange(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                'decode', y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
                c_t_1, extra_zeros, enc_batch_extend_vocab,
                enc_batch_concept_extend_vocab, concept_p, position,
                concept_mask, coverage_t_1, steps)

            topk_log_probs, topk_ids = torch.topk(final_dist,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in xrange(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in xrange(config.beam_size * 2):
                    new_beam = h.extend(token=topk_ids[i, j].data[0],
                                        log_prob=topk_log_probs[i, j].data[0],
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Example #20
0
    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in xrange(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            # to do here
            # ...
            latest_tokens = [h.latest_token for h in beams]
            # latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
            #                  for t in latest_tokens]

            y_t_1 = np.zeros((len(latest_tokens), config.emb_dim))
            for i, t in enumerate(latest_tokens):
                if t < self.vocab.size():
                    w = self.vocab.id2word(t)
                else:
                    idx = t - self.vocab.size()
                    if idx >= len(batch.art_oovs[i]):
                        w = data.UNKNOWN_TOKEN
                    else:
                        try:
                            w = batch.art_oovs[i][idx]
                        except Exception as e:
                            print(e.message)
                if w == data.START_DECODING:
                    embedd = embedding.start_decoding_embedd
                elif w == data.STOP_DECODING:
                    embedd = embedding.stop_decoding_embedd
                elif w == data.PAD_TOKEN:
                    embedd = embedding.padding_embedd
                elif w == data.UNKNOWN_TOKEN:
                    embedd = embedding.unknown_decoding_embedd
                else:
                    embedd = embedding.fasttext.get_numpy_vector(w.lower())
                y_t_1[i] = embedd

            # y_t_1 = Variable(torch.LongTensor(latest_tokens))
            y_t_1 = Variable(torch.FloatTensor(y_t_1))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage_t_1, steps)

            topk_log_probs, topk_ids = torch.topk(final_dist,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in xrange(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in xrange(config.beam_size *
                                2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Example #21
0
    def train_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage, enc_aspects = get_input_from_batch(
            batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch, dec_aspects = get_output_from_batch(
            batch, use_cuda)
        """
        dec_aspects and enc_aspects are a lists of aspects 
        """
        aspect_ids = [int(aspect[0]) for aspect in enc_aspects]

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens, aspect_ids)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()
Example #22
0
    def train_one_batch(self, batch):
        enc_batch_list, enc_padding_mask_list, enc_lens_list, enc_batch_extend_vocab_list, extra_zeros_list, c_t_1_list, coverage_list = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        encoder_outputs_list = []
        encoder_feature_list = []
        s_t_1 = None
        s_t_1_0 = None
        s_t_1_1 = None
        for enc_batch, enc_lens in zip(enc_batch_list, enc_lens_list):
            sorted_indices = sorted(range(len(enc_lens)),
                                    key=enc_lens.__getitem__)
            sorted_indices.reverse()
            inverse_sorted_indices = [-1 for _ in range(len(sorted_indices))]
            for index, position in enumerate(sorted_indices):
                inverse_sorted_indices[position] = index
            sorted_enc_batch = torch.index_select(
                enc_batch, 0,
                torch.LongTensor(sorted_indices)
                if not use_cuda else torch.LongTensor(sorted_indices).cuda())
            sorted_enc_lens = enc_lens[sorted_indices]
            sorted_encoder_outputs, sorted_encoder_feature, sorted_encoder_hidden = self.model.encoder(
                sorted_enc_batch, sorted_enc_lens)
            encoder_outputs = torch.index_select(
                sorted_encoder_outputs, 0,
                torch.LongTensor(inverse_sorted_indices) if not use_cuda else
                torch.LongTensor(inverse_sorted_indices).cuda())
            encoder_feature = torch.index_select(
                sorted_encoder_feature.view(encoder_outputs.shape), 0,
                torch.LongTensor(inverse_sorted_indices) if not use_cuda else
                torch.LongTensor(inverse_sorted_indices).cuda()).view(
                    sorted_encoder_feature.shape)
            encoder_hidden = tuple([
                torch.index_select(
                    sorted_encoder_hidden[0], 1,
                    torch.LongTensor(inverse_sorted_indices) if not use_cuda
                    else torch.LongTensor(inverse_sorted_indices).cuda()),
                torch.index_select(
                    sorted_encoder_hidden[1], 1,
                    torch.LongTensor(inverse_sorted_indices) if not use_cuda
                    else torch.LongTensor(inverse_sorted_indices).cuda())
            ])
            #encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
            encoder_outputs_list.append(encoder_outputs)
            encoder_feature_list.append(encoder_feature)
            if s_t_1 is None:
                s_t_1 = self.model.reduce_state(encoder_hidden)
                s_t_1_0, s_t_1_1 = s_t_1
            else:
                s_t_1_new = self.model.reduce_state(encoder_hidden)
                s_t_1_0 = s_t_1_0 + s_t_1_new[0]
                s_t_1_1 = s_t_1_1 + s_t_1_new[1]
            s_t_1 = tuple([s_t_1_0, s_t_1_1])

        #c_t_1_list = [c_t_1]
        #coverage_list = [coverage]

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1_list, attn_dist_list, p_gen, next_coverage_list = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs_list, encoder_feature_list,
                enc_padding_mask_list, c_t_1_list, extra_zeros_list,
                enc_batch_extend_vocab_list, coverage_list, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = 0.0
                for ind in range(len(coverage_list)):
                    step_coverage_loss += torch.sum(
                        torch.min(attn_dist_list[ind], coverage_list[ind]), 1)
                    coverage_list[ind] = next_coverage_list[ind]
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()
Example #23
0
    def decode(self):
        start = time.time()
        counter = 0
        #batch = self.batcher.next_batch()
        #print(batch.enc_batch)

        keep = True
        for batch in self.batches:
            #keep = False # one batch only

            # Run beam search to get best Hypothesis
            enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = get_input_from_batch(
                batch, use_cuda)

            enc_batch = enc_batch[0:1, :]
            enc_padding_mask = enc_padding_mask[0:1, :]

            in_seq = enc_batch
            in_pos = self.get_pos_data(enc_padding_mask)
            #print("enc_padding_mask", enc_padding_mask)

            #print("Summarizing one batch...")

            batch_hyp, batch_scores = self.summarize_batch(in_seq, in_pos)

            # Extract the output ids from the hypothesis and convert back to words
            output_words = np.array(batch_hyp)
            output_words = output_words[:, 0, 1:]

            for i, out_sent in enumerate(output_words):

                decoded_words = data.outputids2words(
                    out_sent, self.vocab,
                    (batch.art_oovs[0] if config.pointer_gen else None))

                original_abstract_sents = batch.original_abstracts_sents[i]

                write_for_rouge(original_abstract_sents, decoded_words,
                                counter, self._rouge_ref_dir,
                                self._rouge_dec_dir)
                counter += 1

            if counter % 1 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            #batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
Example #24
0
    def train_one_batch(self, batch, alpha, beta):
        # import pdb;
        # pdb.set_trace()
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        nll_list = []
        gen_summary = torch.LongTensor(
            config.batch_size * [config.sample_size * [[2]]])  # B x S x 1
        if use_cuda: gen_summary = gen_summary.cuda()
        preds_y = gen_summary.squeeze(2)  # B x S
        for di in range(min(config.max_dec_steps, dec_batch.size(1))):
            # Select the current input word
            p1 = np.random.uniform()
            if p1 < alpha:  # use ground truth word
                y_t_1 = dec_batch[:, di]
            else:  # use decoded word
                y_t_1 = preds_y[:, 0]

            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)

            # Select the current output word
            p2 = np.random.uniform()
            if p2 < beta:  # sample the ground truth word
                target = target_batch[:, di]
                sampled_batch = torch.stack(config.sample_size * [target],
                                            1)  # B x S
            else:  # randomly sample a word with given probabilities
                sampled_batch = torch.multinomial(final_dist,
                                                  config.sample_size,
                                                  replacement=True)  # B x S

            # Compute the NLL
            probs = torch.gather(final_dist, 1, sampled_batch).squeeze()
            step_nll = -torch.log(probs + config.eps)

            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_nll = step_nll + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage
            nll_list.append(step_nll)

            # Store the decoded words in preds_y
            preds_y = gen_preds(sampled_batch, use_cuda)
            # Add the decoded words into gen_summary (mixed with ground truth and decoded words)
            gen_summary = torch.cat((gen_summary, preds_y.unsqueeze(2)),
                                    2)  # B x S x L

        # compute the REINFORCE score
        nll = torch.sum(torch.stack(nll_list, 2), 2)  # B x S
        all_rewards, avg_reward = compute_reward(batch, gen_summary,
                                                 self.vocab, config.mode,
                                                 use_cuda)  # B x S, 1
        batch_loss = torch.sum(nll * all_rewards, dim=1)  # B
        loss = torch.mean(batch_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()
        return loss.item(), avg_reward.item()
    def train_one_batch(self, batch):

        ########### Below Two lines of code is for just initialization of Encoder and Decoder sizes,vocab, lenghts etc : ######
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()
        #print("train_one_batch function ......")
        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(
            encoder_hidden
        )  ### Here initially encoder final hiddenstate==decoder first/prev word at timestamp=0
        #print("s_t_1 : ",len(s_t_1),s_t_1[0].shape,s_t_1[1].shape)

        #print("steps.....")
        #print("max_dec_len = ",max_dec_len)
        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            ############ Traing [ Teacher Forcing ] ###########
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            #print("y_t_1 : ",len(y_t_1))
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            #print("attn_dist : ",len(attn_dist),attn_dist[0].shape)
            #print("final_dist : ",len(final_dist),final_dist[0].shape) ############## vocab_Size
            target = target_batch[:, di]
            #print("target = ",len(target))

            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(
                gold_probs + config.eps
            )  #################################################### Eqn_6
            if config.is_coverage:
                step_coverage_loss = torch.sum(
                    torch.min(attn_dist, coverage),
                    1)  ###############################Eqn_13a
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss  ###############################Eqn_13b
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()
Example #26
0
    def train_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)
        s_t_1_origin = s_t_1

        batch_size = batch.batch_size
        step_losses = []

        sample_idx = []
        sample_log_probs = Variable(torch.zeros(batch_size))
        baseline_idx = []

        for di in range(min(max_dec_len, config.max_dec_steps)):

            y_t_1 = dec_batch[:, di]  # Teacher forcing, shape [batch_size]
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, s_t_1,
                                                                                           encoder_outputs,
                                                                                           encoder_feature,
                                                                                           enc_padding_mask, c_t_1,
                                                                                           extra_zeros,
                                                                                           enc_batch_extend_vocab,
                                                                                           coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

            # sample
            if di == 0:  # use decoder input[0], which is <BOS>
                sample_t_1 = dec_batch[:, di]
                s_t_sample = s_t_1_origin
                c_t_sample = Variable(torch.zeros((batch_size, 2 * config.hidden_dim)))

            final_dist, s_t_sample, c_t_sample, attn_dist, p_gen, next_coverage = self.model.decoder(sample_t_1,
                                                                                                     s_t_sample,
                                                                                                     encoder_outputs,
                                                                                                     encoder_feature,
                                                                                                     enc_padding_mask,
                                                                                                     c_t_sample,
                                                                                                     extra_zeros,
                                                                                                     enc_batch_extend_vocab,
                                                                                                     coverage, di)
            # according to final_dist to sample
            # change sample_t_1
            dist = torch.distributions.Categorical(final_dist)
            sample_t_1 = Variable(dist.sample())
            # record sample idx
            sample_idx.append(sample_t_1)  # tensor list
            # compute sample probability
            sample_log_probs += torch.log(
                final_dist.gather(1, sample_t_1.view(-1, 1)))  # gather value along axis=1. given index

            # baseline
            if di == 0:  # use decoder input[0], which is <BOS>
                baseline_t_1 = dec_batch[:, di]
                s_t_sample = s_t_1_origin
                c_t_sample = Variable(torch.zeros((batch_size, 2 * config.hidden_dim)))

            final_dist, s_t_baseline, c_t_baseline, attn_dist, p_gen, next_coverage = self.model.decoder(baseline_t_1,
                                                                                                         s_t_baseline,
                                                                                                         encoder_outputs,
                                                                                                         encoder_feature,
                                                                                                         enc_padding_mask,
                                                                                                         c_t_baseline,
                                                                                                         extra_zeros,
                                                                                                         enc_batch_extend_vocab,
                                                                                                         coverage, di)
            # according to final_dist to get baseline
            # change baseline_t_1
            baseline_t_1 = torch.autograd.Variable(final_dist.max(1))  # get max value along axis=1
            # record baseline probability
            baseline_idx.append(baseline_t_1)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        # according to sample_idx and baseline_idx to compute RL loss
        # map sample/baseline_idx to string
        # compute rouge score
        # compute loss
        sample_idx = torch.stack(sample_idx, dim=1).squeeze()  # expect shape (batch_size, seq_len)
        baseline_idx = torch.stack(baseline_idx, dim=1).squeeze()
        rl_loss = torch.zeros(batch_size)
        for i in range(sample_idx.shape[0]):  # each example in a batch
            sample_y = data.outputids2words(sample_idx[i], self.vocab,
                                            (batch.art_oovs[i] if config.pointer_gen else None))
            baseline_y = data.outputids2words(baseline_idx[i], self.vocab,
                                              (batch.art_oovs[i] if config.pointer_gen else None))
            true_y = batch.original_abstracts[i]

            sample_score = rouge_l_f(sample_y, true_y)
            baseline_score = rouge_l_f(baseline_y, true_y)

            sample_score = Variable(sample_score)
            baseline_score = Variable(baseline_score)

            rl_loss[i] = baseline_score - sample_score
        rl_loss = rl_loss * sample_log_probs

        gamma = 0.9984
        loss = (1 - gamma) * loss + gamma * rl_loss

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)

        self.optimizer.step()

        return loss.item()
Example #27
0
    def train_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        if not config.is_hierarchical:
            encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
            s_t_1 = self.model.reduce_state.forward1(encoder_hidden)

        else:
            stop_id = self.vocab.word2id('.')
            pad_id  = self.vocab.word2id('[PAD]')
            enc_sent_pos = get_sent_position(enc_batch, stop_id, pad_id)
            dec_sent_pos = get_sent_position(dec_batch, stop_id, pad_id)

            encoder_outputs, encoder_feature, encoder_hidden, sent_enc_outputs, sent_enc_feature, sent_enc_hidden, sent_enc_padding_mask, sent_lens, seq_lens2 = \
                                                                    self.model.encoder(enc_batch, enc_lens, enc_sent_pos)

            s_t_1, sent_s_t_1 = self.model.reduce_state(encoder_hidden, sent_enc_hidden)
        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            if not config.is_hierarchical:
                # start = datetime.now()

                final_dist, s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder.forward1(y_t_1, s_t_1,
                                                            encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
                                                            extra_zeros, enc_batch_extend_vocab,
                                                                               coverage, di)
                # print('NO HIER Time: ',datetime.now() - start)
                # import pdb; pdb.set_trace()
            else:
                # start = datetime.now()
                max_doc_len = enc_batch.size(1)
                final_dist, sent_s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, sent_s_t_1,
                                                            encoder_outputs, encoder_feature, enc_padding_mask, seq_lens2,
                                                            sent_s_t_1, sent_enc_outputs, sent_enc_feature, sent_enc_padding_mask,
                                                            sent_lens, max_doc_len,
                                                            c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di)
                # print('DO HIER Time: ',datetime.now() - start)
                # import pdb; pdb.set_trace()


            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses/dec_lens_var
        loss = torch.mean(batch_avg_loss)

        # start = datatime.now()
        loss.backward()
        # print('{} HIER Time: {}'.format(config.is_hierarchical ,datetime.now() - start))
        # import pdb; pdb.set_trace()

        clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)

        self.optimizer.step()

        return loss.item()
Example #28
0
    def train_one_batch(self, batch):
        loss = 0
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        # print("Encoding lengths", enc_lens)  #(1,8)
        # print("Encoding batch", enc_batch.size()) #(8, 400)
        # print("c_t_1 is ", c_t_1.size()) # (8, 512)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)
        # print("Max Decoding lengths", max_dec_len)
        # print("Decoding lengths", dec_lens_var)
        # print("Decoding vectors", dec_batch[0])
        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
        # print("encoder_outputs", encoder_outputs.size()) # (8, 400, 512)
        # print("encoder_feature", encoder_feature.size()) # (3200, 512)
        # print("encoder_hidden", encoder_hidden[1].size()) # (2, 8, 256)


        s_t_1 = self.model.reduce_state(encoder_hidden) # (1, 8, 256)
        # print("After reduce_state, the hidden state s_t_1 is", s_t_1[0].size()) 

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)): 
            # print("Decoder step = ", di)
            y_t_1 = dec_batch[:, di]  # Teacher forcing  #  the dith word of all the examples/targets in a batch of shape (8,)
            final_dist, s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, s_t_1,
                                                        encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
                                                        extra_zeros, enc_batch_extend_vocab,
                                                                           coverage, di)

            # print("for di=", di, " final_dist", final_dist.size()) # (8, 50009)
            # print("for di=", di, " s_t_1", encoder_feature.size()) # (3200, 512)
            # print("for di=", di, " c_t_1", c_t_1.size()) # (8, 512)
            # print("for di=", di, " attn_dist", attn_dist.size()) # (8, 400)
            # print("for di=", di, " p_gen", p_gen.size()) # (8, 1)
            # print("for di=", di, " next_coverage", next_coverage.size())

            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage
                
            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses/dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()
        self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm) # gradient clipping
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)

        self.optimizer.step()

        return loss.item()
Example #29
0
    def train_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        # debug(batch.original_articles[0])
        # debug(batch.original_abstracts[0])
        loss_mask = self.get_loss_mask(enc_batch, dec_batch, batch.absts)
        # debug('loss_mask',loss_mask)
        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage, tau = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)

            # debug('enc_batch',enc_batch.size())
            # debug('dec_batch',dec_batch.size())
            # debug('final_dist', final_dist.size())
            # debug('target',target)
            # debug('gold_probs',gold_probs)

            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            # debug('step_loss_before',step_loss)
            # debug('config.loss_mask',config.loss_mask)
            if config.loss_mask:
                step_loss = step_loss * loss_mask
                # pass
            # debug('step_loss_after',step_loss)
            step_losses.append(step_loss)

            if config.DEBUG:
                # break
                pass

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        if not config.DEBUG:
            loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item(), tau
Example #30
0
    def train_one_batch(self, batch, steps, batch_ds):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, enc_batch_concept_extend_vocab, concept_p, position, concept_mask, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)
        enc_batch_ds, enc_padding_mask_ds, enc_lens_ds, _, _, _, _, _, _, _, _ = \
            get_input_from_batch(batch_ds, use_cuda)

        self.optimizer.zero_grad()
        encoder_outputs, encoder_hidden, max_encoder_output, enc_batch_ds_emb, dec_batch_emb = self.model.encoder(
            enc_batch, enc_lens, enc_batch_ds, dec_batch)
        if config.DS_train:
            ds_final_loss = self.ds_loss(enc_batch_ds_emb, enc_padding_mask_ds,
                                         dec_batch_emb, dec_padding_mask)
        s_t_1 = self.model.reduce_state(encoder_hidden)
        s_t_0 = s_t_1
        c_t_0 = c_t_1
        if config.use_maxpool_init_ctx:
            c_t_1 = max_encoder_output
            c_t_0 = c_t_1

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                'train', y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
                c_t_1, extra_zeros, enc_batch_extend_vocab,
                enc_batch_concept_extend_vocab, concept_p, position,
                concept_mask, coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        if config.DS_train:
            ds_final_loss = Variable(torch.FloatTensor([ds_final_loss]),
                                     requires_grad=False)
            ds_final_loss = ds_final_loss.cuda()
            loss = (config.pi - ds_final_loss) * torch.mean(batch_avg_loss)
        else:
            loss = torch.mean(batch_avg_loss)
        if steps > config.traintimes:
            scores = []
            sample_y = []
            s_t_1 = s_t_0
            c_t_1 = c_t_0
            for di in range(min(max_dec_len, config.max_dec_steps)):
                if di == 0:
                    y_t_1 = dec_batch[:, di]
                    sample_y.append(y_t_1.cpu().numpy().tolist())
                else:
                    sample_latest_tokens = sample_y[-1]
                    sample_latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                                            for t in sample_latest_tokens]

                    y_t_1 = Variable(torch.LongTensor(sample_latest_tokens))
                    y_t_1 = y_t_1.cuda()

                final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                    'train', y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
                    c_t_1, extra_zeros, enc_batch_extend_vocab,
                    enc_batch_concept_extend_vocab, concept_p, position,
                    concept_mask, coverage, di)
                sample_select = torch.multinomial(final_dist, 1).view(-1)
                sample_log_probs = torch.gather(
                    final_dist, 1, sample_select.unsqueeze(1)).squeeze()
                sample_y.append(sample_select.cpu().numpy().tolist())
                sample_step_loss = -torch.log(sample_log_probs + config.eps)
                sample_step_mask = dec_padding_mask[:, di]
                sample_step_loss = sample_step_loss * sample_step_mask
                scores.append(sample_step_loss)
            sample_sum_losses = torch.sum(torch.stack(scores, 1), 1)
            sample_batch_avg_loss = sample_sum_losses / dec_lens_var

            sample_y = np.transpose(sample_y).tolist()

            base_y = []
            s_t_1 = s_t_0
            c_t_1 = c_t_0
            for di in range(min(max_dec_len, config.max_dec_steps)):
                if di == 0:
                    y_t_1 = dec_batch[:, di]
                    base_y.append(y_t_1.cpu().numpy().tolist())
                else:
                    base_latest_tokens = base_y[-1]
                    base_latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                                            for t in base_latest_tokens]

                    y_t_1 = Variable(torch.LongTensor(base_latest_tokens))
                    y_t_1 = y_t_1.cuda()

                final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                    'train', y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
                    c_t_1, extra_zeros, enc_batch_extend_vocab,
                    enc_batch_concept_extend_vocab, concept_p, position,
                    concept_mask, coverage, di)
                base_log_probs, base_ids = torch.topk(final_dist, 1)
                base_y.append(base_ids[:, 0].cpu().numpy().tolist())

            base_y = np.transpose(base_y).tolist()

            refs = dec_batch.cpu().numpy().tolist()
            sample_dec_lens_var = map(int, dec_lens_var.cpu().numpy().tolist())
            sample_rougeL = [
                self.calc_Rouge_L(sample[:reflen],
                                  ref[:reflen]) for sample, ref, reflen in zip(
                                      sample_y, refs, sample_dec_lens_var)
            ]
            base_rougeL = [
                self.calc_Rouge_L(base[:reflen], ref[:reflen])
                for base, ref, reflen in zip(base_y, refs, sample_dec_lens_var)
            ]
            sample_rougeL = Variable(torch.FloatTensor(sample_rougeL),
                                     requires_grad=False)
            base_rougeL = Variable(torch.FloatTensor(base_rougeL),
                                   requires_grad=False)
            sample_rougeL = sample_rougeL.cuda()
            base_rougeL = base_rougeL.cuda()
            word_loss = -sample_batch_avg_loss * (base_rougeL - sample_rougeL)
            reinforce_loss = torch.mean(word_loss)
            loss = (1 - config.rein) * loss + config.rein * reinforce_loss

        loss.backward()

        clip_grad_norm(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm(self.model.reduce_state.parameters(),
                       config.max_grad_norm)

        self.optimizer.step()

        return loss.data[0]
Example #31
0
    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage_t_1, steps)
            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Example #32
0
    def train_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        # encoder_outputs shape = (batch_size, max_seq_len, 2*hidden_size)
        # encoder_feature shape = (batch_size*max_seq_len, 2*hidden_size)
        # encoder_hidden[0] shape = (batch, 2, hidden_size)
        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        # s_t_1[0] shape = (1, batch_size, hidden_size)
        s_t_1 = self.model.reduce_state(encoder_hidden)
        '''
        print('Actual enc_batch:')
        en_words = [self.vocab._id_to_word[idx] for idx in enc_batch[0].numpy()]
        print(en_words)
        print('Actual de_batch:')
        de_words = [self.vocab._id_to_word[idx] for idx in dec_batch[0].numpy()]
        print(de_words)
        print('Actual tar_batch:')
        tar_words = [self.vocab._id_to_word[idx] for idx in target_batch[0].numpy()]
        print(tar_words)
        '''

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()