Beispiel #1
0
    def compute_logits_step_by_step(self, src_seq, tgt_seq):
        encoded_source, src_mask = self.encoder(src_seq)
        decoding_cell = self.decoder.get_conditionalized_cell(encoded_source, src_mask)
        
        logits, decoder_state = decoding_cell.get_initial_logits()            
        
        from nmt_chainer.models.feedforward.utils import pad_data
        padded_tgt = pad_data(tgt_seq, pad_value=0)#, add_eos=self.decoder.eos_idx)
        
        decoder_device = self.decoder.get_device()
        if decoder_device is not None:
            padded_tgt = self.xp.array(padded_tgt)

        max_tgt_length = padded_tgt.shape[1]
        seq_padded_tgt = [padded_tgt[:, i:i+1] for i in range(max_tgt_length)]
        

#         loss = F.softmax_cross_entropy(F.reshape(logits, (-1, self.decoder.V+1)), padded_target_with_eos.reshape(-1,))
        
        mb_size = len(src_seq)
        result = [logits]
        
        
        for num_step in range(max_tgt_length):
#             print "num_step", num_step
#             print "logits shape", logits.shape
            prev_word = seq_padded_tgt[num_step]
#             print "prev w shape", prev_word.shape
            assert prev_word.shape == (mb_size, 1)
            
            logits, decoder_state = decoding_cell(decoder_state, prev_word)
            result.append(logits)
            
#         print "seq_padded_tgt", seq_padded_tgt
        return result  
Beispiel #2
0
    def compute_loss_step_by_step(self, src_seq, tgt_seq):
        encoded_source, src_mask = self.encoder(src_seq)
        decoding_cell = self.decoder.get_conditionalized_cell(
            encoded_source, src_mask)

        logits, decoder_state = decoding_cell.get_initial_logits()

        from nmt_chainer.models.feedforward.utils import pad_data
        padded_tgt = pad_data(tgt_seq,
                              pad_value=0)  #, add_eos=self.decoder.eos_idx)

        max_tgt_length = padded_tgt.shape[1]
        seq_padded_tgt = [
            padded_tgt[:, i] for i in six.moves.range(max_tgt_length)
        ]

        padded_target_with_eos = pad_data(tgt_seq,
                                          pad_value=-1,
                                          add_eos=self.decoder.eos_idx)

        #         loss = F.softmax_cross_entropy(F.reshape(logits, (-1, self.decoder.V+1)), padded_target_with_eos.reshape(-1,))

        mb_size = len(src_seq)
        result = [[] for _ in six.moves.range(mb_size)]

        num_step = 0
        while 1:
            #             print "logits shape", logits.shape
            prev_word = padded_tgt[num_step]
            #             print "prev w shape", prev_word.shape
            assert prev_word.shape == (mb_size, 1)
            for i in six.moves.range(mb_size):
                result[i].append(prev_word[i, 0])

            prev_word = self.xp.where(prev_word == self.decoder.eos_idx, 0,
                                      prev_word)
            num_step += 1
            if num_step >= max_tgt_length:
                break
            logits, decoder_state = decoding_cell(decoder_state,
                                                  prev_word,
                                                  train=train)
        return result
Beispiel #3
0
 def compute_loss(self, seq_list, encoded_input, mask_input, reduce="mean"):
     logits = self.compute_logits(seq_list, encoded_input, mask_input)
     padded_target_with_eos = pad_data(seq_list,
                                       pad_value=-1,
                                       add_eos=self.eos_idx)
     padded_target_with_eos = self.move_np_array_to_correct_device(
         padded_target_with_eos)
     loss = F.softmax_cross_entropy(F.reshape(logits, (-1, self.V + 1)),
                                    padded_target_with_eos.reshape(-1, ),
                                    reduce=reduce)
     return loss
Beispiel #4
0
    def make_batch(self, seq_list):
        padded_data = pad_data(seq_list, pad_value=0)
        seq_length = [len(x) + 1 for x in seq_list]  #BOS
        max_length_1 = max(seq_length)
        max_length_2 = max_length_1
        mb_size = len(seq_list)
        mask = make_batch_mask(
            mb_size,
            self.n_heads,
            max_length_1,
            max_length_2,
            #                     key_seq_lengths=seq_length, #actually not needed
            future_mask=True,
            mask_value=-10000)

        padded_data = self.move_np_array_to_correct_device(padded_data)
        mask = self.move_np_array_to_correct_device(mask)

        return padded_data, mask
Beispiel #5
0
    def make_batch(self, seq_list):
        padded_data = pad_data(seq_list, pad_value=0)
        seq_length = [len(x) for x in seq_list]
        max_length_1 = max(seq_length)
        max_length_2 = max_length_1
        mb_size = len(seq_list)
        mask = make_batch_mask(mb_size,
                               self.n_heads,
                               max_length_1,
                               max_length_2,
                               key_seq_lengths=seq_length,
                               future_mask=False,
                               mask_value=-10000)

        device = self.get_device()
        if device is not None:
            with device:
                padded_data = self.xp.array(padded_data)
                mask = self.xp.array(mask)

        return padded_data, mask