Exemplo n.º 1
0
    def encode(self, input_sequence, length):
        batch_size = input_sequence.size(0)
        sorted_lengths, sorted_idx = torch.sort(length, descending=True)
        input_sequence = input_sequence[sorted_idx]

        # ENCODER
        input_embedding = self.embedding(input_sequence)

        packed_input = rnn_utils.pack_padded_sequence(
            input_embedding, sorted_lengths.data.tolist(), batch_first=True)

        _, hidden = self.encoder_rnn(packed_input)

        if self.bidirectional or self.num_layers > 1:
            # flatten hidden state
            hidden = hidden.view(batch_size,
                                 self.hidden_size * self.hidden_factor)
        else:
            hidden = hidden.squeeze()

        # REPARAMETERIZATION
        mean = self.hidden2mean(hidden)
        logv = self.hidden2logv(hidden)
        std = torch.exp(0.5 * logv)

        z = to_var(torch.randn([batch_size, self.latent_size]))
        z = z * std + mean
        return mean, std, z
Exemplo n.º 2
0
    def encode_condition(self, cond_sequence, cond_length):
        assert self.is_conditional is not None and cond_sequence is not None and cond_length is not None
        batch_size = cond_sequence.size(0)
        sorted_lengths, sorted_idx = torch.sort(cond_length, descending=True)
        cond_sequence = cond_sequence[sorted_idx]
        _, reversed_idx = torch.sort(sorted_idx)

        # -------------------- CONDITIONAL ENCODER ------------------------
        cond_embedding = self.cond_embedding(cond_sequence)
        packed_input = rnn_utils.pack_padded_sequence(
            cond_embedding, sorted_lengths.data.tolist(), batch_first=True)

        _, hidden = self.cond_encoder_rnn(packed_input)
        hidden = self._reshape_hidden_for_bidirection(hidden, batch_size,
                                                      self.cond_hidden_size)
        hidden = hidden[reversed_idx]

        assert hidden.size(0) == batch_size, hidde.size(
            1) == self.cond_hidden_size

        # REPARAMETERIZATION
        mean = self.cond_hidden2mean(hidden)
        logv = self.cond_hidden2logv(hidden)
        std = torch.exp(0.5 * logv)
        z = to_var(torch.randn([batch_size, self.latent_size]))
        z = z * std + mean

        return hidden, mean, logv, z
Exemplo n.º 3
0
    def encode(self, input_sequence, length, extra_hidden=None):
        batch_size = input_sequence.size(0)
        sorted_lengths, sorted_idx = torch.sort(length, descending=True)
        input_sequence = input_sequence[sorted_idx]
        _, reversed_idx = torch.sort(sorted_idx)

        # -------------------- ENCODER ------------------------
        input_embedding = self.embedding(input_sequence)

        packed_input = rnn_utils.pack_padded_sequence(
            input_embedding, sorted_lengths.data.tolist(), batch_first=True)

        _, hidden = self.encoder_rnn(packed_input)
        hidden = self._reshape_hidden_for_bidirection(hidden, batch_size,
                                                      self.hidden_size)
        hidden = hidden[reversed_idx]

        assert hidden.size(0) == batch_size, hidde.size(1) == self.hidden_size

        if extra_hidden is not None:
            assert self.is_conditional, 'extra_hidden を追加しているのに is_conditional が無効になっています'
            hidden = torch.cat([hidden, extra_hidden], dim=1)

        # REPARAMETERIZATION
        mean = self.hidden2mean(hidden)
        logv = self.hidden2logv(hidden)
        std = torch.exp(0.5 * logv)

        z = to_var(torch.randn([batch_size, self.latent_size]))
        z = z * std + mean
        return mean, logv, z
Exemplo n.º 4
0
    def hidden2latent(self, hidden):
        # --------------- REPARAMETERIZATION ------------------
        batch_size = hidden.size(0)
        mean = self.hidden2mean(hidden)
        logv = self.hidden2logv(hidden)
        std = torch.exp(0.5 * logv)

        z = to_var(torch.randn([batch_size, self.latent_size]))
        z = z * std + mean
        return mean, logv, z
