Пример #1
0
 def np2var(self, inputs, dtype):
     if inputs is None:
         return None
     if type(inputs) == list:
         return cast_type(Variable(torch.Tensor(inputs)), dtype,
                          self.use_gpu)
     return cast_type(Variable(torch.from_numpy(inputs)), dtype,
                      self.use_gpu)
    def sweep(self, data_feed, gen_type='greedy'):
        ctx_lens = data_feed['output_lens']
        batch_size = len(ctx_lens)
        out_utts = self.np2var(data_feed['outputs'], LONG)

        # output encoder
        output_embedding = self.embedding(out_utts)
        x_outs, x_last = self.x_encoder(output_embedding)
        x_last = x_last.transpose(0, 1).contiguous().view(-1, self.enc_out_size)

        # posterior network
        qy_logits = self.q_y(x_last).view(-1, self.config.k)

        # switch that controls the sampling
        sample_y, y_id = self.cat_connector(qy_logits, 1.0, self.use_gpu,
                                            hard=True, return_max_id=True)
        y_id = y_id.view(-1, self.config.y_size)
        start_y_id = y_id[0]
        end_y_id = y_id[batch_size-1]

        # start sweeping
        all_y_ids = [start_y_id]
        for idx in range(self.config.y_size):
            mask = torch.zeros(self.config.y_size)
            mask[0:idx+1] = 1.0
            neg_mask = 1 - mask
            mask = cast_type(Variable(mask), LONG, self.use_gpu)
            neg_mask = cast_type(Variable(neg_mask), LONG, self.use_gpu)
            temp_y = neg_mask * start_y_id + mask * end_y_id
            all_y_ids.append(temp_y)
        num_steps = len(all_y_ids)
        all_y_ids = torch.cat(all_y_ids, dim=0).view(num_steps, -1)

        sample_y = cast_type(Variable(torch.zeros((num_steps*self.config.y_size, self.config.k))), FLOAT, self.use_gpu)
        sample_y.scatter_(1, all_y_ids.view(-1, 1), 1.0)
        sample_y = sample_y.view(-1, self.config.k * self.config.y_size)
        batch_size = num_steps

        # map sample to initial state of decoder
        dec_init_state = self.dec_init_connector(sample_y)

        # get decoder inputs
        labels = out_utts[:, 1:].contiguous()
        dec_inputs = out_utts[:, 0:-1]

        # decode
        dec_outs, dec_last, dec_ctx = self.decoder(batch_size,
                                                   dec_inputs, dec_init_state,
                                                   mode=GEN, gen_type=gen_type,
                                                   beam_size=self.beam_size)
        # compute loss or return results
        return dec_ctx, labels, all_y_ids
    def exp_enumerate(self, repeat=1, gen_type='greedy'):

        # do something here. For each y, we enumerate from 0 to K
        # and take the expectation of other values.
        batch_size = np.power(self.config.k, self.config.y_size) * repeat
        sample_y = cast_type(Variable(torch.zeros((batch_size*self.config.y_size,
                                                   self.config.k))),
                             FLOAT, self.use_gpu)
        d = dict((str(i), range(self.config.k)) for i in range(self.config.y_size))
        all_y_ids = []
        for combo in itertools.product(*[d[k] for k in sorted(d.keys())]):
            all_y_ids.append(list(combo))
        np_y_ids = np.array(all_y_ids)
        np_y_ids = self.np2var(np_y_ids, LONG)
        # map sample to initial state of decoder
        sample_y.scatter_(1, np_y_ids.view(-1, 1), 1.0)
        sample_y = sample_y.view(-1, self.config.k * self.config.y_size)
        dec_init_state = self.dec_init_connector(sample_y)

        # decode
        dec_outs, dec_last, dec_ctx = self.decoder(batch_size,
                                                   None, dec_init_state,
                                                   mode=GEN, gen_type=gen_type,
                                                   beam_size=self.beam_size)
        return dec_ctx, all_y_ids
    def enumerate(self, repeat=1, gen_type='greedy'):

        # do something here. For each y, we enumerate from 0 to K
        # and take the expectation of other values.
        batch_size = self.config.y_size * self.config.k * repeat
        sample_y = cast_type(Variable(torch.zeros((batch_size,
                                                   self.config.y_size,
                                                   self.config.k))),
                             FLOAT, self.use_gpu)
        sample_y += 1.0/self.config.k

        for y_id in range(self.config.y_size):
            for k_id in range(self.config.k):
                for r_id in range(repeat):
                    idx = y_id*self.config.k + k_id*repeat + r_id
                    sample_y[idx, y_id] = 0.0
                    sample_y[idx, y_id, k_id] = 1.0

        # map sample to initial state of decoder
        sample_y = sample_y.view(-1, self.config.k * self.config.y_size)
        dec_init_state = self.dec_init_connector(sample_y)

        # decode
        dec_outs, dec_last, dec_ctx = self.decoder(batch_size,
                                                   None, dec_init_state,
                                                   mode=GEN, gen_type=gen_type,
                                                   beam_size=self.beam_size)
        # compute loss or return results
        return dec_ctx
