示例#1
0
    def reorder_encoder_states(self, encoder_states, indices):
        """
        Reorder encoder states according to a new set of indices.
        """
        enc_out, hidden, attn_mask, context_vec = encoder_states
        # make sure we swap the hidden state around, apropos multigpu settings
        hidden = _transpose_hidden_state(hidden)

        # LSTM or GRU/RNN hidden state?
        if isinstance(hidden, torch.Tensor):
            hid, cell = hidden, None
        else:
            hid, cell = hidden

        if not torch.is_tensor(indices):
            # cast indices to a tensor if needed
            indices = torch.LongTensor(indices).to(hid.device)

        hid = hid.index_select(1, indices)
        if cell is None:
            hidden = hid
        else:
            cell = cell.index_select(1, indices)
            hidden = (hid, cell)

        # and bring it back to multigpu friendliness
        hidden = _transpose_hidden_state(hidden)
        context_vec = context_vec.index_select(0, indices)
        return enc_out, hidden, attn_mask, context_vec
示例#2
0
    def forward(self, xs, encoder_output, incremental_state=None):
        """
        Decode from input tokens.

        :param xs: (bsz x seqlen) LongTensor of input token indices
        :param encoder_output: output from HredEncoder. Tuple containing
            (enc_out, enc_hidden, attn_mask, context_hidden) tuple.
        :param incremental_state: most recent hidden state to the decoder.
        :returns: (output, hidden_state) pair from the RNN.
            - output is a bsz x time x latentdim matrix. This value must be passed to
                the model's OutputLayer for a final softmax.
            - hidden_state depends on the choice of RNN
        """
        (
            enc_state,
            (hidden_state, cell_state),
            attn_mask,
            context_hidden,
        ) = encoder_output

        # sequence indices => sequence embeddings
        seqlen = xs.size(1)
        xes = self.dropout(self.lt(xs))

        # concatentate context lstm hidden state
        context_hidden_final_layer = context_hidden[:, -1, :].unsqueeze(1)
        resized_context_h = context_hidden_final_layer.expand(-1, seqlen, -1)
        xes = torch.cat((xes, resized_context_h), dim=-1).to(xes.device)

        # run through rnn with None as initial decoder state
        # source for zeroes hidden state: http://www.cs.toronto.edu/~lcharlin/papers/vhred_aaai17.pdf
        output, new_hidden = self.rnn(xes, None)

        return output, _transpose_hidden_state(new_hidden)
示例#3
0
    def forward(self,
                init_hidden,
                context=None,
                inputs=None,
                lens=None,
                context_encoder_states=None):
        batch_size, maxlen = inputs.size()
        if self.embedding is not None:
            inputs = self.embedding(inputs)

        if context is not None:
            repeated_context = context.unsqueeze(1).repeat(1, maxlen, 1)
            inputs = torch.cat([inputs, repeated_context], 2)

        inputs = self.dropout(inputs)
        self.rnn.flatten_parameters()
        if context_encoder_states is not None:
            # attention on the context encoder outputs
            context_enc_state, context_enc_hidden, context_attn_mask = context_encoder_states
            context_attn_params = (context_enc_state, context_attn_mask)
            context_hidden = _transpose_hidden_state(context_enc_hidden)
            if isinstance(context_hidden, tuple):
                context_hidden = tuple(x.contiguous() for x in context_hidden)
            else:
                context_hidden = context_hidden.contiguous()
            new_hidden = context_hidden
            output = []

            for i in range(maxlen):
                o, new_hidden = self.rnn(inputs[:, i, :].unsqueeze(1),
                                         new_hidden)
                o, _ = self.context_attention(o, new_hidden,
                                              context_attn_params)
                output.append(o)
            output = torch.cat(output, dim=1).to(inputs.device)
        else:
            init_hidden = init_hidden.view(batch_size, self.n_layers,
                                           self.hidden_size)
            init_hidden = init_hidden.transpose(0, 1).contiguous()
            if self.rnn_class == 'lstm':
                init_hidden = (init_hidden, init_hidden)
            output, _ = self.rnn(inputs, init_hidden)

        decoded = self.out(output)
        decoded = decoded.view(batch_size, maxlen, self.vocab_size)
        return decoded
