Beispiel #1
0
    def forward(self, context):
        batch_size, _ = context.size()
        context = self.fc(context)

        pi = self.pi_net(context)
        pi = F.gumbel_softmax(pi, tau=self.gumbel_temp, hard=True, eps=1e-10)
        pi = pi.unsqueeze(1)

        mus = self.context_to_mu(context)
        logsigmas = self.context_to_logsigma(context)

        # mus = torch.clamp(mus, -30, 30)
        logsigmas = torch.clamp(logsigmas, -20, 20)

        stds = torch.exp(0.5 * logsigmas)

        epsilons = gVar(
            torch.randn([batch_size, self.n_components * self.z_size]),
            self.use_cua)

        zi = (epsilons * stds + mus).view(batch_size, self.n_components,
                                          self.z_size)
        z = torch.bmm(pi, zi).squeeze(1)  # [batch_sz x z_sz]
        mu = torch.bmm(pi, mus.view(batch_size, self.n_components,
                                    self.z_size))
        logsigma = torch.bmm(
            pi, logsigmas.view(batch_size, self.n_components, self.z_size))
        return z, mu, logsigma
Beispiel #2
0
    def forward(self, context):
        batch_size, _ = context.size()
        context = self.fc(context)
        mu = self.context_to_mu(context)
        logsigma = self.context_to_logsigma(context)

        # mu = torch.clamp(mu, -30, 30)
        logsigma = torch.clamp(logsigma, -20, 20)
        std = torch.exp(0.5 * logsigma)
        epsilon = gVar(torch.randn([batch_size, self.z_size]),
                       use_cuda=self.use_cua)
        z = epsilon * std + mu
        return z, mu, logsigma
Beispiel #3
0
    def forward(self, inputs, input_lens=None, noise=False):
        inputs = self.input_dropout(inputs)
        attn_mask = inputs.ne(0)
        if self.embedding is not None:
            inputs = self.embedding(inputs)

        batch_size, seq_len, emb_size = inputs.size()
        inputs = self.dropout(inputs)

        self.rnn.flatten_parameters()
        encoder_output, hidden = self.rnn(inputs)

        h_n = hidden[0] if self.rnn_class == 'lstm' else hidden
        h_n = h_n.view(self.n_layers, self.dirs, batch_size, self.hidden_size)
        enc = h_n[-1].transpose(1, 0).contiguous().view(
            batch_size, -1)  # bsz, num_dirs*hidden_size

        if isinstance(self.rnn, nn.LSTM):
            hidden = (
                hidden[0].view(-1, self.dirs, batch_size,
                               self.hidden_size).sum(1),
                hidden[1].view(-1, self.dirs, batch_size,
                               self.hidden_size).sum(1),
            )
        else:
            hidden = hidden.view(-1, self.dirs, batch_size,
                                 self.hidden_size).sum(1)

        hidden = _transpose_hidden_state(hidden)

        if noise and self.noise_radius > 0:
            gauss_noise = gVar(
                torch.normal(means=torch.zeros(enc.size()),
                             std=self.noise_radius), self.use_cuda)
            enc = enc + gauss_noise

        utt_encoder_states = (encoder_output, hidden, attn_mask)
        return enc, utt_encoder_states
Beispiel #4
0
    def sampling(self,
                 init_hidden,
                 context,
                 maxlen,
                 SOS_tok,
                 EOS_tok,
                 mode='greedy',
                 context_encoder_states=None):
        batch_size = init_hidden.size(0)
        decoded_words = np.zeros((batch_size, maxlen), dtype=np.int)
        sample_lens = np.zeros(batch_size, dtype=np.int)

        # noinspection PyArgumentList
        decoder_input = gVar(
            torch.LongTensor([[SOS_tok] * batch_size]).view(batch_size, 1),
            self.use_cuda)
        decoder_input = self.embedding(
            decoder_input) if self.embedding is not None else decoder_input
        decoder_input = torch.cat(
            [decoder_input, context.unsqueeze(1)],
            2) if context is not None else decoder_input

        if context_encoder_states is not None:
            context_enc_state, context_enc_hidden, context_attn_mask = context_encoder_states
            context_attn_params = (context_enc_state, context_attn_mask)
            context_hidden = _transpose_hidden_state(context_enc_hidden)
            if isinstance(context_hidden, tuple):
                context_hidden = tuple(x.contiguous() for x in context_hidden)
            else:
                context_hidden = context_hidden.contiguous()
            decoder_hidden = context_hidden
        else:
            decoder_hidden = init_hidden.view(batch_size, self.n_layers,
                                              self.hidden_size)
            decoder_hidden = decoder_hidden.transpose(0, 1).contiguous()
            if self.rnn_class == 'lstm':
                decoder_hidden = (decoder_hidden, decoder_hidden)

        for di in range(maxlen):
            decoder_output, decoder_hidden = self.rnn(decoder_input,
                                                      decoder_hidden)
            if context_encoder_states is not None:
                # apply attention
                decoder_output, _ = self.context_attention(
                    decoder_output, decoder_hidden, context_attn_params)

            decoder_output = self.out(decoder_output)

            if mode == 'greedy':
                topi = decoder_output[:, -1].max(1, keepdim=True)[1]
            elif mode == 'nucleus':
                # Nucelus, aka top-p sampling (Holtzman et al., 2019).
                logprobs = decoder_output[:, -1]
                probs = torch.softmax(logprobs, dim=-1)
                sprobs, sinds = probs.sort(dim=-1, descending=True)
                mask = (sprobs.cumsum(dim=-1) - sprobs[:, :1]) >= self.topp
                sprobs[mask] = 0
                sprobs.div_(sprobs.sum(dim=-1).unsqueeze(1))
                choices = torch.multinomial(sprobs, 1)[:, 0]
                hyp_ids = torch.arange(logprobs.size(0)).to(logprobs.device)
                topi = sinds[hyp_ids, choices].unsqueeze(dim=1)
            else:
                raise RuntimeError('inference method: {} not supported yet!')

            decoder_input = self.embedding(
                topi) if self.embedding is not None else topi
            decoder_input = torch.cat(
                [decoder_input, context.unsqueeze(1)],
                2) if context is not None else decoder_input
            ni = topi.squeeze().data.cpu().numpy()
            decoded_words[:, di] = ni

        for i in range(batch_size):
            for word in decoded_words[i]:
                if word == EOS_tok:
                    break
                sample_lens[i] += 1
        return decoded_words, sample_lens
