Exemplo n.º 1
0
    def get_lstm_states(self, gaz_list, reverse_gaz_list, word_ids, word_masks,
                        lengths):
        """

        :param gaz_list:
        :param reverse_gaz_list:
        :param word_ids:
        :param word_masks:
        :param lengths:
        :return:
        """
        batch_size = word_masks.size(0)
        seq_len = word_masks.size(1)

        ## get batch gaz ids
        batch_gaz_ids, batch_gaz_length, batch_gaz_mask = get_batch_gaz(
            reverse_gaz_list, batch_size, seq_len, self.gpu)

        reverse_batch_gaz_ids, reverse_batch_gaz_length, reverse_batch_gaz_mask = get_batch_gaz(
            gaz_list, batch_size, seq_len, self.gpu)
        reverse_batch_gaz_ids = reverse_padded_sequence(
            reverse_batch_gaz_ids, lengths)
        reverse_batch_gaz_length = reverse_padded_sequence(
            reverse_batch_gaz_length, lengths)
        reverse_batch_gaz_mask = reverse_padded_sequence(
            reverse_batch_gaz_mask, lengths)

        ## word embedding
        word_embs = self.word_embedding(word_ids)
        reverse_word_embs = reverse_padded_sequence(word_embs, lengths)

        ## gaz embedding
        gaz_embs = self.gaz_embed(
            (batch_gaz_ids, batch_gaz_length, batch_gaz_mask))
        reverse_gaz_embs = self.gaz_embed(
            (reverse_batch_gaz_ids, reverse_batch_gaz_length,
             reverse_batch_gaz_mask))

        ## lstm
        forward_inputs = torch.cat((word_embs, gaz_embs), dim=-1)
        backward_inputs = torch.cat((reverse_word_embs, reverse_gaz_embs),
                                    dim=-1)

        lstm_outs, _ = self.lstm((forward_inputs, backward_inputs), lengths)

        return lstm_outs
Exemplo n.º 2
0
    def forward(self, inputs, lengths):
        f_cnn_out, b_cnn_out = self.get_cnn_features(inputs)

        rb_cnn_out = reverse_padded_sequence(b_cnn_out, lengths)

        cnn_out = torch.cat((f_cnn_out, rb_cnn_out),
                            dim=-1)  # [N, M, 2 * out_size]

        return cnn_out, (f_cnn_out, b_cnn_out)
Exemplo n.º 3
0
    def forward(self, inputs, lengths):
        """
        """
        f_lstm_out, b_lstm_out = self.get_lstm_features(inputs)

        # lengths = list(map(int, word_seq_length))
        rb_lstm_out = reverse_padded_sequence(b_lstm_out, lengths)

        lstm_out = torch.cat((f_lstm_out, rb_lstm_out), dim=-1)

        return lstm_out, (f_lstm_out, b_lstm_out)