Ejemplo n.º 1
0
    def __init__(self, corpus, config):
        super(SysPerfectBD2Gauss, self).__init__(config)
        self.vocab = corpus.vocab
        self.vocab_dict = corpus.vocab_dict
        self.vocab_size = len(self.vocab)
        self.bos_id = self.vocab_dict[BOS]
        self.eos_id = self.vocab_dict[EOS]
        self.pad_id = self.vocab_dict[PAD]
        self.bs_size = corpus.bs_size
        self.db_size = corpus.db_size
        self.y_size = config.y_size
        self.simple_posterior = config.simple_posterior

        self.embedding = None
        self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
                                         embedding_dim=config.embed_size,
                                         feat_size=0,
                                         goal_nhid=0,
                                         rnn_cell=config.utt_rnn_cell,
                                         utt_cell_size=config.utt_cell_size,
                                         num_layers=config.num_layers,
                                         input_dropout_p=config.dropout,
                                         output_dropout_p=config.dropout,
                                         bidirectional=config.bi_utt_cell,
                                         variable_lengths=False,
                                         use_attn=config.enc_use_attn,
                                         embedding=self.embedding)

        self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size +
                                          self.db_size + self.bs_size,
                                          config.y_size,
                                          is_lstm=False)
        self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu)
        self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size)
        if not self.simple_posterior:
            self.xc2z = nn_lib.Hidden2Gaussian(
                self.utt_encoder.output_size * 2 + self.db_size + self.bs_size,
                config.y_size,
                is_lstm=False)

        self.decoder = DecoderRNN(input_dropout_p=config.dropout,
                                  rnn_cell=config.dec_rnn_cell,
                                  input_size=config.embed_size,
                                  hidden_size=config.dec_cell_size,
                                  num_layers=config.num_layers,
                                  output_dropout_p=config.dropout,
                                  bidirectional=False,
                                  vocab_size=self.vocab_size,
                                  use_attn=config.dec_use_attn,
                                  ctx_cell_size=config.dec_cell_size,
                                  attn_mode=config.dec_attn_mode,
                                  sys_id=self.bos_id,
                                  eos_id=self.eos_id,
                                  use_gpu=config.use_gpu,
                                  max_dec_len=config.max_dec_len,
                                  embedding=self.embedding)

        self.nll = NLLEntropy(self.pad_id, config.avg_type)
        self.gauss_kl = NormKLLoss(unit_average=True)
        self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu)
Ejemplo n.º 2
0
    def z2dec(self, last_h, requires_grad):
        logits, log_qy = self.c2z(last_h)

        if requires_grad:
            sample_y = self.gumbel_connector(logits)
            logprob_z = None
        else:
            idx = th.multinomial(th.exp(log_qy), 1).detach()
            logprob_z = th.sum(log_qy.gather(1, idx))
            sample_y = utils.cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu)
            sample_y.scatter_(1, idx, 1.0)

        if self.config.dec_use_attn:
            z_embeddings = th.t(self.z_embedding.weight).split(self.config.k_size, dim=0)
            attn_context = []
            temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size)
            for z_id in range(self.config.y_size):
                attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1))
            attn_context = th.cat(attn_context, dim=1)
            dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
        else:
            attn_context = None
            dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size))

        return dec_init_state, attn_context, logprob_z
