def evaluate_batch(self, input_batch, input_chunk, out_batch, input_mask,
                       target_mask, kb, kb_mask, sentient_orig):
        """
        evaluating batch
        :param input_batch:
        :param out_batch:
        :param input_mask:
        :param target_mask:
        :return:
        """
        # Set to not-training mode to disable dropout
        self.encoder.train(False)
        self.decoder.train(False)
        #self.embedding.train(False)

        #inp_emb = self.embedding(input_batch)
        # output decoder words

        encoder_outputs, encoder_hidden = self.encoder(input_batch, input_mask)
        b_size = input_batch.size(1)
        #target_len = torch.sum(target_mask, dim=0)
        target_len = out_batch.size(0)
        #print (min(max(target_len), self.max_r))
        max_target_length = (min(target_len, self.max_r))
        #print (max_target_length)
        if not isinstance(max_target_length, int):
            max_target_length = int(
                max_target_length.cpu().numpy()) if self.use_cuda else int(
                    max_target_length.numpy())

        # Prepare input and output variables
        if self.use_cuda:
            decoder_input = Variable(torch.Tensor([self.sos_tok] *
                                                  b_size)).long().cuda()
            sentinel_values = Variable(
                torch.zeros(int(max_target_length), b_size)).cuda()
            all_decoder_outputs_vocab = Variable(
                torch.zeros(int(max_target_length), b_size,
                            self.output_size)).cuda()
        else:
            decoder_input = Variable(torch.Tensor([self.sos_tok] *
                                                  b_size)).long()
            sentinel_values = Variable(
                torch.zeros(int(max_target_length), b_size))
            all_decoder_outputs_vocab = Variable(
                torch.zeros(int(max_target_length), b_size, self.output_size))

        decoded_words = Variable(torch.zeros(int(max_target_length), b_size)).cuda() if self.use_cuda else \
                        Variable(torch.zeros(int(max_target_length), b_size))
        decoder_hidden = (encoder_hidden[0][:self.decoder.n_layers],
                          encoder_hidden[1][:self.decoder.n_layers])
        # provide data to decoder
        for t in range(max_target_length):
            #print (decoder_input)
            #inp_emb_d = self.embedding(decoder_input)
            decoder_vocab, decoder_hidden = self.decoder(
                decoder_input, decoder_hidden, encoder_outputs, input_mask)
            sentient_gate, obj = self.sentinel_g(input_batch, input_chunk,
                                                 input_mask, decoder_input, kb,
                                                 kb_mask,
                                                 sentinel_values[t - 1])
            # print (sentient_gate.size())
            # obj_output = (torch.cat([vocab_pad, obj], dim=-1))
            s = F.sigmoid(sentient_gate)
            sentinel_values[t] = s.squeeze()
            obj = s * obj
            decoder_vocab = (1 - s) * decoder_vocab
            decoder_vocab = decoder_vocab.scatter_add(
                1,
                self.kg_vocab.repeat(b_size).view(b_size, self.kb_max_size),
                obj)
            all_decoder_outputs_vocab[t] = decoder_vocab
            topv, topi = decoder_vocab.data.topk(
                1)  # get prediction from decoder
            decoder_input = Variable(
                topi.view(-1))  # use this in the next time-steps
            decoded_words[t] = (topi.view(-1))

        target_mask = target_mask.transpose(0, 1).contiguous()

        loss_Vocab = masked_cross_entropy(
            all_decoder_outputs_vocab.transpose(
                0, 1).contiguous(),  # -> B x S X VOCAB
            out_batch.transpose(0, 1).contiguous(),  # -> B x S
            target_mask)

        # Set back to training mode
        self.encoder.train(True)
        self.decoder.train(True)
        #self.embedding.train(True)

        return decoded_words, loss_Vocab
    def train_batch(self, input_batch, input_chunk, out_batch, input_mask,
                    target_mask, kb, kb_mask, sentient_orig):

        self.encoder.train(True)
        self.decoder.train(True)
        #self.embedding.train(True)

        #inp_emb = self.embedding(input_batch)
        #print (len(out_batch))
        b_size = input_batch.size(1)
        #print (b_size)
        # Zero gradients of both optimizers
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        loss_Vocab, loss_Ptr, loss_Gate = 0, 0, 0
        # Run words through encoder
        #input_len = torch.sum(input_mask, dim=0)
        encoder_outputs, encoder_hidden = self.encoder(input_batch, input_mask)

        #target_len = torch.sum(target_mask, dim=0)
        target_len = out_batch.size(0)
        #print (min(max(target_len), self.max_r))
        max_target_length = min(target_len, self.max_r)
        #print (max_target_length)
        if not isinstance(max_target_length, int):
            max_target_length = int(
                max_target_length.cpu().numpy()) if self.use_cuda else int(
                    max_target_length.numpy())

        # Prepare input and output variables
        if self.use_cuda:
            decoder_input = Variable(torch.Tensor([self.sos_tok] *
                                                  b_size)).cuda().long()
            sentinel_values = Variable(
                torch.zeros(int(max_target_length), b_size)).cuda()
            all_decoder_outputs_vocab = Variable(
                torch.zeros(int(max_target_length), b_size,
                            self.output_size)).cuda()
        else:
            decoder_input = Variable(torch.Tensor([self.sos_tok] *
                                                  b_size)).long()
            sentinel_values = Variable(
                torch.zeros(int(max_target_length), b_size))
            all_decoder_outputs_vocab = Variable(
                torch.zeros(int(max_target_length), b_size, self.output_size))

        decoder_hidden = (encoder_hidden[0][:self.decoder.n_layers],
                          encoder_hidden[1][:self.decoder.n_layers])

        # Choose whether to use teacher forcing
        use_teacher_forcing = random.randint(0,
                                             10) < self.teacher_forcing_ratio

        if use_teacher_forcing:
            for t in range(max_target_length):
                #inp_emb_d = self.embedding(decoder_input)
                #print (decoder_input.size())
                decoder_vocab, decoder_hidden = self.decoder(
                    decoder_input, decoder_hidden, encoder_outputs, input_mask)
                # sentient gating
                sentient_gate, obj = self.sentinel_g(input_batch, input_chunk,
                                                     input_mask, decoder_input,
                                                     kb, kb_mask,
                                                     sentinel_values[t - 1])
                #s = sentient_orig[t].reshape(b_size, 1)
                s = F.sigmoid(sentient_gate)
                obj = s * obj
                decoder_vocab = (1 - s) * decoder_vocab
                decoder_vocab = decoder_vocab.scatter_add(
                    1,
                    self.kg_vocab.repeat(b_size).view(b_size,
                                                      self.kb_max_size), obj)
                sentinel_values[t] = F.sigmoid(sentient_gate).squeeze()
                all_decoder_outputs_vocab[t] = decoder_vocab
                decoder_input = out_batch[t].long(
                )  # Next input is current target
        else:
            print('Not TF..')
            for t in range(max_target_length):
                #inp_emb_d = self.embedding(decoder_input)
                decoder_vocab, decoder_hidden = self.decoder(
                    decoder_input, decoder_hidden, encoder_outputs, input_mask)
                all_decoder_outputs_vocab[t] = decoder_vocab
                sentient_gate, obj = self.sentinel_g(input_batch, input_chunk,
                                                     input_mask, decoder_input,
                                                     kb, kb_mask,
                                                     sentinel_values[t - 1])
                s = F.sigmoid(sentient_gate)
                sentinel_values[t] = s.squeeze()
                obj = s * obj
                decoder_vocab = (1 - s) * decoder_vocab
                decoder_vocab = decoder_vocab.scatter_add(
                    1,
                    self.kg_vocab.repeat(b_size).view(b_size, 200), obj)
                all_decoder_outputs_vocab[t] = decoder_vocab
                topv, topi = decoder_vocab.data.topk(
                    1)  # get prediction from decoder
                decoder_input = Variable(
                    topi.view(-1))  # use this in the next time-steps

        #print (all_decoder_outputs_vocab.size(), out_batch.size())
        #out_batch = out_batch.transpose(0, 1).contiguous
        target_mask = target_mask.transpose(0, 1).contiguous()
        #print (all_decoder_outputs_vocab.size(), out_batch.size(), target_mask.size())
        loss_Vocab = masked_cross_entropy(
            all_decoder_outputs_vocab.transpose(
                0, 1).contiguous(),  # -> B x S X VOCAB
            out_batch.transpose(0, 1).contiguous(),  # -> B x S
            target_mask)
        sentiental_loss = self.sentient_loss(sentinel_values, sentient_orig)

        loss = loss_Vocab + sentiental_loss
        loss.backward()

        # clip gradient
        torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), self.clip)
        torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), self.clip)
        torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip)

        # Update parameters with optimizers
        self.encoder_optimizer.step()
        self.decoder_optimizer.step()
        self.loss += loss.item()