示例#4
0
    def forward(self, xs, context_vec, hist_lens):
        # encode current utterrance
        (enc_state, (hidden_state, cell_state),
         attn_mask) = super().forward(xs)
        # if all utterances in context vec length 1, unsqueeze to prevent loss of dimensionality
        if len(context_vec.shape) < 2:
            context_vec = context_vec.unsqueeze(1)
        # get utt lengths of each utt in context vector
        utt_lens = torch.sum(context_vec.ne(0).int(), dim=1)
        # sort by lengths descending for utterance encoder
        sorted_lens, sorted_idx = utt_lens.sort(descending=True)
        sorted_context_vec = context_vec[sorted_idx]
        (_, (sorted_hidden_state, _), _) = super().forward(sorted_context_vec)
        sorted_final_hidden_states = sorted_hidden_state[:, -1, :]

        ### reshape and pad hidden states to bsz x max_hist_len x hidden_size using hist_lens
        original_order_final_hidden = torch.zeros_like(
            sorted_final_hidden_states).scatter_(
                0,
                sorted_idx.unsqueeze(1).expand(
                    -1, sorted_final_hidden_states.shape[1]),
                sorted_final_hidden_states,
            )
        # pad to max hist_len
        original_size_final_hidden = self.sequence_to_padding(
            original_order_final_hidden, hist_lens)
        # pack padded sequence so that we ignore padding
        original_size_final_hidden_packed = nn.utils.rnn.pack_padded_sequence(
            original_size_final_hidden,
            hist_lens.cpu(),
            batch_first=True,
            enforce_sorted=False,
        )
        # pass through context lstm
        _, (context_h_n,
            _) = self.context_lstm(original_size_final_hidden_packed)
        return (
            enc_state,
            (hidden_state, cell_state),
            attn_mask,
            _transpose_hidden_state(context_h_n),
        )
示例#5
0
    def forward(self, inputs, input_lens=None, noise=False):
        inputs = self.input_dropout(inputs)
        attn_mask = inputs.ne(0)
        if self.embedding is not None:
            inputs = self.embedding(inputs)

        batch_size, seq_len, emb_size = inputs.size()
        inputs = self.dropout(inputs)

        self.rnn.flatten_parameters()
        encoder_output, hidden = self.rnn(inputs)

        h_n = hidden[0] if self.rnn_class == 'lstm' else hidden
        h_n = h_n.view(self.n_layers, self.dirs, batch_size, self.hidden_size)
        enc = h_n[-1].transpose(1, 0).contiguous().view(
            batch_size, -1)  # bsz, num_dirs*hidden_size

        if isinstance(self.rnn, nn.LSTM):
            hidden = (
                hidden[0].view(-1, self.dirs, batch_size,
                               self.hidden_size).sum(1),
                hidden[1].view(-1, self.dirs, batch_size,
                               self.hidden_size).sum(1),
            )
        else:
            hidden = hidden.view(-1, self.dirs, batch_size,
                                 self.hidden_size).sum(1)

        hidden = _transpose_hidden_state(hidden)

        if noise and self.noise_radius > 0:
            gauss_noise = gVar(
                torch.normal(means=torch.zeros(enc.size()),
                             std=self.noise_radius), self.use_cuda)
            enc = enc + gauss_noise

        utt_encoder_states = (encoder_output, hidden, attn_mask)
        return enc, utt_encoder_states
示例#6
0
    def sampling(self,
                 init_hidden,
                 context,
                 maxlen,
                 SOS_tok,
                 EOS_tok,
                 mode='greedy',
                 context_encoder_states=None):
        batch_size = init_hidden.size(0)
        decoded_words = np.zeros((batch_size, maxlen), dtype=np.int)
        sample_lens = np.zeros(batch_size, dtype=np.int)

        # noinspection PyArgumentList
        decoder_input = gVar(
            torch.LongTensor([[SOS_tok] * batch_size]).view(batch_size, 1),
            self.use_cuda)
        decoder_input = self.embedding(
            decoder_input) if self.embedding is not None else decoder_input
        decoder_input = torch.cat(
            [decoder_input, context.unsqueeze(1)],
            2) if context is not None else decoder_input

        if context_encoder_states is not None:
            context_enc_state, context_enc_hidden, context_attn_mask = context_encoder_states
            context_attn_params = (context_enc_state, context_attn_mask)
            context_hidden = _transpose_hidden_state(context_enc_hidden)
            if isinstance(context_hidden, tuple):
                context_hidden = tuple(x.contiguous() for x in context_hidden)
            else:
                context_hidden = context_hidden.contiguous()
            decoder_hidden = context_hidden
        else:
            decoder_hidden = init_hidden.view(batch_size, self.n_layers,
                                              self.hidden_size)
            decoder_hidden = decoder_hidden.transpose(0, 1).contiguous()
            if self.rnn_class == 'lstm':
                decoder_hidden = (decoder_hidden, decoder_hidden)

        for di in range(maxlen):
            decoder_output, decoder_hidden = self.rnn(decoder_input,
                                                      decoder_hidden)
            if context_encoder_states is not None:
                # apply attention
                decoder_output, _ = self.context_attention(
                    decoder_output, decoder_hidden, context_attn_params)

            decoder_output = self.out(decoder_output)

            if mode == 'greedy':
                topi = decoder_output[:, -1].max(1, keepdim=True)[1]
            elif mode == 'nucleus':
                # Nucelus, aka top-p sampling (Holtzman et al., 2019).
                logprobs = decoder_output[:, -1]
                probs = torch.softmax(logprobs, dim=-1)
                sprobs, sinds = probs.sort(dim=-1, descending=True)
                mask = (sprobs.cumsum(dim=-1) - sprobs[:, :1]) >= self.topp
                sprobs[mask] = 0
                sprobs.div_(sprobs.sum(dim=-1).unsqueeze(1))
                choices = torch.multinomial(sprobs, 1)[:, 0]
                hyp_ids = torch.arange(logprobs.size(0)).to(logprobs.device)
                topi = sinds[hyp_ids, choices].unsqueeze(dim=1)
            else:
                raise RuntimeError('inference method: {} not supported yet!')

            decoder_input = self.embedding(
                topi) if self.embedding is not None else topi
            decoder_input = torch.cat(
                [decoder_input, context.unsqueeze(1)],
                2) if context is not None else decoder_input
            ni = topi.squeeze().data.cpu().numpy()
            decoded_words[:, di] = ni

        for i in range(batch_size):
            for word in decoded_words[i]:
                if word == EOS_tok:
                    break
                sample_lens[i] += 1
        return decoded_words, sample_lens