Пример #5
0
 def extract_name(self, action_id):
     if str(action_id) in self.action2name.keys():
         action_name = self.action2name[str(action_id)]
         name = torch.Tensor(map(int, action_name.strip().split('-')))
         name = cast_type(name, LONG, self.use_gpu)
         return name
     else:
         return 'empty'
Пример #6
0
    def __init__(self, padding_idx, config, rev_vocab=None, key_vocab=None):
        super(NLLEntropy, self).__init__()
        self.padding_idx = padding_idx
        self.avg_type = config.avg_type

        if rev_vocab is None or key_vocab is None:
            self.weight = None
        else:
            self.logger.info("Use extra cost for key words")
            weight = np.ones(len(rev_vocab))
            for key in key_vocab:
                weight[rev_vocab[key]] = 10.0
            self.weight = cast_type(torch.from_numpy(weight), FLOAT,
                                    config.use_gpu)
    def gumbel_max(self, log_probs):
        """
        Obtain a sample from the Gumbel max. Not this is not differentibale.
        :param log_probs: [batch_size x vocab_size]
        :return: [batch_size x 1] selected token IDs
        """
        sample = torch.Tensor(log_probs.size()).uniform_(0, 1)
        sample = cast_type(Variable(sample), FLOAT, self.use_gpu)

        # compute the gumbel sample
        matrix_u = -1.0 * torch.log(-1.0 * torch.log(sample))
        gumbel_log_probs = log_probs + matrix_u
        max_val, max_ids = torch.max(gumbel_log_probs, dim=-1, keepdim=True)
        return max_ids
Пример #8
0
    def forward(self, mu, logvar, use_gpu):
        """
        Sample a sample from a multivariate Gaussian distribution with a diagonal covariance matrix using the
        reparametrization trick.

        TODO: this should be better be a instance method in a Gaussian class.

        :param mu: a tensor of size [batch_size, variable_dim]. Batch_size can be None to support dynamic batching
        :param logvar: a tensor of size [batch_size, variable_dim]. Batch_size can be None.
        :return:
        """
        epsilon = torch.randn(logvar.size())
        epsilon = cast_type(Variable(epsilon), FLOAT, use_gpu)
        std = torch.exp(0.5 * logvar)
        z = mu + std * epsilon
        return z
Пример #9
0
 def forward(self, logits, use_gpu, return_max_id=False):
     """
     :param logits: [batch_size, n_class] unnormalized log-prob
     :param temperature: non-negative scalar
     :param hard: if True take argmax
     :return: [batch_size, n_class] sample from gumbel softmax
     """
     _, y_hard = torch.max(logits, dim=1, keepdim=True)
     y_onehot = cast_type(Variable(torch.zeros(logits.size())), FLOAT,
                          use_gpu)
     y_onehot.scatter_(1, y_hard, 1.0)
     y = y_onehot
     if return_max_id:
         return y, y_hard
     else:
         return y