Ejemplo n.º 3
0
    def forward_rl(self, data_feed, max_words, temp=0.1):
        ctx_lens = data_feed['context_lens']  # (batch_size, )
        short_ctx_utts = self.np2var(
            self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
        bs_label = self.np2var(data_feed['bs'],
                               FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
        db_label = self.np2var(data_feed['db'],
                               FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
        batch_size = len(ctx_lens)

        utt_summary, _, enc_outs = self.utt_encoder(
            short_ctx_utts.unsqueeze(1))

        # create decoder initial states
        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
        # create decoder initial states
        if self.simple_posterior:
            logits_py, log_qy = self.c2z(enc_last)
        else:
            logits_py, log_qy = self.c2z(enc_last)

        qy = F.softmax(logits_py / temp, dim=1)  # (batch_size, vocab_size, )
        log_qy = F.log_softmax(logits_py, dim=1)  # (batch_size, vocab_size, )
        idx = th.multinomial(qy, 1).detach()
        logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size)
        joint_logpz = th.sum(logprob_sample_z, dim=1)
        sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT,
                             self.use_gpu)
        sample_y.scatter_(1, idx, 1.0)

        # pack attention context
        if self.config.dec_use_attn:
            z_embeddings = th.t(self.z_embedding.weight).split(self.k_size,
                                                               dim=0)
            attn_context = []
            temp_sample_y = sample_y.view(-1, self.config.y_size,
                                          self.config.k_size)
            for z_id in range(self.y_size):
                attn_context.append(
                    th.mm(temp_sample_y[:, z_id],
                          z_embeddings[z_id]).unsqueeze(1))
            attn_context = th.cat(attn_context, dim=1)
            dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
        else:
            dec_init_state = self.z_embedding(
                sample_y.view(1, -1, self.config.y_size * self.config.k_size))
            attn_context = None

        # decode
        if self.config.dec_rnn_cell == 'lstm':
            dec_init_state = tuple([dec_init_state, dec_init_state])

        # decode
        logprobs, outs = self.decoder.forward_rl(batch_size=batch_size,
                                                 dec_init_state=dec_init_state,
                                                 attn_context=attn_context,
                                                 vocab=self.vocab,
                                                 max_words=max_words,
                                                 temp=0.1)
        return logprobs, outs, joint_logpz, sample_y
Ejemplo n.º 4
0
 def forward(self, mu, logvar):
     """
     Sample a sample from a multivariate Gaussian distribution with a diagonal covariance matrix using the
     reparametrization trick.
     TODO: this should be better be a instance method in a Gaussian class.
     :param mu: a tensor of size [batch_size, variable_dim]. Batch_size can be None to support dynamic batching
     :param logvar: a tensor of size [batch_size, variable_dim]. Batch_size can be None.
     :return:
     """
     epsilon = th.randn(logvar.size())
     epsilon = cast_type(Variable(epsilon), FLOAT, self.use_gpu)
     std = th.exp(0.5 * logvar)
     z = mu + std * epsilon
     return z
Ejemplo n.º 5
0
Archivo: nn_lib.py Proyecto: qywu/ARDM
 def forward(self, logits, temperature=1.0, hard=False,
             return_max_id=False):
     """
     :param logits: [batch_size, n_class] unnormalized log-prob
     :param temperature: non-negative scalar
     :param hard: if True take argmax
     :param return_max_id
     :return: [batch_size, n_class] sample from gumbel softmax
     """
     y = self.gumbel_softmax_sample(logits, temperature, self.use_gpu)
     _, y_hard = th.max(y, dim=1, keepdim=True)
     if hard:
         y_onehot = cast_type(Variable(th.zeros(y.size())), FLOAT, self.use_gpu)
         y_onehot.scatter_(1, y_hard, 1.0)
         y = y_onehot
     if return_max_id:
         return y, y_hard
     else:
         return y
Ejemplo n.º 6
0
 def sample_gumbel(self, logits, use_gpu, eps=1e-20):
     u = th.rand(logits.size())
     sample = Variable(-th.log(-th.log(u + eps) + eps))
     sample = cast_type(sample, FLOAT, use_gpu)
     return sample
Ejemplo n.º 7
0
    def __init__(self, corpus, config):
        super(GaussHRED, self).__init__(config)

        self.vocab = corpus.vocab
        self.vocab_dict = corpus.vocab_dict
        self.vocab_size = len(self.vocab)
        self.goal_vocab = corpus.goal_vocab
        self.goal_vocab_dict = corpus.goal_vocab_dict
        self.goal_vocab_size = len(self.goal_vocab)
        self.outcome_vocab = corpus.outcome_vocab
        self.outcome_vocab_dict = corpus.outcome_vocab_dict
        self.outcome_vocab_size = len(self.outcome_vocab)
        self.sys_id = self.vocab_dict[SYS]
        self.eos_id = self.vocab_dict[EOS]
        self.pad_id = self.vocab_dict[PAD]
        self.simple_posterior = config.simple_posterior

        self.goal_encoder = MlpGoalEncoder(goal_vocab_size=self.goal_vocab_size,
                                           k=config.k,
                                           nembed=config.goal_embed_size,
                                           nhid=config.goal_nhid,
                                           init_range=config.init_range)

        self.embedding = nn.Embedding(self.vocab_size, config.embed_size, padding_idx=self.pad_id)
        self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size,
                                         embedding_dim=config.embed_size,
                                         feat_size=0,
                                         goal_nhid=config.goal_nhid,
                                         rnn_cell=config.utt_rnn_cell,
                                         utt_cell_size=config.utt_cell_size,
                                         num_layers=config.num_layers,
                                         input_dropout_p=config.dropout,
                                         output_dropout_p=config.dropout,
                                         bidirectional=config.bi_utt_cell,
                                         variable_lengths=False,
                                         use_attn=config.enc_use_attn,
                                         embedding=self.embedding)

        self.ctx_encoder = EncoderRNN(input_dropout_p=0.0,
                                      rnn_cell=config.ctx_rnn_cell,
                                      # input_size=self.utt_encoder.output_size+config.goal_nhid,
                                      input_size=self.utt_encoder.output_size,
                                      hidden_size=config.ctx_cell_size,
                                      num_layers=config.num_layers,
                                      output_dropout_p=config.dropout,
                                      bidirectional=config.bi_ctx_cell,
                                      variable_lengths=False)
        # mu and logvar projector
        self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False)
        self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu)
        self.z_embedding = nn.Linear(config.y_size, config.dec_cell_size)
        if not self.simple_posterior:
            self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size+self.ctx_encoder.output_size, config.y_size, is_lstm=False)

        self.decoder = DecoderRNN(input_dropout_p=config.dropout,
                                  rnn_cell=config.dec_rnn_cell,
                                  input_size=config.embed_size + config.goal_nhid,
                                  hidden_size=config.dec_cell_size,
                                  num_layers=config.num_layers,
                                  output_dropout_p=config.dropout,
                                  bidirectional=False,
                                  vocab_size=self.vocab_size,
                                  use_attn=config.dec_use_attn,
                                  ctx_cell_size=self.ctx_encoder.output_size,
                                  attn_mode=config.dec_attn_mode,
                                  sys_id=self.sys_id,
                                  eos_id=self.eos_id,
                                  use_gpu=config.use_gpu,
                                  max_dec_len=config.max_dec_len,
                                  embedding=self.embedding)

        self.nll = NLLEntropy(self.pad_id, config.avg_type)
        self.gauss_kl = criterions.NormKLLoss(unit_average=True)
        self.zero = utils.cast_type(th.zeros(1), FLOAT, self.use_gpu)