示例#7
0
    def forward(self, context, context_lens, utt_lens, floors, noise=False):
        batch_size, max_context_len, max_utt_len = context.size()
        utts = context.view(-1, max_utt_len)
        batch_max_lens = torch.arange(max_context_len).expand(
            batch_size, max_context_len)
        if self.use_cuda:
            batch_max_lens = batch_max_lens.cuda()
        context_mask = batch_max_lens < context_lens.unsqueeze(1)
        utt_lens = utt_lens.view(-1)
        utt_encs, utt_encoder_states = self.utt_encoder(utts, utt_lens)
        utt_encs = utt_encs.view(batch_size, max_context_len, -1)
        utt_encoder_output, utt_hidden, utt_attn_mask = utt_encoder_states
        utt_encoder_output = utt_encoder_output.view(
            batch_size, max_context_len, max_utt_len,
            self.utt_encoder.dirs * self.utt_encoder.hidden_size)
        utt_hidden = _transpose_hidden_state(utt_hidden)
        if isinstance(utt_hidden, tuple):
            utt_hidden = tuple(
                x.view(self.utt_encoder.n_layers, batch_size, max_context_len,
                       self.utt_encoder.hidden_size).contiguous()
                for x in utt_hidden)
        else:
            utt_hidden = utt_hidden.view(
                self.utt_encoder.n_layers, batch_size, max_context_len,
                self.utt_encoder.hidden_size).contiguous()
        utt_attn_mask = utt_attn_mask.view(batch_size, max_context_len,
                                           max_utt_len)

        floor_one_hot = gVar(torch.zeros(floors.numel(), 2), self.use_cuda)
        floor_one_hot.data.scatter_(1, floors.view(-1, 1), 1)
        floor_one_hot = floor_one_hot.view(-1, max_context_len, 2)
        utt_floor_encs = torch.cat([utt_encs, floor_one_hot], 2)

        utt_floor_encs = self.dropout(utt_floor_encs)
        self.rnn.flatten_parameters()

        if self.rnn_class == 'lstm':
            new_hidden = tuple(x[:, :, -1, :].contiguous() for x in utt_hidden)
        else:
            new_hidden = utt_hidden[:, :, -1, :].contiguous()

        if self.attn_type != 'none':
            output = []
            for i in range(max_context_len):
                o, new_hidden = self.rnn(utt_floor_encs[:, i, :].unsqueeze(1),
                                         new_hidden)
                o, _ = self.word_attention(
                    o, new_hidden,
                    (utt_encoder_output[:, i, :, :], utt_attn_mask[:, i, :]))
                output.append(o)

            context_encoder_output = torch.cat(output,
                                               dim=1).to(utt_floor_encs.device)
        else:
            utt_floor_encs = pack_padded_sequence(utt_floor_encs,
                                                  context_lens,
                                                  batch_first=True,
                                                  enforce_sorted=False)

            context_encoder_output, new_hidden = self.rnn(
                utt_floor_encs, new_hidden)
            context_encoder_output, _ = pad_packed_sequence(
                context_encoder_output,
                batch_first=True,
                total_length=max_context_len)

        new_hidden = _transpose_hidden_state(new_hidden)
        if self.rnn_class == 'lstm':
            enc = new_hidden[0]
        else:
            enc = new_hidden
        enc = enc.contiguous().view(batch_size, -1)

        if noise and self.noise_radius > 0:
            gauss_noise = gVar(
                torch.normal(means=torch.zeros(enc.size()),
                             std=self.noise_radius), self.use_cuda)
            enc = enc + gauss_noise
        return enc, (context_encoder_output, new_hidden, context_mask)