Exemple #1
0
    def test_encoder_step_by_step(self):
        batch_size = 2
        seq_len = 2
        vocab_size = 100
        docs = torch.randint(vocab_size, size=(batch_size, seq_len))

        encoder = Encoder(vocab_size=vocab_size, is_bidirectional=False)
        encoder.eval()
        with torch.no_grad():
            h_n_1, c_n_1, _ = encoder(docs)

            h_n_2, c_n_2, _ = encoder(docs[:, 0:1])
            for step in range(1, seq_len):
                h_n_2, c_n_2, _ = encoder(docs[:, step:step+1], (h_n_2, c_n_2))

            self.assertEqual(torch.norm(h_n_1 - h_n_2), 0)
            self.assertEqual(torch.norm(c_n_1 - c_n_2), 0)
Exemple #2
0
    def test_encoder(self):

        docs = torch.Tensor([[1, 2, 3, 4], [1, 2, 2, 4]]).long()
        batch_size = docs.size(0)

        encoder = Encoder(vocab_size=5)
        h_n, c_n, _ = encoder(docs)

        self.assertEqual(h_n.shape, (6, batch_size, 512))
        self.assertEqual(c_n.shape, (6, batch_size, 512))
    def __init__(self, src_vocab_size, tgt_vocab_size, start_idx, end_idx, beam_width, device):
        super(Seq2SeqBeamAttnWithSrc, self).__init__()
        self.lr_rate = 1e-3
        self.max_length = 100
        self.__start_idx_int = start_idx

        self.encoder = Encoder(vocab_size=src_vocab_size)
        _enc_output_size = 2*self.encoder.lstm_size if self.encoder.is_bidirectional else self.encoder.lstm_size
        self.flatten_hidden_lstm = FlattenHiddenLSTM(lstm_num_layer=3, is_bidirectional=self.encoder.is_bidirectional)
        self.core_decoder = AttnRawDecoderWithSrc(vocab_size=tgt_vocab_size, enc_output_size=_enc_output_size,
                                                  enc_embedding_size=self.encoder.embedding_size)
        self.infer_module = BeamSearchWithSrcInfer(core_decoder=self.core_decoder, start_idx=start_idx,
                                                   beam_width=beam_width, device=device)

        self.xent = None
        self.optimizer = None

        self.register_buffer('start_idx', torch.Tensor([[start_idx]]).long())
        self.register_buffer('end_idx', torch.Tensor([[end_idx]]).long())
    def __init__(self, src_vocab_size, tgt_vocab_size, start_idx, end_idx):
        super(Seq2Seq, self).__init__()
        self.lr_rate = 1e-3
        self.max_length = 100
        self.__start_idx_int = start_idx

        self.encoder = Encoder(vocab_size=src_vocab_size)
        self.flatten_hidden_lstm = FlattenHiddenLSTM(lstm_num_layer=3, is_bidirectional=self.encoder.is_bidirectional)
        self.core_decoder = RawDecoder(vocab_size=tgt_vocab_size)
        self.greedy_infer = DecoderGreedyInfer(core_decoder=self.core_decoder, max_length=self.max_length,
                                               start_idx=start_idx)

        self.xent = None
        self.optimizer = None

        self.register_buffer('start_idx', torch.Tensor([[start_idx]]).long())
        self.register_buffer('end_idx', torch.Tensor([[end_idx]]).long())
class Seq2SeqAttnWithSrc(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, start_idx, end_idx):
        super(Seq2SeqAttnWithSrc, self).__init__()
        self.lr_rate = 1e-3
        self.max_length = 100
        self.__start_idx_int = start_idx

        self.encoder = Encoder(vocab_size=src_vocab_size)
        _enc_output_size = 2 * self.encoder.lstm_size if self.encoder.is_bidirectional else self.encoder.lstm_size
        self.flatten_hidden_lstm = FlattenHiddenLSTM(
            lstm_num_layer=3, is_bidirectional=self.encoder.is_bidirectional)
        self.core_decoder = AttnRawDecoderWithSrc(
            vocab_size=tgt_vocab_size,
            enc_output_size=_enc_output_size,
            enc_embedding_size=self.encoder.embedding_size)
        self.greedy_infer = DecoderGreedyWithSrcInfer(
            core_decoder=self.core_decoder, start_idx=start_idx)

        self.xent = None
        self.optimizer = None

        self.register_buffer('start_idx', torch.Tensor([[start_idx]]).long())
        self.register_buffer('end_idx', torch.Tensor([[end_idx]]).long())

    def forward(self, word_input, *args):
        """

        :param word_input: shape == (batch_size, max_len)
        :param args:
        :return:
        """
        h_n, c_n, outputs = self.encoder(word_input)
        h_n, c_n = self.flatten_hidden_lstm(h_n, c_n)

        enc_inputs = self.encoder.embedding(word_input)
        enc_inputs = enc_inputs.permute(1, 0, 2)
        output = self.greedy_infer(h_n, c_n, outputs, enc_inputs)
        return output

    def train(self, mode=True):
        if self.xent is None:
            self.xent = nn.CrossEntropyLoss(reduction='none')
        if self.optimizer is None:
            self.optimizer = optim.Adam(self.parameters(), lr=self.lr_rate)
        super().train(mode)

    def get_loss(self, word_input, target, length):
        """

        :param word_input: shape == (batch_size, max_len)
        :param target: shape == (batch_size, max_len)
        :param length: shape == (batch_size)
        :return:
        """

        enc_h_n, enc_c_n, enc_outputs = self.encoder(word_input)

        enc_h_n, enc_c_n = self.flatten_hidden_lstm(enc_h_n, enc_c_n)
        batch_size = enc_h_n.size(1)
        init_words = self.start_idx.repeat(batch_size, 1)
        end_words = self.end_idx.repeat(batch_size, 1)

        # shape == (batch_size, max_len + 1)
        dec_input = torch.cat((init_words, target), dim=1)

        # shape == (max_len + 1, batch_size)
        dec_input = dec_input.permute(1, 0)

        enc_inputs = self.encoder.embedding(word_input)
        # shape == (seq_len, batch, _)
        enc_inputs = enc_inputs.permute(1, 0, 2)
        end_words_embedding = self.encoder.embedding(end_words)
        # shape == (1, batch, _)
        end_words_embedding = end_words_embedding.permute(1, 0, 2)
        # shape == (seq_len+1, batch, _)
        enc_inputs = torch.cat((enc_inputs, end_words_embedding), dim=0)
        # shape == (max_len+1, batch_size, tgt_vocab_size)
        predict, _ = self.core_decoder(dec_input, (enc_h_n, enc_c_n),
                                       enc_outputs,
                                       enc_inputs,
                                       step=None)

        # shape == (batch_size, tgt_vocab_size, max_len+1)
        predict = predict.permute(1, 2, 0)

        dec_target = torch.cat((target, end_words), dim=1)

        loss = self.xent(predict, dec_target)
        loss_mask = pytorch_utils.length_to_mask(length + 1,
                                                 max_len=self.max_length + 1,
                                                 dtype=torch.float)
        loss = torch.mul(loss, loss_mask)
        loss = torch.div(loss.sum(dim=1), (length + 1).float())
        loss = loss.mean(dim=0)
        return loss

    def train_batch(self, word_input, target, length):
        """

        :param word_input: shape == (batch_size, max_len)
        :param target: shape == (batch_size, max_len)
        :return:
        """
        self.train()
        self.optimizer.zero_grad()
        loss = self.get_loss(word_input, target, length)
        loss.backward()
        self.optimizer.step()

        return loss.item()