Пример #10
0
    def forward(self,
                batch_size,
                inputs=None,
                init_state=None,
                attn_context=None,
                mode=TEACH_FORCE,
                gen_type='greedy',
                beam_size=4):

        # sanity checks
        ret_dict = dict()

        if self.use_attention:
            # calculate initial attention
            ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list()

        if mode == GEN:
            inputs = None

        if gen_type != 'beam':
            beam_size = 1

        if inputs is not None:
            decoder_input = inputs
        else:
            # prepare the BOS inputs
            bos_var = Variable(torch.LongTensor([self.sos_id]), volatile=True)
            bos_var = cast_type(bos_var, LONG, self.use_gpu)
            decoder_input = bos_var.expand(batch_size * beam_size, 1)

        if mode == GEN and gen_type == 'beam':
            # if beam search, repeat the initial states of the RNN
            if self.rnn_cell is nn.LSTM:
                h, c = init_state
                decoder_hidden = (self.repeat_state(h, batch_size, beam_size),
                                  self.repeat_state(c, batch_size, beam_size))
            else:
                decoder_hidden = self.repeat_state(init_state, batch_size,
                                                   beam_size)
        else:
            decoder_hidden = init_state

        decoder_outputs = []  # a list of logprob
        sequence_symbols = []  # a list word ids
        back_pointers = []  # a list of parent beam ID
        lengths = np.array([self.max_length] * batch_size * beam_size)

        def decode(step, cum_sum, step_output, step_attn):
            decoder_outputs.append(step_output)
            step_output_slice = step_output.squeeze(1)

            if self.use_attention:
                ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn)

            if gen_type == 'greedy':
                symbols = step_output_slice.topk(1)[1]
            elif gen_type == 'sample':
                symbols = self.gumbel_max(step_output_slice)
            elif gen_type == 'beam':
                if step == 0:
                    seq_score = step_output_slice.view(batch_size, -1)
                    seq_score = seq_score[:, 0:self.output_size]
                else:
                    seq_score = cum_sum + step_output_slice
                    seq_score = seq_score.view(batch_size, -1)

                top_v, top_id = seq_score.topk(beam_size)

                back_ptr = top_id.div(self.output_size).view(-1, 1)
                symbols = top_id.fmod(self.output_size).view(-1, 1)
                cum_sum = top_v.view(-1, 1)
                back_pointers.append(back_ptr)
            else:
                raise ValueError("Unsupported decoding mode")

            sequence_symbols.append(symbols)

            eos_batches = symbols.data.eq(self.eos_id)
            if eos_batches.dim() > 0:
                eos_batches = eos_batches.cpu().view(-1).numpy()
                update_idx = ((lengths > di) & eos_batches) != 0
                lengths[update_idx] = len(sequence_symbols)
            return cum_sum, symbols

        # Manual unrolling is used to support random teacher forcing.
        # If teacher_forcing_ratio is True or False instead of a probability,
        # the unrolling can be done in graph
        if mode == TEACH_FORCE:
            decoder_output, decoder_hidden, attn = self.forward_step(
                decoder_input, decoder_hidden, attn_context)

            # in teach forcing mode, we don't need symbols.
            decoder_outputs = decoder_output

        else:
            # do free running here
            cum_sum = None
            for di in range(self.max_length):
                decoder_output, decoder_hidden, step_attn = self.forward_step(
                    decoder_input, decoder_hidden, attn_context)

                cum_sum, symbols = decode(di, cum_sum, decoder_output,
                                          step_attn)
                decoder_input = symbols

            decoder_outputs = torch.cat(decoder_outputs, dim=1)

            if gen_type == 'beam':
                # do back tracking here to recover the 1-best according to
                # beam search.
                final_seq_symbols = []
                cum_sum = cum_sum.view(-1, beam_size)
                max_seq_id = cum_sum.topk(1)[1].data.cpu().view(-1).numpy()
                rev_seq_symbols = sequence_symbols[::-1]
                rev_back_ptrs = back_pointers[::-1]

                for symbols, back_ptrs in zip(rev_seq_symbols, rev_back_ptrs):
                    symbol2ds = symbols.view(-1, beam_size)
                    back2ds = back_ptrs.view(-1, beam_size)

                    selected_symbols = []
                    selected_parents = []
                    for b_id in range(batch_size):
                        selected_parents.append(back2ds[b_id,
                                                        max_seq_id[b_id]])
                        selected_symbols.append(symbol2ds[b_id,
                                                          max_seq_id[b_id]])

                    final_seq_symbols.append(
                        torch.cat(selected_symbols).unsqueeze(1))
                    max_seq_id = torch.cat(selected_parents).data.cpu().numpy()
                sequence_symbols = final_seq_symbols[::-1]

        # save the decoded sequence symbols and sequence length
        ret_dict[DecoderRNN.KEY_SEQUENCE] = sequence_symbols
        ret_dict[DecoderRNN.KEY_LENGTH] = lengths.tolist()

        return decoder_outputs, decoder_hidden, ret_dict