Ejemplo n.º 8
0
 def np2var(self, inputs, dtype):
     if inputs is None:
         return None
     return cast_type(Variable(th.from_numpy(inputs)), dtype, self.use_gpu)
Ejemplo n.º 9
0
    def forward_rl(self,
                   batch_size,
                   dec_init_state,
                   attn_context,
                   vocab,
                   max_words,
                   goal_hid=None,
                   mask=True,
                   temp=0.1):
        # prepare the BOS inputs
        with th.no_grad():
            bos_var = Variable(th.LongTensor([self.sys_id]))
        bos_var = cast_type(bos_var, LONG, self.use_gpu)
        decoder_input = bos_var.expand(batch_size, 1)  # (1, 1)
        decoder_hidden_state = dec_init_state  # tuple: (h, c)
        encoder_outputs = attn_context  # (1, ctx_len, ctx_cell_size)

        logprob_outputs = []  # list of logprob | max_dec_len*(1, )
        symbol_outputs = []  # list of word ids | max_dec_len*(1, )

        if mask:
            special_token_mask = Variable(
                th.FloatTensor([
                    -999. if token in DECODING_MASKED_TOKENS else 0.
                    for token in vocab
                ]))
            special_token_mask = cast_type(special_token_mask, FLOAT,
                                           self.use_gpu)  # (vocab_size, )

        def _sample(dec_output, num_i):
            # dec_output: (1, 1, vocab_size), need to softmax and log_softmax
            dec_output = dec_output.view(batch_size,
                                         -1)  # (batch_size, vocab_size, )
            prob = F.softmax(dec_output / temp,
                             dim=1)  # (batch_size, vocab_size, )
            logprob = F.log_softmax(dec_output,
                                    dim=1)  # (batch_size, vocab_size, )
            symbol = prob.multinomial(
                num_samples=1).detach()  # (batch_size, 1)
            # _, symbol = prob.topk(1) # (1, )
            _, tmp_symbol = prob.topk(1)  # (1, )
            # print('multinomial symbol = {}, prob = {}'.format(symbol, prob[symbol.item()]))
            # print('topk symbol = {}, prob = {}'.format(tmp_symbol, prob[tmp_symbol.item()]))
            logprob = logprob.gather(1, symbol)  # (1, )
            return logprob, symbol

        stopped_samples = set()
        for i in range(max_words):
            decoder_output, decoder_hidden_state = self._step(
                decoder_input, decoder_hidden_state, encoder_outputs, goal_hid)
            # disable special tokens from being generated in a normal turn
            if mask:
                decoder_output += special_token_mask.expand(1, 1, -1)
            logprob, symbol = _sample(decoder_output, i)
            logprob_outputs.append(logprob)
            symbol_outputs.append(symbol)
            decoder_input = symbol.view(batch_size, -1)
            for b_id in range(batch_size):
                if vocab[symbol[b_id].item()] == EOS:
                    stopped_samples.add(b_id)

            if len(stopped_samples) == batch_size:
                break

        assert len(logprob_outputs) == len(symbol_outputs)
        symbol_outputs = th.cat(symbol_outputs,
                                dim=1).cpu().data.numpy().tolist()
        logprob_outputs = th.cat(logprob_outputs, dim=1)
        logprob_list = []
        symbol_list = []
        for b_id in range(batch_size):
            b_logprob = []
            b_symbol = []
            for t_id in range(logprob_outputs.shape[1]):
                symbol = symbol_outputs[b_id][t_id]
                if vocab[symbol] == EOS and t_id != 0:
                    break

                b_symbol.append(symbol_outputs[b_id][t_id])
                b_logprob.append(logprob_outputs[b_id][t_id])

            logprob_list.append(b_logprob)
            symbol_list.append(b_symbol)

        # TODO backward compatible, if batch_size == 1, we remove the nested structure
        if batch_size == 1:
            logprob_list = logprob_list[0]
            symbol_list = symbol_list[0]

        return logprob_list, symbol_list