Exemplo n.º 5
0
    def inference(self, n=4, z=None):

        if z is None:
            batch_size = n
            z = to_var(torch.randn([batch_size, self.latent_size]))
        else:
            batch_size = z.size(0)

        hidden = self.latent2hidden(z)

        if self.bidirectional or self.num_layers > 1:
            # unflatten hidden state
            hidden = hidden.view(self.hidden_factor, batch_size,
                                 self.hidden_size)

        hidden = hidden.unsqueeze(0)

        # required for dynamic stopping of sentence generation
        sequence_idx = torch.arange(
            0, batch_size, out=self.tensor()).long()  # all idx of batch
        sequence_running = torch.arange(0, batch_size, out=self.tensor()).long(
        )  # all idx of batch which are still generating
        sequence_mask = torch.ones(batch_size, out=self.tensor()).bool()

        running_seqs = torch.arange(0, batch_size, out=self.tensor()).long(
        )  # idx of still generating sequences with respect to current loop

        generations = self.tensor(batch_size, self.max_sequence_length).fill_(
            self.pad_idx).long()

        t = 0
        while (t < self.max_sequence_length and len(running_seqs) > 0):
            if t == 0:
                input_sequence = to_var(
                    torch.Tensor(batch_size).fill_(self.sos_idx).long())

            input_sequence = input_sequence.unsqueeze(1)

            input_embedding = self.embedding(input_sequence)

            output, hidden = self.decoder_rnn(input_embedding, hidden)

            logits = self.outputs2vocab(output)

            input_sequence = self._sample(logits)

            # save next input
            generations = self._save_sample(generations, input_sequence,
                                            sequence_running, t)

            # update gloabl running sequence
            sequence_mask[sequence_running] = (input_sequence !=
                                               self.eos_idx).data
            sequence_running = sequence_idx.masked_select(sequence_mask)

            # update local running sequences
            running_mask = (input_sequence != self.eos_idx).data
            running_seqs = running_seqs.masked_select(running_mask)

            # prune input and hidden state according to local update
            if len(running_seqs) > 0:
                input_sequence = input_sequence.view(-1)[running_seqs]
                hidden = hidden[:, running_seqs]

                running_seqs = torch.arange(0,
                                            len(running_seqs),
                                            out=self.tensor()).long()

            t += 1

        return generations, z
Exemplo n.º 6
0
    def forward(self, input_sequence, length):

        batch_size = input_sequence.size(0)
        sorted_lengths, sorted_idx = torch.sort(length, descending=True)
        input_sequence = input_sequence[sorted_idx]

        # ENCODER
        input_embedding = self.embedding(input_sequence)

        packed_input = rnn_utils.pack_padded_sequence(
            input_embedding, sorted_lengths.data.tolist(), batch_first=True)

        _, hidden = self.encoder_rnn(packed_input)

        if self.bidirectional or self.num_layers > 1:
            # flatten hidden state
            hidden = hidden.view(batch_size,
                                 self.hidden_size * self.hidden_factor)
        else:
            hidden = hidden.squeeze()

        # REPARAMETERIZATION
        mean = self.hidden2mean(hidden)
        logv = self.hidden2logv(hidden)
        std = torch.exp(0.5 * logv)

        z = to_var(torch.randn([batch_size, self.latent_size]))
        z = z * std + mean

        # DECODER
        hidden = self.latent2hidden(z)

        if self.bidirectional or self.num_layers > 1:
            # unflatten hidden state
            hidden = hidden.view(self.hidden_factor, batch_size,
                                 self.hidden_size)
        else:
            hidden = hidden.unsqueeze(0)

        # decoder input
        if self.word_dropout_rate > 0:
            # randomly replace decoder input with <unk>
            prob = torch.rand(input_sequence.size())
            if torch.cuda.is_available():
                prob = prob.cuda()
            prob[(input_sequence.data - self.sos_idx) *
                 (input_sequence.data - self.pad_idx) == 0] = 1
            decoder_input_sequence = input_sequence.clone()
            decoder_input_sequence[
                prob < self.word_dropout_rate] = self.unk_idx
            input_embedding = self.embedding(decoder_input_sequence)
        input_embedding = self.embedding_dropout(input_embedding)
        packed_input = rnn_utils.pack_padded_sequence(
            input_embedding, sorted_lengths.data.tolist(), batch_first=True)

        # decoder forward pass
        outputs, _ = self.decoder_rnn(packed_input, hidden)

        # process outputs
        padded_outputs = rnn_utils.pad_packed_sequence(outputs,
                                                       batch_first=True)[0]
        padded_outputs = padded_outputs.contiguous()
        _, reversed_idx = torch.sort(sorted_idx)
        padded_outputs = padded_outputs[reversed_idx]
        b, s, _ = padded_outputs.size()

        # project outputs to vocab
        logp = nn.functional.log_softmax(self.outputs2vocab(
            padded_outputs.view(-1, padded_outputs.size(2))),
                                         dim=-1)
        logp = logp.view(b, s, self.embedding.num_embeddings)

        return logp, mean, logv, z