def __init__(self, enc_embedding_weight, dec_embedding_weight, start_idx):
        super(MainModel, self).__init__()
        self.lr_rate = 1e-3
        self.max_length = 100
        self.__start_idx_int = start_idx

        self.encoder = Encoder(
            embedding=create_my_embedding(enc_embedding_weight),
            lstm_num_layer=1,
            lstm_size=1024)
        _enc_output_size = 2 * self.encoder.lstm_size if self.encoder.is_bidirectional else self.encoder.lstm_size
        self.flatten_hidden_lstm = FlattenHiddenLSTM(
            lstm_num_layer=self.encoder.lstm_num_layer,
            is_bidirectional=self.encoder.is_bidirectional)

        self.core_decoder = AttnRawDecoderWithSrc(
            embedding=create_my_embedding(dec_embedding_weight),
            enc_output_size=_enc_output_size,
            use_pred_prob=0.1,
            lstm_size=self.encoder.lstm_size,
            lstm_num_layer=1,
            enc_embedding_size=self.encoder.embedding_size)
        self.greedy_infer = DecoderGreedyInfer(core_decoder=self.core_decoder,
                                               start_idx=start_idx)

        self.xent = None
        self.optimizer = None

        self.register_buffer('start_idx', torch.Tensor([[start_idx]]).long())
    def __init__(self, src_vocab_size, tgt_vocab_size, start_idx, padding_idx,
                 max_length):
        super(Seq2SeqChunk, self).__init__()

        pytorch_utils.register_buffer(self, 'lr_rate', 1e-3)
        pytorch_utils.register_buffer(self, 'max_length', max_length)
        pytorch_utils.register_buffer(self, 'chunk_size', 10)

        self.__start_idx_int = start_idx
        self.__padding_idx_int = padding_idx

        self.encoder = Encoder(vocab_size=src_vocab_size,
                               is_bidirectional=False)
        _enc_output_size = 2 * self.encoder.lstm_size.item(
        ) if self.encoder.is_bidirectional.item(
        ) else self.encoder.lstm_size.item()
        self.flatten_hidden_lstm = FlattenHiddenLSTM(
            lstm_num_layer=3,
            is_bidirectional=bool(self.encoder.is_bidirectional.item()))
        self.core_decoder = AttnRawDecoderWithSrc(
            vocab_size=tgt_vocab_size,
            enc_output_size=_enc_output_size,
            enc_embedding_size=self.encoder.embedding_size.item())
        self.greedy_infer = DecoderGreedyWithSrcInfer(
            core_decoder=self.core_decoder)

        self.xent = None
        self.optimizer = None

        self.register_buffer('start_idx', torch.Tensor([start_idx]).long())
        self.register_buffer('padding_idx',
                             torch.Tensor([[padding_idx]]).long())
    def test_encoder(self):

        docs = torch.Tensor([[1, 2, 3, 4], [1, 2, 2, 4]]).long()
        batch_size = docs.size(0)

        encoder = Encoder(embedding=create_my_embedding(np.random.rand(10, 5)))
        h_n, c_n, _ = encoder(docs)

        self.assertEqual(h_n.shape, (6, batch_size, 512))
        self.assertEqual(c_n.shape, (6, batch_size, 512))
    def test_encoder_step_by_step(self):
        batch_size = 2
        seq_len = 2
        vocab_size = 100
        docs = torch.randint(vocab_size, size=(batch_size, seq_len))

        encoder = Encoder(embedding=create_my_embedding(
            np.random.rand(vocab_size, 5)),
                          is_bidirectional=False)
        encoder.eval()
        with torch.no_grad():
            h_n_1, c_n_1, _ = encoder(docs)

            h_n_2, c_n_2, _ = encoder(docs[:, 0:1])
            for step in range(1, seq_len):
                h_n_2, c_n_2, _ = encoder(docs[:, step:step + 1],
                                          (h_n_2, c_n_2))

            self.assertEqual(torch.norm(h_n_1 - h_n_2), 0)
            self.assertEqual(torch.norm(c_n_1 - c_n_2), 0)