Ejemplo n.º 10
0
    def write(self,
              input_var,
              hidden_state,
              encoder_outputs,
              max_words,
              vocab,
              stop_tokens,
              goal_hid=None,
              mask=True,
              decoding_masked_tokens=DECODING_MASKED_TOKENS):
        # input_var: (1, 1)
        # hidden_state: tuple: (h, c)
        # encoder_outputs: max_dlg_len*(1, 1, dlg_cell_size)
        # goal_hid: (1, goal_nhid)
        logprob_outputs = []  # list of logprob | max_dec_len*(1, )
        symbol_outputs = []  # list of word ids | max_dec_len*(1, )
        decoder_input = input_var
        decoder_hidden_state = hidden_state
        if type(encoder_outputs) is list:
            encoder_outputs = th.cat(encoder_outputs,
                                     1)  # (1, max_dlg_len, dlg_cell_size)
        # print('encoder_outputs.size() = {}'.format(encoder_outputs.size()))

        if mask:
            special_token_mask = Variable(
                th.FloatTensor([
                    -999. if token in decoding_masked_tokens else 0.
                    for token in vocab
                ]))
            special_token_mask = cast_type(special_token_mask, FLOAT,
                                           self.use_gpu)  # (vocab_size, )

        def _sample(dec_output, num_i):
            # dec_output: (1, 1, vocab_size), need to softmax and log_softmax
            dec_output = dec_output.view(-1)  # (vocab_size, )
            # TODO temperature
            prob = F.softmax(dec_output / 0.6, dim=0)  # (vocab_size, )
            logprob = F.log_softmax(dec_output, dim=0)  # (vocab_size, )
            symbol = prob.multinomial(num_samples=1).detach()  # (1, )
            # _, symbol = prob.topk(1) # (1, )
            _, tmp_symbol = prob.topk(1)  # (1, )
            # print('multinomial symbol = {}, prob = {}'.format(symbol, prob[symbol.item()]))
            # print('topk symbol = {}, prob = {}'.format(tmp_symbol, prob[tmp_symbol.item()]))
            logprob = logprob.gather(0, symbol)  # (1, )
            return logprob, symbol

        for i in range(max_words):
            decoder_output, decoder_hidden_state = self._step(
                decoder_input, decoder_hidden_state, encoder_outputs, goal_hid)
            # disable special tokens from being generated in a normal turn
            if mask:
                decoder_output += special_token_mask.expand(1, 1, -1)
            logprob, symbol = _sample(decoder_output, i)
            logprob_outputs.append(logprob)
            symbol_outputs.append(symbol)
            decoder_input = symbol.view(1, -1)

            if vocab[symbol.item()] in stop_tokens:
                break

        assert len(logprob_outputs) == len(symbol_outputs)
        # logprob_list = [t.item() for t in logprob_outputs]
        logprob_list = logprob_outputs
        symbol_list = [t.item() for t in symbol_outputs]
        return logprob_list, symbol_list
