Esempio n. 1
0
    def forward(self, seq, seq_lens):

        if self.training:
            word_dropout = Bernoulli(self.word_dropout).sample(seq.shape)
            word_dropout = word_dropout.type(torch.LongTensor)
            seq = seq.cpu()
            seq = seq * word_dropout
            seq = seq.cuda()

        embedded_seq = self.embed(seq)
        embedded_seq = self.input_dropout(embedded_seq)

        encoder_input = nn.utils.rnn.pack_padded_sequence(embedded_seq,
                                                          seq_lens,
                                                          batch_first=True,
                                                          enforce_sorted=False)

        encoder_hidden, (h_0, c_0) = self.encoder(encoder_input)
        encoder_hidden, _ = nn.utils.rnn.pad_packed_sequence(encoder_hidden,
                                                             batch_first=True)
        encoder_hidden = self.output_dropout(encoder_hidden)

        final_hidden = encoder_hidden[torch.arange(encoder_hidden.size(0)),
                                      seq_lens - 1, :]

        # TODO Highway layers

        return final_hidden, encoder_hidden
Esempio n. 2
0
    def word_dropout(self, input, lens):
        if not self.training:
            return input
        output = []
        for inp, _len in zip(input, lens):
            word_dropout = Bernoulli(1 - self.word_dropout_p).sample(
                inp[1:_len].shape)
            inp = inp.cpu()
            inp[1:_len] = inp[1:_len] * word_dropout.type(torch.LongTensor)
            inp[1:_len][inp[1:_len] == 0] = self.embeddings.unk_idx
            inp = inp.cuda()
            output.append(inp)

        return torch.stack(output, 0).cuda()
Esempio n. 3
0
    def forward(
        self,
        representation,
        seq,
        initial_state=None,
        context_batch_mask=None,
        encoder_hidden_states=None,
    ):
        class_name = self.__class__.__name__
        if self.MODE == "PRETRAIN":
            representation = self.representation_layer(
                self.embeddings(seq).mean(1))
        assert representation.shape == torch.Size(
            [seq.shape[0], self.repr_hidden_size])
        logits = []
        predictions = []
        attention = []
        batch_size, seq_len = seq.shape
        # batch_size, 1
        original_seq = seq
        seq_i = seq[:, 0].unsqueeze(1)
        if self.MODE == "TRAIN":
            word_dropout = Bernoulli(0.75).sample(seq[:, 1:].shape)
            word_dropout = word_dropout.type(torch.LongTensor)
            seq = seq.cpu()
            seq[:, 1:] = seq[:, 1:] * word_dropout
            seq = seq.cuda()
        # 1, batch_size, hidden_x_dirs
        if initial_state is not None:
            decoder_hidden_tuple_i = initial_state.unsqueeze(0)
        else:
            decoder_hidden_tuple_i = None
        # teacher forcing p
        p = random.random()

        self.attention = []

        # we skip the EOS as input for the decoder
        for i in range(seq_len - 1):

            decoder_hidden_tuple_i, logits_i = self.generate(
                seq_i,
                decoder_hidden_tuple_i,
                representation,
                z,
                context,
                emb,
                cnn,
                context_batch_mask,
                encoder_hidden_states,
            )

            # batch_size
            _, predictions_i = logits_i.max(1)

            logits.append(logits_i)
            predictions.append(predictions_i)

            if self.training and p <= self.teacher_forcing_p:
                # batch_size, 1
                seq_i = seq[:, i + 1].unsqueeze(1)
            else:
                # batch_size, 1
                seq_i = predictions_i.unsqueeze(1)
                seq_i = seq_i.cuda()

        # (seq_len, batch_size)
        predictions = torch.stack(predictions, 0)

        # (batch_size, seq_len)
        predictions = predictions.t().contiguous()

        # (seq_len, batch_size, output_size)
        logits = torch.stack(logits, 0)

        # (batch_size, seq_len, output_size)
        logits = logits.transpose(0, 1).contiguous()

        # (batch_size*seq_len, output_size)
        flat_logits = logits.view(batch_size * (seq_len - 1), -1)

        # (batch_size, seq_len)
        labels = original_seq[:, 1:].contiguous()

        # (batch_size*seq_len)
        flat_labels = labels.view(-1)

        loss = self.loss_function(flat_logits, flat_labels)
        log_ppl = F.cross_entropy(flat_logits,
                                  flat_labels,
                                  ignore_index=self.embeddings.padding_idx)
        return loss, predictions, log_ppl