Exemple #5
0
    def __init__(self, src_vocab_size, tgt_vocab_size, start_idx, end_idx):
        super(Seq2SeqHugeFeedingAttn, self).__init__()
        self.lr_rate = 1e-3
        self.max_length = 100
        self.__start_idx_int = start_idx

        self.encoder = Encoder(vocab_size=src_vocab_size, lstm_size=1024, lstm_num_layer=4)
        _enc_output_size = 2*self.encoder.lstm_size if self.encoder.is_bidirectional else self.encoder.lstm_size
        self.flatten_hidden_lstm = FlattenHiddenLSTM(lstm_num_layer=self.encoder.lstm_num_layer,
                                                     is_bidirectional=self.encoder.is_bidirectional)
        self.core_decoder = AttnRawDecoder(vocab_size=tgt_vocab_size, enc_output_size=_enc_output_size,
                                           lstm_size=1024, lstm_num_layer=4)
        self.greedy_infer = DecoderGreedyInfer(core_decoder=self.core_decoder, max_length=self.max_length,
                                               start_idx=start_idx)

        self.xent = None
        self.optimizer = None

        self.register_buffer('start_idx', torch.Tensor([[start_idx]]).long())
        self.register_buffer('end_idx', torch.Tensor([[end_idx]]).long())