Ejemplo n.º 11
0
    def forward(self,
                batch_size,
                dec_inputs,
                dec_init_state,
                attn_context,
                mode,
                gen_type,
                beam_size,
                goal_hid=None):
        # dec_inputs: (batch_size, response_size-1)
        # attn_context: (batch_size, max_ctx_len, ctx_cell_size)
        # goal_hid: (batch_size, goal_nhid)

        ret_dict = dict()

        if self.use_attn:
            ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list()

        if mode == GEN:
            dec_inputs = None

        if gen_type != 'beam':
            beam_size = 1

        if dec_inputs is not None:
            decoder_input = dec_inputs
        else:
            # prepare the BOS inputs
            with th.no_grad():
                bos_var = Variable(th.LongTensor([self.sys_id]))
            bos_var = cast_type(bos_var, LONG, self.use_gpu)
            decoder_input = bos_var.expand(batch_size * beam_size,
                                           1)  # (batch_size, 1)

        if mode == GEN and gen_type == 'beam':
            # TODO if beam search, repeat the initial states of the RNN
            pass
        else:
            decoder_hidden_state = dec_init_state

        prob_outputs = [
        ]  # list of logprob | max_dec_len*(batch_size, 1, vocab_size)
        symbol_outputs = []  # list of word ids | max_dec_len*(batch_size, 1)

        # back_pointers = []
        # lengths = blabla...

        def decode(step, cum_sum, step_output, step_attn):
            prob_outputs.append(step_output)
            step_output_slice = step_output.squeeze(
                1)  # (batch_size, vocab_size)
            if self.use_attn:
                ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn)

            if gen_type == 'greedy':
                _, symbols = step_output_slice.topk(1)  # (batch_size, 1)
            elif gen_type == 'sample':
                # TODO FIXME
                # symbols = self.gumbel_max(step_output_slice)
                pass
            elif gen_type == 'beam':
                # TODO
                pass
            else:
                raise ValueError('Unsupported decoding mode')

            symbol_outputs.append(symbols)

            return cum_sum, symbols

        if mode == TEACH_FORCE:
            prob_outputs, decoder_hidden_state, attn = self.forward_step(
                input_var=decoder_input,
                hidden_state=decoder_hidden_state,
                encoder_outputs=attn_context,
                goal_hid=goal_hid)
        else:
            # do free running here
            cum_sum = None
            for step in range(self.max_dec_len):
                # Input:
                #   decoder_input: (batch_size, 1)
                #   decoder_hidden_state: tuple: (h, c)
                #   attn_context: (batch_size, max_ctx_len, ctx_cell_size)
                #   goal_hid: (batch_size, goal_nhid)
                # Output:
                #   decoder_output: (batch_size, 1, vocab_size)
                #   decoder_hidden_state: tuple: (h, c)
                #   step_attn: (batch_size, 1, max_ctx_len)
                decoder_output, decoder_hidden_state, step_attn = self.forward_step(
                    decoder_input,
                    decoder_hidden_state,
                    attn_context,
                    goal_hid=goal_hid)
                cum_sum, symbols = decode(step, cum_sum, decoder_output,
                                          step_attn)
                decoder_input = symbols

            prob_outputs = th.cat(
                prob_outputs, dim=1)  # (batch_size, max_dec_len, vocab_size)

            # back tracking to recover the 1-best in beam search
            # if gen_type == 'beam':

        ret_dict[DecoderRNN.KEY_SEQUENCE] = symbol_outputs

        # prob_outputs: (batch_size, max_dec_len, vocab_size)
        # decoder_hidden_state: tuple: (h, c)
        # ret_dict[DecoderRNN.KEY_ATTN_SCORE]: max_dec_len*(batch_size, 1, max_ctx_len)
        # ret_dict[DecoderRNN.KEY_SEQUENCE]: max_dec_len*(batch_size, 1)
        return prob_outputs, decoder_hidden_state, ret_dict