Beispiel #5
0
    def forward(self, context, context_lens, utt_lens, floors, noise=False):
        batch_size, max_context_len, max_utt_len = context.size()
        utts = context.view(-1, max_utt_len)
        batch_max_lens = torch.arange(max_context_len).expand(
            batch_size, max_context_len)
        if self.use_cuda:
            batch_max_lens = batch_max_lens.cuda()
        context_mask = batch_max_lens < context_lens.unsqueeze(1)
        utt_lens = utt_lens.view(-1)
        utt_encs, utt_encoder_states = self.utt_encoder(utts, utt_lens)
        utt_encs = utt_encs.view(batch_size, max_context_len, -1)
        utt_encoder_output, utt_hidden, utt_attn_mask = utt_encoder_states
        utt_encoder_output = utt_encoder_output.view(
            batch_size, max_context_len, max_utt_len,
            self.utt_encoder.dirs * self.utt_encoder.hidden_size)
        utt_hidden = _transpose_hidden_state(utt_hidden)
        if isinstance(utt_hidden, tuple):
            utt_hidden = tuple(
                x.view(self.utt_encoder.n_layers, batch_size, max_context_len,
                       self.utt_encoder.hidden_size).contiguous()
                for x in utt_hidden)
        else:
            utt_hidden = utt_hidden.view(
                self.utt_encoder.n_layers, batch_size, max_context_len,
                self.utt_encoder.hidden_size).contiguous()
        utt_attn_mask = utt_attn_mask.view(batch_size, max_context_len,
                                           max_utt_len)

        floor_one_hot = gVar(torch.zeros(floors.numel(), 2), self.use_cuda)
        floor_one_hot.data.scatter_(1, floors.view(-1, 1), 1)
        floor_one_hot = floor_one_hot.view(-1, max_context_len, 2)
        utt_floor_encs = torch.cat([utt_encs, floor_one_hot], 2)

        utt_floor_encs = self.dropout(utt_floor_encs)
        self.rnn.flatten_parameters()

        if self.rnn_class == 'lstm':
            new_hidden = tuple(x[:, :, -1, :].contiguous() for x in utt_hidden)
        else:
            new_hidden = utt_hidden[:, :, -1, :].contiguous()

        if self.attn_type != 'none':
            output = []
            for i in range(max_context_len):
                o, new_hidden = self.rnn(utt_floor_encs[:, i, :].unsqueeze(1),
                                         new_hidden)
                o, _ = self.word_attention(
                    o, new_hidden,
                    (utt_encoder_output[:, i, :, :], utt_attn_mask[:, i, :]))
                output.append(o)

            context_encoder_output = torch.cat(output,
                                               dim=1).to(utt_floor_encs.device)
        else:
            utt_floor_encs = pack_padded_sequence(utt_floor_encs,
                                                  context_lens,
                                                  batch_first=True,
                                                  enforce_sorted=False)

            context_encoder_output, new_hidden = self.rnn(
                utt_floor_encs, new_hidden)
            context_encoder_output, _ = pad_packed_sequence(
                context_encoder_output,
                batch_first=True,
                total_length=max_context_len)

        new_hidden = _transpose_hidden_state(new_hidden)
        if self.rnn_class == 'lstm':
            enc = new_hidden[0]
        else:
            enc = new_hidden
        enc = enc.contiguous().view(batch_size, -1)

        if noise and self.noise_radius > 0:
            gauss_noise = gVar(
                torch.normal(means=torch.zeros(enc.size()),
                             std=self.noise_radius), self.use_cuda)
            enc = enc + gauss_noise
        return enc, (context_encoder_output, new_hidden, context_mask)