class Seq2SeqChunk(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, start_idx, padding_idx,
                 max_length):
        super(Seq2SeqChunk, self).__init__()

        pytorch_utils.register_buffer(self, 'lr_rate', 1e-3)
        pytorch_utils.register_buffer(self, 'max_length', max_length)
        pytorch_utils.register_buffer(self, 'chunk_size', 10)

        self.__start_idx_int = start_idx
        self.__padding_idx_int = padding_idx

        self.encoder = Encoder(vocab_size=src_vocab_size,
                               is_bidirectional=False)
        _enc_output_size = 2 * self.encoder.lstm_size.item(
        ) if self.encoder.is_bidirectional.item(
        ) else self.encoder.lstm_size.item()
        self.flatten_hidden_lstm = FlattenHiddenLSTM(
            lstm_num_layer=3,
            is_bidirectional=bool(self.encoder.is_bidirectional.item()))
        self.core_decoder = AttnRawDecoderWithSrc(
            vocab_size=tgt_vocab_size,
            enc_output_size=_enc_output_size,
            enc_embedding_size=self.encoder.embedding_size.item())
        self.greedy_infer = DecoderGreedyWithSrcInfer(
            core_decoder=self.core_decoder)

        self.xent = None
        self.optimizer = None

        self.register_buffer('start_idx', torch.Tensor([start_idx]).long())
        self.register_buffer('padding_idx',
                             torch.Tensor([[padding_idx]]).long())

    def chunk_forward(self, word_input, h_c, starts_idx, *args):
        """
        Encoding procedure is the same, but only decoding the first half of the sequence
        :param word_input: shape == (batch_size, max_len)
        :param h_c: tuple of (h, c). Set it None to indicate the start of the sequence
        :param starts_idx: Tensor shape == (batch)
        :param args:
        :return: Tensor shape == (batch, seq_len)
        """
        if h_c is not None:
            h_n, c_n, outputs = self.encoder(word_input)
        else:
            h_n, c_n, outputs = self.encoder(word_input, h_c)
        h_n, c_n = self.flatten_hidden_lstm(h_n, c_n)

        seq_len = word_input.size(1)
        assert seq_len % 2 == 0
        word_input = word_input[:, :int(seq_len / 2)]
        enc_inputs = self.encoder.embedding(word_input)
        enc_inputs = enc_inputs.permute(1, 0, 2)
        output = self.greedy_infer(h_n, c_n, outputs, enc_inputs, starts_idx)
        return output, (h_n, c_n)

    def forward(self, word_input, *args):
        """

        :param word_input: shape == (batch_size, seq_len)
        :param args:
        :return: Tensor shape == (batch, seq_len)
        """
        __batch_size = word_input.size(0)

        input_chunks = self.__chunking_sequence(word_input)
        h_c = None
        output = []
        previous_starts_idx = self.start_idx.repeat(__batch_size)
        for i_chunk in input_chunks:
            output_chunk, h_c = self.chunk_forward(i_chunk, h_c,
                                                   previous_starts_idx)
            output.append(output_chunk)
            previous_starts_idx = output_chunk[:, -1]

        output = torch.cat(output, dim=1)
        seq_len = word_input.size(1)
        output = output[:, :seq_len]
        return output

    def train(self, mode=True):
        if self.xent is None:
            self.xent = nn.CrossEntropyLoss(reduction='none')
        if self.optimizer is None:
            self.optimizer = optim.Adam(self.parameters(),
                                        lr=self.lr_rate.item())
        super().train(mode)

    def get_loss_chunk(self, word_input, target, length, previous_starts_idx):
        """

        :param word_input: shape == (batch, seq_len)
        :param target: shape == (batch, seq_len/2)
        :param length: shape == (batch)
        :param previous_starts_idx: shape == (batch)
        :return: Tensor shape == (batch, seq_len/2)
        """
        assert target.size(1) * 2 == word_input.size(1)
        __half_seq_len = target.size(1)
        __batch_size = word_input.size(0)

        enc_h_n, enc_c_n, enc_outputs = self.encoder(word_input)
        enc_h_n, enc_c_n = self.flatten_hidden_lstm(enc_h_n, enc_c_n)

        # shape == (batch_size, seq_len/2)
        dec_inputs = torch.cat(
            (previous_starts_idx.view(__batch_size, 1), target[:, :-1]), dim=1)

        # shape == (seq_len/2, batch_size)
        dec_inputs = dec_inputs.permute(1, 0)

        enc_inputs = self.encoder.embedding(word_input[:, :__half_seq_len])
        # shape == (seq_len/2, batch, _)
        enc_inputs = enc_inputs.permute(1, 0, 2)

        # shape == (seq_len/2, batch_size, tgt_vocab_size)
        predict, _ = self.core_decoder(dec_inputs, (enc_h_n, enc_c_n),
                                       enc_outputs,
                                       enc_inputs,
                                       step=None)

        # shape == (batch_size, tgt_vocab_size, max_len+1)
        predict = predict.permute(1, 2, 0)

        dec_target = target
        loss = self.xent(predict, dec_target)
        __chunk_size = self.chunk_size.item()
        assert __chunk_size == word_input.size(1)

        loss_mask = pytorch_utils.length_to_mask(length,
                                                 max_len=__half_seq_len,
                                                 dtype=torch.float)
        loss = torch.mul(loss, loss_mask)

        return loss

    def get_loss(self, word_input, target, length):
        """

        :param word_input: shape == (batch, seq_len)
        :param target: shape == (batch, seq_len)
        :param length: shape == (batch)
        :return: Tensor shape == (batch, seq_len)
        """
        __max_length = self.max_length.item()
        __half_chunk_size = int(self.chunk_size.item() / 2)
        __batch_size = word_input.size(0)

        input_chunks = self.__chunking_sequence(word_input)
        target_chunks = self.__chunking_sequence(target)

        # shape == (batch, __max_length)
        mask = pytorch_utils.length_to_mask(length, max_len=__max_length)
        length_chunks = [
            torch.sum(mask[:, i:i + __half_chunk_size], dim=1)
            for i in range(0, __max_length, __half_chunk_size)
        ]
        loss = []
        previous_starts_idx = self.start_idx.repeat(__batch_size)
        for idx, (i_chunk,
                  t_chunk) in enumerate(zip(input_chunks, target_chunks)):
            t_chunk = t_chunk[:, :__half_chunk_size]
            length_chunk = length_chunks[idx]
            loss.append(
                self.get_loss_chunk(i_chunk, t_chunk, length_chunk,
                                    previous_starts_idx))
            previous_starts_idx = t_chunk[:, -1]

        loss = torch.cat(loss, dim=1)
        loss = torch.div(loss.sum(dim=1), length.float())
        loss = loss.mean(dim=0)
        return loss

    def __chunking_sequence(self, word_input):
        """

        :param word_input:
        :return: List chunks
        """
        assert self.chunk_size % 2 == 0

        seq_len = word_input.size(1)
        batch_size = word_input.size(0)

        must_have = int(
            int(seq_len / (self.chunk_size / 2)) * (self.chunk_size / 2) +
            self.chunk_size)
        no_padding = must_have - seq_len
        padding = self.padding_idx.repeat(batch_size, no_padding)
        word_input = torch.cat((word_input, padding), dim=1)
        input_chunks = [
            word_input[:, i:i + self.chunk_size]
            for i in range(0, seq_len, int(self.chunk_size / 2))
        ]
        return input_chunks

    def train_batch(self, word_input, target, length):
        """

        :param word_input: shape == (batch_size, max_len)
        :param target: shape == (batch_size, max_len)
        :return:
        """
        self.train()
        self.optimizer.zero_grad()
        loss = self.get_loss(word_input, target, length)
        loss.backward()
        self.optimizer.step()

        return loss.item()
