def __init__(self, corpus, config): super(SysPerfectBD2Gauss, self).__init__(config) self.vocab = corpus.vocab self.vocab_dict = corpus.vocab_dict self.vocab_size = len(self.vocab) self.bos_id = self.vocab_dict[BOS] self.eos_id = self.vocab_dict[EOS] self.pad_id = self.vocab_dict[PAD] self.bs_size = corpus.bs_size self.db_size = corpus.db_size self.y_size = config.y_size self.simple_posterior = config.simple_posterior self.embedding = None self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, embedding_dim=config.embed_size, feat_size=0, goal_nhid=0, rnn_cell=config.utt_rnn_cell, utt_cell_size=config.utt_cell_size, num_layers=config.num_layers, input_dropout_p=config.dropout, output_dropout_p=config.dropout, bidirectional=config.bi_utt_cell, variable_lengths=False, use_attn=config.enc_use_attn, embedding=self.embedding) self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size + self.db_size + self.bs_size, config.y_size, is_lstm=False) self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu) self.z_embedding = nn.Linear(self.y_size, config.dec_cell_size) if not self.simple_posterior: self.xc2z = nn_lib.Hidden2Gaussian( self.utt_encoder.output_size * 2 + self.db_size + self.bs_size, config.y_size, is_lstm=False) self.decoder = DecoderRNN(input_dropout_p=config.dropout, rnn_cell=config.dec_rnn_cell, input_size=config.embed_size, hidden_size=config.dec_cell_size, num_layers=config.num_layers, output_dropout_p=config.dropout, bidirectional=False, vocab_size=self.vocab_size, use_attn=config.dec_use_attn, ctx_cell_size=config.dec_cell_size, attn_mode=config.dec_attn_mode, sys_id=self.bos_id, eos_id=self.eos_id, use_gpu=config.use_gpu, max_dec_len=config.max_dec_len, embedding=self.embedding) self.nll = NLLEntropy(self.pad_id, config.avg_type) self.gauss_kl = NormKLLoss(unit_average=True) self.zero = cast_type(th.zeros(1), FLOAT, self.use_gpu)
def z2dec(self, last_h, requires_grad): logits, log_qy = self.c2z(last_h) if requires_grad: sample_y = self.gumbel_connector(logits) logprob_z = None else: idx = th.multinomial(th.exp(log_qy), 1).detach() logprob_z = th.sum(log_qy.gather(1, idx)) sample_y = utils.cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) sample_y.scatter_(1, idx, 1.0) if self.config.dec_use_attn: z_embeddings = th.t(self.z_embedding.weight).split(self.config.k_size, dim=0) attn_context = [] temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) for z_id in range(self.config.y_size): attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) attn_context = th.cat(attn_context, dim=1) dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) else: attn_context = None dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) return dec_init_state, attn_context, logprob_z
def forward_rl(self, data_feed, max_words, temp=0.1): ctx_lens = data_feed['context_lens'] # (batch_size, ) short_ctx_utts = self.np2var( self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) batch_size = len(ctx_lens) utt_summary, _, enc_outs = self.utt_encoder( short_ctx_utts.unsqueeze(1)) # create decoder initial states enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) # create decoder initial states if self.simple_posterior: logits_py, log_qy = self.c2z(enc_last) else: logits_py, log_qy = self.c2z(enc_last) qy = F.softmax(logits_py / temp, dim=1) # (batch_size, vocab_size, ) log_qy = F.log_softmax(logits_py, dim=1) # (batch_size, vocab_size, ) idx = th.multinomial(qy, 1).detach() logprob_sample_z = log_qy.gather(1, idx).view(-1, self.y_size) joint_logpz = th.sum(logprob_sample_z, dim=1) sample_y = cast_type(Variable(th.zeros(log_qy.size())), FLOAT, self.use_gpu) sample_y.scatter_(1, idx, 1.0) # pack attention context if self.config.dec_use_attn: z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) attn_context = [] temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) for z_id in range(self.y_size): attn_context.append( th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) attn_context = th.cat(attn_context, dim=1) dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) else: dec_init_state = self.z_embedding( sample_y.view(1, -1, self.config.y_size * self.config.k_size)) attn_context = None # decode if self.config.dec_rnn_cell == 'lstm': dec_init_state = tuple([dec_init_state, dec_init_state]) # decode logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, dec_init_state=dec_init_state, attn_context=attn_context, vocab=self.vocab, max_words=max_words, temp=0.1) return logprobs, outs, joint_logpz, sample_y
def forward(self, mu, logvar): """ Sample a sample from a multivariate Gaussian distribution with a diagonal covariance matrix using the reparametrization trick. TODO: this should be better be a instance method in a Gaussian class. :param mu: a tensor of size [batch_size, variable_dim]. Batch_size can be None to support dynamic batching :param logvar: a tensor of size [batch_size, variable_dim]. Batch_size can be None. :return: """ epsilon = th.randn(logvar.size()) epsilon = cast_type(Variable(epsilon), FLOAT, self.use_gpu) std = th.exp(0.5 * logvar) z = mu + std * epsilon return z
def forward(self, logits, temperature=1.0, hard=False, return_max_id=False): """ :param logits: [batch_size, n_class] unnormalized log-prob :param temperature: non-negative scalar :param hard: if True take argmax :param return_max_id :return: [batch_size, n_class] sample from gumbel softmax """ y = self.gumbel_softmax_sample(logits, temperature, self.use_gpu) _, y_hard = th.max(y, dim=1, keepdim=True) if hard: y_onehot = cast_type(Variable(th.zeros(y.size())), FLOAT, self.use_gpu) y_onehot.scatter_(1, y_hard, 1.0) y = y_onehot if return_max_id: return y, y_hard else: return y
def sample_gumbel(self, logits, use_gpu, eps=1e-20): u = th.rand(logits.size()) sample = Variable(-th.log(-th.log(u + eps) + eps)) sample = cast_type(sample, FLOAT, use_gpu) return sample
def __init__(self, corpus, config): super(GaussHRED, self).__init__(config) self.vocab = corpus.vocab self.vocab_dict = corpus.vocab_dict self.vocab_size = len(self.vocab) self.goal_vocab = corpus.goal_vocab self.goal_vocab_dict = corpus.goal_vocab_dict self.goal_vocab_size = len(self.goal_vocab) self.outcome_vocab = corpus.outcome_vocab self.outcome_vocab_dict = corpus.outcome_vocab_dict self.outcome_vocab_size = len(self.outcome_vocab) self.sys_id = self.vocab_dict[SYS] self.eos_id = self.vocab_dict[EOS] self.pad_id = self.vocab_dict[PAD] self.simple_posterior = config.simple_posterior self.goal_encoder = MlpGoalEncoder(goal_vocab_size=self.goal_vocab_size, k=config.k, nembed=config.goal_embed_size, nhid=config.goal_nhid, init_range=config.init_range) self.embedding = nn.Embedding(self.vocab_size, config.embed_size, padding_idx=self.pad_id) self.utt_encoder = RnnUttEncoder(vocab_size=self.vocab_size, embedding_dim=config.embed_size, feat_size=0, goal_nhid=config.goal_nhid, rnn_cell=config.utt_rnn_cell, utt_cell_size=config.utt_cell_size, num_layers=config.num_layers, input_dropout_p=config.dropout, output_dropout_p=config.dropout, bidirectional=config.bi_utt_cell, variable_lengths=False, use_attn=config.enc_use_attn, embedding=self.embedding) self.ctx_encoder = EncoderRNN(input_dropout_p=0.0, rnn_cell=config.ctx_rnn_cell, # input_size=self.utt_encoder.output_size+config.goal_nhid, input_size=self.utt_encoder.output_size, hidden_size=config.ctx_cell_size, num_layers=config.num_layers, output_dropout_p=config.dropout, bidirectional=config.bi_ctx_cell, variable_lengths=False) # mu and logvar projector self.c2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size, config.y_size, is_lstm=False) self.gauss_connector = nn_lib.GaussianConnector(self.use_gpu) self.z_embedding = nn.Linear(config.y_size, config.dec_cell_size) if not self.simple_posterior: self.xc2z = nn_lib.Hidden2Gaussian(self.utt_encoder.output_size+self.ctx_encoder.output_size, config.y_size, is_lstm=False) self.decoder = DecoderRNN(input_dropout_p=config.dropout, rnn_cell=config.dec_rnn_cell, input_size=config.embed_size + config.goal_nhid, hidden_size=config.dec_cell_size, num_layers=config.num_layers, output_dropout_p=config.dropout, bidirectional=False, vocab_size=self.vocab_size, use_attn=config.dec_use_attn, ctx_cell_size=self.ctx_encoder.output_size, attn_mode=config.dec_attn_mode, sys_id=self.sys_id, eos_id=self.eos_id, use_gpu=config.use_gpu, max_dec_len=config.max_dec_len, embedding=self.embedding) self.nll = NLLEntropy(self.pad_id, config.avg_type) self.gauss_kl = criterions.NormKLLoss(unit_average=True) self.zero = utils.cast_type(th.zeros(1), FLOAT, self.use_gpu)
def np2var(self, inputs, dtype): if inputs is None: return None return cast_type(Variable(th.from_numpy(inputs)), dtype, self.use_gpu)
def forward_rl(self, batch_size, dec_init_state, attn_context, vocab, max_words, goal_hid=None, mask=True, temp=0.1): # prepare the BOS inputs with th.no_grad(): bos_var = Variable(th.LongTensor([self.sys_id])) bos_var = cast_type(bos_var, LONG, self.use_gpu) decoder_input = bos_var.expand(batch_size, 1) # (1, 1) decoder_hidden_state = dec_init_state # tuple: (h, c) encoder_outputs = attn_context # (1, ctx_len, ctx_cell_size) logprob_outputs = [] # list of logprob | max_dec_len*(1, ) symbol_outputs = [] # list of word ids | max_dec_len*(1, ) if mask: special_token_mask = Variable( th.FloatTensor([ -999. if token in DECODING_MASKED_TOKENS else 0. for token in vocab ])) special_token_mask = cast_type(special_token_mask, FLOAT, self.use_gpu) # (vocab_size, ) def _sample(dec_output, num_i): # dec_output: (1, 1, vocab_size), need to softmax and log_softmax dec_output = dec_output.view(batch_size, -1) # (batch_size, vocab_size, ) prob = F.softmax(dec_output / temp, dim=1) # (batch_size, vocab_size, ) logprob = F.log_softmax(dec_output, dim=1) # (batch_size, vocab_size, ) symbol = prob.multinomial( num_samples=1).detach() # (batch_size, 1) # _, symbol = prob.topk(1) # (1, ) _, tmp_symbol = prob.topk(1) # (1, ) # print('multinomial symbol = {}, prob = {}'.format(symbol, prob[symbol.item()])) # print('topk symbol = {}, prob = {}'.format(tmp_symbol, prob[tmp_symbol.item()])) logprob = logprob.gather(1, symbol) # (1, ) return logprob, symbol stopped_samples = set() for i in range(max_words): decoder_output, decoder_hidden_state = self._step( decoder_input, decoder_hidden_state, encoder_outputs, goal_hid) # disable special tokens from being generated in a normal turn if mask: decoder_output += special_token_mask.expand(1, 1, -1) logprob, symbol = _sample(decoder_output, i) logprob_outputs.append(logprob) symbol_outputs.append(symbol) decoder_input = symbol.view(batch_size, -1) for b_id in range(batch_size): if vocab[symbol[b_id].item()] == EOS: stopped_samples.add(b_id) if len(stopped_samples) == batch_size: break assert len(logprob_outputs) == len(symbol_outputs) symbol_outputs = th.cat(symbol_outputs, dim=1).cpu().data.numpy().tolist() logprob_outputs = th.cat(logprob_outputs, dim=1) logprob_list = [] symbol_list = [] for b_id in range(batch_size): b_logprob = [] b_symbol = [] for t_id in range(logprob_outputs.shape[1]): symbol = symbol_outputs[b_id][t_id] if vocab[symbol] == EOS and t_id != 0: break b_symbol.append(symbol_outputs[b_id][t_id]) b_logprob.append(logprob_outputs[b_id][t_id]) logprob_list.append(b_logprob) symbol_list.append(b_symbol) # TODO backward compatible, if batch_size == 1, we remove the nested structure if batch_size == 1: logprob_list = logprob_list[0] symbol_list = symbol_list[0] return logprob_list, symbol_list
def write(self, input_var, hidden_state, encoder_outputs, max_words, vocab, stop_tokens, goal_hid=None, mask=True, decoding_masked_tokens=DECODING_MASKED_TOKENS): # input_var: (1, 1) # hidden_state: tuple: (h, c) # encoder_outputs: max_dlg_len*(1, 1, dlg_cell_size) # goal_hid: (1, goal_nhid) logprob_outputs = [] # list of logprob | max_dec_len*(1, ) symbol_outputs = [] # list of word ids | max_dec_len*(1, ) decoder_input = input_var decoder_hidden_state = hidden_state if type(encoder_outputs) is list: encoder_outputs = th.cat(encoder_outputs, 1) # (1, max_dlg_len, dlg_cell_size) # print('encoder_outputs.size() = {}'.format(encoder_outputs.size())) if mask: special_token_mask = Variable( th.FloatTensor([ -999. if token in decoding_masked_tokens else 0. for token in vocab ])) special_token_mask = cast_type(special_token_mask, FLOAT, self.use_gpu) # (vocab_size, ) def _sample(dec_output, num_i): # dec_output: (1, 1, vocab_size), need to softmax and log_softmax dec_output = dec_output.view(-1) # (vocab_size, ) # TODO temperature prob = F.softmax(dec_output / 0.6, dim=0) # (vocab_size, ) logprob = F.log_softmax(dec_output, dim=0) # (vocab_size, ) symbol = prob.multinomial(num_samples=1).detach() # (1, ) # _, symbol = prob.topk(1) # (1, ) _, tmp_symbol = prob.topk(1) # (1, ) # print('multinomial symbol = {}, prob = {}'.format(symbol, prob[symbol.item()])) # print('topk symbol = {}, prob = {}'.format(tmp_symbol, prob[tmp_symbol.item()])) logprob = logprob.gather(0, symbol) # (1, ) return logprob, symbol for i in range(max_words): decoder_output, decoder_hidden_state = self._step( decoder_input, decoder_hidden_state, encoder_outputs, goal_hid) # disable special tokens from being generated in a normal turn if mask: decoder_output += special_token_mask.expand(1, 1, -1) logprob, symbol = _sample(decoder_output, i) logprob_outputs.append(logprob) symbol_outputs.append(symbol) decoder_input = symbol.view(1, -1) if vocab[symbol.item()] in stop_tokens: break assert len(logprob_outputs) == len(symbol_outputs) # logprob_list = [t.item() for t in logprob_outputs] logprob_list = logprob_outputs symbol_list = [t.item() for t in symbol_outputs] return logprob_list, symbol_list
def forward(self, batch_size, dec_inputs, dec_init_state, attn_context, mode, gen_type, beam_size, goal_hid=None): # dec_inputs: (batch_size, response_size-1) # attn_context: (batch_size, max_ctx_len, ctx_cell_size) # goal_hid: (batch_size, goal_nhid) ret_dict = dict() if self.use_attn: ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list() if mode == GEN: dec_inputs = None if gen_type != 'beam': beam_size = 1 if dec_inputs is not None: decoder_input = dec_inputs else: # prepare the BOS inputs with th.no_grad(): bos_var = Variable(th.LongTensor([self.sys_id])) bos_var = cast_type(bos_var, LONG, self.use_gpu) decoder_input = bos_var.expand(batch_size * beam_size, 1) # (batch_size, 1) if mode == GEN and gen_type == 'beam': # TODO if beam search, repeat the initial states of the RNN pass else: decoder_hidden_state = dec_init_state prob_outputs = [ ] # list of logprob | max_dec_len*(batch_size, 1, vocab_size) symbol_outputs = [] # list of word ids | max_dec_len*(batch_size, 1) # back_pointers = [] # lengths = blabla... def decode(step, cum_sum, step_output, step_attn): prob_outputs.append(step_output) step_output_slice = step_output.squeeze( 1) # (batch_size, vocab_size) if self.use_attn: ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn) if gen_type == 'greedy': _, symbols = step_output_slice.topk(1) # (batch_size, 1) elif gen_type == 'sample': # TODO FIXME # symbols = self.gumbel_max(step_output_slice) pass elif gen_type == 'beam': # TODO pass else: raise ValueError('Unsupported decoding mode') symbol_outputs.append(symbols) return cum_sum, symbols if mode == TEACH_FORCE: prob_outputs, decoder_hidden_state, attn = self.forward_step( input_var=decoder_input, hidden_state=decoder_hidden_state, encoder_outputs=attn_context, goal_hid=goal_hid) else: # do free running here cum_sum = None for step in range(self.max_dec_len): # Input: # decoder_input: (batch_size, 1) # decoder_hidden_state: tuple: (h, c) # attn_context: (batch_size, max_ctx_len, ctx_cell_size) # goal_hid: (batch_size, goal_nhid) # Output: # decoder_output: (batch_size, 1, vocab_size) # decoder_hidden_state: tuple: (h, c) # step_attn: (batch_size, 1, max_ctx_len) decoder_output, decoder_hidden_state, step_attn = self.forward_step( decoder_input, decoder_hidden_state, attn_context, goal_hid=goal_hid) cum_sum, symbols = decode(step, cum_sum, decoder_output, step_attn) decoder_input = symbols prob_outputs = th.cat( prob_outputs, dim=1) # (batch_size, max_dec_len, vocab_size) # back tracking to recover the 1-best in beam search # if gen_type == 'beam': ret_dict[DecoderRNN.KEY_SEQUENCE] = symbol_outputs # prob_outputs: (batch_size, max_dec_len, vocab_size) # decoder_hidden_state: tuple: (h, c) # ret_dict[DecoderRNN.KEY_ATTN_SCORE]: max_dec_len*(batch_size, 1, max_ctx_len) # ret_dict[DecoderRNN.KEY_SEQUENCE]: max_dec_len*(batch_size, 1) return prob_outputs, decoder_hidden_state, ret_dict