Example #1
0
    def get_loss(self, word_input, target, length):
        """

        :param word_input: shape == (batch_size, max_len)
        :param target: shape == (batch_size, max_len)
        :param length: shape == (batch_size)
        :return:
        """
        enc_h_n, enc_c_n, enc_outputs = self.encoder(word_input)
        enc_h_n, enc_c_n = self.flatten_hidden_lstm(enc_h_n, enc_c_n)
        batch_size = enc_h_n.size(1)

        init_words = self.start_idx.repeat(batch_size, 1)
        # shape == (batch_size, max_len + 1)
        dec_input = torch.cat((init_words, target), dim=1)

        # shape == (max_len + 1, batch_size)
        dec_input = dec_input.permute(1, 0)

        # shape == (max_len+1, batch_size, tgt_vocab_size)
        predict, _, _ = self.core_decoder(dec_input, (enc_h_n, enc_c_n), enc_outputs, None)

        # shape == (batch_size, tgt_vocab_size, max_len+1)
        predict = predict.permute(1, 2, 0)

        end_words = self.end_idx.repeat(batch_size, 1)
        dec_target = torch.cat((target, end_words), dim=1)

        loss = self.xent(predict, dec_target)
        loss_mask = pytorch_utils.length_to_mask(length+1, max_len=self.max_length+1, dtype=torch.float)
        loss = torch.mul(loss, loss_mask)
        loss = torch.div(loss.sum(dim=1), (length+1).float())
        loss = loss.mean(dim=0)
        return loss
Example #2
0
    def get_loss(self, word_input, target, length):
        """

        :param word_input: shape == (batch, seq_len)
        :param target: shape == (batch, seq_len)
        :param length: shape == (batch)
        :return: Tensor shape == (batch, seq_len)
        """
        __max_length = self.max_length.item()
        __half_chunk_size = int(self.chunk_size.item() / 2)
        __batch_size = word_input.size(0)

        input_chunks = self.__chunking_sequence(word_input)
        target_chunks = self.__chunking_sequence(target)

        # shape == (batch, __max_length)
        mask = pytorch_utils.length_to_mask(length, max_len=__max_length)
        length_chunks = [torch.sum(mask[:, i:i+__half_chunk_size], dim=1)
                         for i in range(0, __max_length, __half_chunk_size)]
        loss = []
        previous_starts_idx = self.start_idx.repeat(__batch_size)
        for idx, (i_chunk, t_chunk) in enumerate(zip(input_chunks, target_chunks)):
            t_chunk = t_chunk[:, :__half_chunk_size]
            length_chunk = length_chunks[idx]
            loss.append(self.get_loss_chunk(i_chunk, t_chunk, length_chunk, previous_starts_idx))
            previous_starts_idx = t_chunk[:, -1]

        loss = torch.cat(loss, dim=1)
        loss = torch.div(loss.sum(dim=1), length.float())
        loss = loss.mean(dim=0)
        return loss
    def get_loss(self, inputs):
        src, tgt, seq_len = inputs
        max_length = src.size(1)
        assert max_length >= torch.max(seq_len).int().item()

        # shape == (batch, max_len, vocab_size)
        predict = self.model.get_logits(src)
        # shape == (batch, vocab_size, max_len)
        predict = predict.permute(0, 2, 1)
        loss = self.xent(predict, tgt)
        loss_mask = pytorch_utils.length_to_mask(seq_len, max_len=max_length, dtype=torch.float)
        loss = torch.mul(loss, loss_mask)
        loss = torch.div(loss.sum(dim=1), seq_len.float())
        loss = loss.mean(dim=0)
        return loss
    def get_loss_chunk(self, word_input, target, length, previous_starts_idx):
        """

        :param word_input: shape == (batch, seq_len)
        :param target: shape == (batch, seq_len/2)
        :param length: shape == (batch)
        :param previous_starts_idx: shape == (batch)
        :return: Tensor shape == (batch, seq_len/2)
        """
        assert target.size(1) * 2 == word_input.size(1)
        __half_seq_len = target.size(1)
        __batch_size = word_input.size(0)

        enc_h_n, enc_c_n, enc_outputs = self.encoder(word_input)
        enc_h_n, enc_c_n = self.flatten_hidden_lstm(enc_h_n, enc_c_n)

        # shape == (batch_size, seq_len/2)
        dec_inputs = torch.cat(
            (previous_starts_idx.view(__batch_size, 1), target[:, :-1]), dim=1)

        # shape == (seq_len/2, batch_size)
        dec_inputs = dec_inputs.permute(1, 0)

        enc_inputs = self.encoder.embedding(word_input[:, :__half_seq_len])
        # shape == (seq_len/2, batch, _)
        enc_inputs = enc_inputs.permute(1, 0, 2)

        # shape == (seq_len/2, batch_size, tgt_vocab_size)
        predict, _ = self.core_decoder(dec_inputs, (enc_h_n, enc_c_n),
                                       enc_outputs,
                                       enc_inputs,
                                       step=None)

        # shape == (batch_size, tgt_vocab_size, max_len+1)
        predict = predict.permute(1, 2, 0)

        dec_target = target
        loss = self.xent(predict, dec_target)
        __chunk_size = self.chunk_size.item()
        assert __chunk_size == word_input.size(1)

        loss_mask = pytorch_utils.length_to_mask(length,
                                                 max_len=__half_seq_len,
                                                 dtype=torch.float)
        loss = torch.mul(loss, loss_mask)

        return loss
        def loss(pred, sequence_len, y):
            """

            :param pred: shape == (batch_size, seq_len, no_class)
            :param sequence_len: shape == (batch_size, seq_len)
            :param y: shape == (batch_size, seq_len)
            :return:
            """
            max_len = pred.size(1)

            # shape == (batch_size, no_class, seq_len)
            pred = pred.permute(0, 2, 1)

            # shape == (batch_size, max_len)
            loss = mll_loss(pred, y)

            mask_loss = pytorch_utils.length_to_mask(length=sequence_len, max_len=max_len)
            loss *= mask_loss.float().to(device)

            return loss.sum(dim=1).mean(dim=0)
    def get_loss(self, word_input, target, length):
        """

        :param word_input: shape == (batch_size, max_len)
        :param target: shape == (batch_size, max_len)
        :param length: shape == (batch_size)
        :return:
        """
        max_length = word_input.size(1)

        # shape == (batch, max_len, vocab_size)
        predict = self.inner_forward(word_input)
        # shape == (batch, vocab_size, max_len)
        predict = predict.permute(0, 2, 1)

        loss = self.xent(predict, target)
        loss_mask = pytorch_utils.length_to_mask(length,
                                                 max_len=max_length,
                                                 dtype=torch.float)
        loss = torch.mul(loss, loss_mask)
        loss = torch.div(loss.sum(dim=1), length.float())
        loss = loss.mean(dim=0)
        return loss