class MainModel(nn.Module):
    def __init__(self, enc_embedding_weight, dec_embedding_weight, start_idx):
        super(MainModel, self).__init__()
        self.lr_rate = 1e-3
        self.max_length = 100
        self.__start_idx_int = start_idx

        self.encoder = Encoder(
            embedding=create_my_embedding(enc_embedding_weight),
            lstm_num_layer=1,
            lstm_size=1024)
        _enc_output_size = 2 * self.encoder.lstm_size if self.encoder.is_bidirectional else self.encoder.lstm_size
        self.flatten_hidden_lstm = FlattenHiddenLSTM(
            lstm_num_layer=self.encoder.lstm_num_layer,
            is_bidirectional=self.encoder.is_bidirectional)

        self.core_decoder = AttnRawDecoderWithSrc(
            embedding=create_my_embedding(dec_embedding_weight),
            enc_output_size=_enc_output_size,
            use_pred_prob=0.1,
            lstm_size=self.encoder.lstm_size,
            lstm_num_layer=1,
            enc_embedding_size=self.encoder.embedding_size)
        self.greedy_infer = DecoderGreedyInfer(core_decoder=self.core_decoder,
                                               start_idx=start_idx)

        self.xent = None
        self.optimizer = None

        self.register_buffer('start_idx', torch.Tensor([[start_idx]]).long())
        # self.register_buffer('end_idx', torch.Tensor([[end_idx]]).long())

    def forward(self, word_input, *args):
        """

        :param word_input: shape == (batch_size, max_len)
        :param start_idx: int scala
        :param end_idx: int scala
        :param args:
        :return:
        """
        h_n, c_n, outputs = self.encoder(word_input)
        h_n, c_n = self.flatten_hidden_lstm(h_n, c_n)
        h_c = (h_n, c_n)

        enc_inputs = self.encoder.embedding(word_input)
        enc_inputs = enc_inputs.permute(1, 0, 2)

        output = self.greedy_infer(h_c, outputs, enc_inputs)
        return output

    def train(self, mode=True):
        if self.xent is None:
            self.xent = nn.CrossEntropyLoss(reduction='none')
        if self.optimizer is None:
            self.optimizer = optim.Adam(self.parameters(), lr=self.lr_rate)
        super().train(mode)

    def get_loss(self, word_input, target, length):
        """

        :param word_input: shape == (batch_size, max_len)
        :param target: shape == (batch_size, max_len)
        :param length: shape == (batch_size)
        :return:
        """
        enc_h_n, enc_c_n, enc_outputs = self.encoder(word_input)
        enc_h_n, enc_c_n = self.flatten_hidden_lstm(enc_h_n, enc_c_n)
        batch_size = enc_h_n.size(1)

        init_words = self.start_idx.repeat(batch_size, 1)
        # shape == (batch_size, max_len)
        dec_input = torch.cat((init_words, target), dim=1)

        # shape == (max_len, batch_size)
        dec_input = dec_input.permute(1, 0)[:-1]

        # shape ==  (batch_size, max_len, _)
        enc_inputs = self.encoder.embedding(word_input)
        # shape ==  (max_len, batch_size, _)
        enc_inputs = enc_inputs.permute(1, 0, 2)
        # shape == (max_len+1, batch_size, tgt_vocab_size)
        predict, _, _ = self.core_decoder(dec_input, (enc_h_n, enc_c_n),
                                          enc_outputs, None, enc_inputs)

        # shape == (batch_size, tgt_vocab_size, max_len+1)
        predict = predict.permute(1, 2, 0)

        # end_words = self.end_idx.repeat(batch_size, 1)
        # dec_target = torch.cat((target, end_words), dim=1)
        dec_target = target

        loss = self.xent(predict, dec_target)
        loss_mask = pytorch_utils.length_to_mask(length,
                                                 max_len=self.max_length,
                                                 dtype=torch.float)
        loss = torch.mul(loss, loss_mask)
        loss = torch.div(loss.sum(dim=1), length.float())
        loss = loss.mean(dim=0)
        return loss

    def train_batch(self, word_input, target, length):
        """

        :param word_input: shape == (batch_size, max_len)
        :param target: shape == (batch_size, max_len)
        :return:
        """
        self.train()
        self.optimizer.zero_grad()
        loss = self.get_loss(word_input, target, length)
        loss.backward()
        self.optimizer.step()

        return loss.item()