Пример #11
0
    def exp_forward(self, data_feed):
        ctx_lens = data_feed['context_lens']
        batch_size = len(ctx_lens)

        ctx_utts = self.np2var(data_feed['contexts'], LONG)
        out_utts = self.np2var(data_feed['outputs'], LONG)
        output_lens = self.np2var(data_feed['output_lens'], FLOAT)

        # context encoder
        c_inputs = self.utt_encoder(ctx_utts)
        c_outs, c_last = self.ctx_encoder(c_inputs, ctx_lens)
        c_last = c_last.squeeze(0)

        # prior network
        py_logits = self.p_y(c_last).view(-1, self.config.k)
        log_py = F.log_softmax(py_logits, dim=py_logits.dim()-1)

        exp_size = np.power(self.config.k, self.config.y_size)
        sample_y = cast_type(
            Variable(torch.zeros((exp_size * self.config.y_size, self.config.k))), FLOAT, self.use_gpu)
        d = dict((str(i), range(self.config.k)) for i in range(self.config.y_size))
        all_y_ids = []
        for combo in itertools.product(*[d[k] for k in sorted(d.keys())]):
            all_y_ids.append(list(combo))
        np_y_ids = np.array(all_y_ids)
        np_y_ids = self.np2var(np_y_ids, LONG)
        # map sample to initial state of decoder
        sample_y.scatter_(1, np_y_ids.view(-1, 1), 1.0)
        sample_y = sample_y.view(-1, self.config.k * self.config.y_size)

        # pack attention context
        attn_inputs = None
        labels = out_utts[:, 1:].contiguous()
        c_last = c_last.unsqueeze(0)

        nll_xcz = 0.0
        cum_pcs = 0.0
        all_words = torch.sum(output_lens-1)
        for exp_id in range(exp_size):
            cur_sample_y = sample_y[exp_id:exp_id+1]
            cur_sample_y = cur_sample_y.expand(batch_size, self.config.k*self.config.y_size)

            # find out logp(z|c)
            log_pyc = torch.sum(log_py.view(-1, self.config.k*self.config.y_size) * cur_sample_y, dim=1)
            # map sample to initial state of decoder
            dec_init_state = self.c_init_connector(cur_sample_y) + c_last

            # decode
            dec_outs, dec_last, dec_ctx = self.decoder(batch_size,
                                                       out_utts[:, 0:-1],
                                                       dec_init_state,
                                                       attn_context=attn_inputs,
                                                       mode=TEACH_FORCE, gen_type="greedy",
                                                       beam_size=self.config.beam_size)

            output = dec_outs.view(-1, dec_outs.size(-1))
            target = labels.view(-1)
            enc_dec_nll = F.nll_loss(output, target, size_average=False,
                                     ignore_index=self.nll_loss.padding_idx,
                                     weight=self.nll_loss.weight, reduce=False)

            enc_dec_nll = enc_dec_nll.view(-1, dec_outs.size(1))
            enc_dec_nll = torch.sum(enc_dec_nll, dim=1)
            py_c = torch.exp(log_pyc)
            cum_pcs += py_c
            nll_xcz += py_c * enc_dec_nll

        nll_xcz = torch.sum(nll_xcz) / all_words
        return Pack(nll=nll_xcz)
Пример #12
0
 def sample_gumbel(self, logits, use_gpu, eps=1e-20):
     u = torch.rand(logits.size())
     sample = Variable(-torch.log(-torch.log(u + eps) + eps))
     sample = cast_type(sample, FLOAT, use_gpu)
     return sample
Пример #13
0
 def np2var(self, inputs, dtype):
     if inputs is None:
         return None
     return cast_type(Variable(torch.from_numpy(inputs)), dtype,
                      self.use_gpu)
Пример #14
0
    def forward_rl(self, data_feed, max_words, temp=0.1):
        ctx_lens = data_feed['context_lens']
        batch_size = len(ctx_lens)

        ctx_utts = self.np2var(data_feed['contexts'], LONG)
        out_utts = self.np2var(data_feed['outputs'], LONG)

        # context encoder
        c_inputs = self.utt_encoder(ctx_utts)
        c_outs, c_last = self.ctx_encoder(c_inputs, ctx_lens)
        c_last = c_last.squeeze(0)

        # get decoder inputs
        dec_inputs = out_utts[:, :-1]
        labels = out_utts[:, 1:].contiguous()

        # create decoder initial states
        # enc_last = torch.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
        # DB infor is not fed here
        enc_last = c_last
        logits_py, log_py = self.c2z(enc_last)

        qy = F.softmax(logits_py / temp, dim=1)
        log_qy = F.log_softmax(logits_py, dim=1)
        idx = torch.multinomial(qy, 1).detach()

        logprob_sample_z = log_qy.gathcher(1, idx).view(-1)
        joint_logpz = torch.sum(logprob_sample_z)
        sample_y = cast_type(Variable(torch.zeros(log_qy.size())), FLOAT,
                             self.use_gpu)
        sample_y.scatter_(1, idx, 1.0)

        # pack attention context
        if self.config.dec_use_attn:
            z_embeddings = torch.t(self.z_embedding.weight).split(self.k_size,
                                                                  dim=0)
            attn_context = []
            temp_sample_y = sample_y.view(-1, self.config.y_size,
                                          self.config.k_size)
            for z_id in range(self.y_size):
                attn_context.append(
                    torch.mm(temp_sample_y[:, z_id],
                             z_embeddings[z_id]).unsqueeze(1))
            attn_context = torch.cat(attn_context, dim=1)
            dec_init_state = torch.sum(attn_context, dim=1).unsqueeze(0)
        else:
            dec_init_state = self.z_embedding(
                sample_y.view(1, -1, self.config.y_size * self.config.k_size))
            attn_context = None

        # decode
        if self.config.dec_rnn_cell == 'lstm':
            dec_init_state = tuple([dec_init_state, dec_init_state])

        # decode
        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
                                                 dec_init_state=dec_init_state,
                                                 attn_context=attn_context,
                                                 vocab=self.vocab,
                                                 max_words=max_words,
                                                 temp=0.1)
        return logprobs, outs, joint_logpz, sample_y