def np2var(self, inputs, dtype): if inputs is None: return None if type(inputs) == list: return cast_type(Variable(torch.Tensor(inputs)), dtype, self.use_gpu) return cast_type(Variable(torch.from_numpy(inputs)), dtype, self.use_gpu)
def sweep(self, data_feed, gen_type='greedy'): ctx_lens = data_feed['output_lens'] batch_size = len(ctx_lens) out_utts = self.np2var(data_feed['outputs'], LONG) # output encoder output_embedding = self.embedding(out_utts) x_outs, x_last = self.x_encoder(output_embedding) x_last = x_last.transpose(0, 1).contiguous().view(-1, self.enc_out_size) # posterior network qy_logits = self.q_y(x_last).view(-1, self.config.k) # switch that controls the sampling sample_y, y_id = self.cat_connector(qy_logits, 1.0, self.use_gpu, hard=True, return_max_id=True) y_id = y_id.view(-1, self.config.y_size) start_y_id = y_id[0] end_y_id = y_id[batch_size-1] # start sweeping all_y_ids = [start_y_id] for idx in range(self.config.y_size): mask = torch.zeros(self.config.y_size) mask[0:idx+1] = 1.0 neg_mask = 1 - mask mask = cast_type(Variable(mask), LONG, self.use_gpu) neg_mask = cast_type(Variable(neg_mask), LONG, self.use_gpu) temp_y = neg_mask * start_y_id + mask * end_y_id all_y_ids.append(temp_y) num_steps = len(all_y_ids) all_y_ids = torch.cat(all_y_ids, dim=0).view(num_steps, -1) sample_y = cast_type(Variable(torch.zeros((num_steps*self.config.y_size, self.config.k))), FLOAT, self.use_gpu) sample_y.scatter_(1, all_y_ids.view(-1, 1), 1.0) sample_y = sample_y.view(-1, self.config.k * self.config.y_size) batch_size = num_steps # map sample to initial state of decoder dec_init_state = self.dec_init_connector(sample_y) # get decoder inputs labels = out_utts[:, 1:].contiguous() dec_inputs = out_utts[:, 0:-1] # decode dec_outs, dec_last, dec_ctx = self.decoder(batch_size, dec_inputs, dec_init_state, mode=GEN, gen_type=gen_type, beam_size=self.beam_size) # compute loss or return results return dec_ctx, labels, all_y_ids
def exp_enumerate(self, repeat=1, gen_type='greedy'): # do something here. For each y, we enumerate from 0 to K # and take the expectation of other values. batch_size = np.power(self.config.k, self.config.y_size) * repeat sample_y = cast_type(Variable(torch.zeros((batch_size*self.config.y_size, self.config.k))), FLOAT, self.use_gpu) d = dict((str(i), range(self.config.k)) for i in range(self.config.y_size)) all_y_ids = [] for combo in itertools.product(*[d[k] for k in sorted(d.keys())]): all_y_ids.append(list(combo)) np_y_ids = np.array(all_y_ids) np_y_ids = self.np2var(np_y_ids, LONG) # map sample to initial state of decoder sample_y.scatter_(1, np_y_ids.view(-1, 1), 1.0) sample_y = sample_y.view(-1, self.config.k * self.config.y_size) dec_init_state = self.dec_init_connector(sample_y) # decode dec_outs, dec_last, dec_ctx = self.decoder(batch_size, None, dec_init_state, mode=GEN, gen_type=gen_type, beam_size=self.beam_size) return dec_ctx, all_y_ids
def enumerate(self, repeat=1, gen_type='greedy'): # do something here. For each y, we enumerate from 0 to K # and take the expectation of other values. batch_size = self.config.y_size * self.config.k * repeat sample_y = cast_type(Variable(torch.zeros((batch_size, self.config.y_size, self.config.k))), FLOAT, self.use_gpu) sample_y += 1.0/self.config.k for y_id in range(self.config.y_size): for k_id in range(self.config.k): for r_id in range(repeat): idx = y_id*self.config.k + k_id*repeat + r_id sample_y[idx, y_id] = 0.0 sample_y[idx, y_id, k_id] = 1.0 # map sample to initial state of decoder sample_y = sample_y.view(-1, self.config.k * self.config.y_size) dec_init_state = self.dec_init_connector(sample_y) # decode dec_outs, dec_last, dec_ctx = self.decoder(batch_size, None, dec_init_state, mode=GEN, gen_type=gen_type, beam_size=self.beam_size) # compute loss or return results return dec_ctx
def extract_name(self, action_id): if str(action_id) in self.action2name.keys(): action_name = self.action2name[str(action_id)] name = torch.Tensor(map(int, action_name.strip().split('-'))) name = cast_type(name, LONG, self.use_gpu) return name else: return 'empty'
def __init__(self, padding_idx, config, rev_vocab=None, key_vocab=None): super(NLLEntropy, self).__init__() self.padding_idx = padding_idx self.avg_type = config.avg_type if rev_vocab is None or key_vocab is None: self.weight = None else: self.logger.info("Use extra cost for key words") weight = np.ones(len(rev_vocab)) for key in key_vocab: weight[rev_vocab[key]] = 10.0 self.weight = cast_type(torch.from_numpy(weight), FLOAT, config.use_gpu)
def gumbel_max(self, log_probs): """ Obtain a sample from the Gumbel max. Not this is not differentibale. :param log_probs: [batch_size x vocab_size] :return: [batch_size x 1] selected token IDs """ sample = torch.Tensor(log_probs.size()).uniform_(0, 1) sample = cast_type(Variable(sample), FLOAT, self.use_gpu) # compute the gumbel sample matrix_u = -1.0 * torch.log(-1.0 * torch.log(sample)) gumbel_log_probs = log_probs + matrix_u max_val, max_ids = torch.max(gumbel_log_probs, dim=-1, keepdim=True) return max_ids
def forward(self, mu, logvar, use_gpu): """ Sample a sample from a multivariate Gaussian distribution with a diagonal covariance matrix using the reparametrization trick. TODO: this should be better be a instance method in a Gaussian class. :param mu: a tensor of size [batch_size, variable_dim]. Batch_size can be None to support dynamic batching :param logvar: a tensor of size [batch_size, variable_dim]. Batch_size can be None. :return: """ epsilon = torch.randn(logvar.size()) epsilon = cast_type(Variable(epsilon), FLOAT, use_gpu) std = torch.exp(0.5 * logvar) z = mu + std * epsilon return z
def forward(self, logits, use_gpu, return_max_id=False): """ :param logits: [batch_size, n_class] unnormalized log-prob :param temperature: non-negative scalar :param hard: if True take argmax :return: [batch_size, n_class] sample from gumbel softmax """ _, y_hard = torch.max(logits, dim=1, keepdim=True) y_onehot = cast_type(Variable(torch.zeros(logits.size())), FLOAT, use_gpu) y_onehot.scatter_(1, y_hard, 1.0) y = y_onehot if return_max_id: return y, y_hard else: return y
def forward(self, batch_size, inputs=None, init_state=None, attn_context=None, mode=TEACH_FORCE, gen_type='greedy', beam_size=4): # sanity checks ret_dict = dict() if self.use_attention: # calculate initial attention ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list() if mode == GEN: inputs = None if gen_type != 'beam': beam_size = 1 if inputs is not None: decoder_input = inputs else: # prepare the BOS inputs bos_var = Variable(torch.LongTensor([self.sos_id]), volatile=True) bos_var = cast_type(bos_var, LONG, self.use_gpu) decoder_input = bos_var.expand(batch_size * beam_size, 1) if mode == GEN and gen_type == 'beam': # if beam search, repeat the initial states of the RNN if self.rnn_cell is nn.LSTM: h, c = init_state decoder_hidden = (self.repeat_state(h, batch_size, beam_size), self.repeat_state(c, batch_size, beam_size)) else: decoder_hidden = self.repeat_state(init_state, batch_size, beam_size) else: decoder_hidden = init_state decoder_outputs = [] # a list of logprob sequence_symbols = [] # a list word ids back_pointers = [] # a list of parent beam ID lengths = np.array([self.max_length] * batch_size * beam_size) def decode(step, cum_sum, step_output, step_attn): decoder_outputs.append(step_output) step_output_slice = step_output.squeeze(1) if self.use_attention: ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn) if gen_type == 'greedy': symbols = step_output_slice.topk(1)[1] elif gen_type == 'sample': symbols = self.gumbel_max(step_output_slice) elif gen_type == 'beam': if step == 0: seq_score = step_output_slice.view(batch_size, -1) seq_score = seq_score[:, 0:self.output_size] else: seq_score = cum_sum + step_output_slice seq_score = seq_score.view(batch_size, -1) top_v, top_id = seq_score.topk(beam_size) back_ptr = top_id.div(self.output_size).view(-1, 1) symbols = top_id.fmod(self.output_size).view(-1, 1) cum_sum = top_v.view(-1, 1) back_pointers.append(back_ptr) else: raise ValueError("Unsupported decoding mode") sequence_symbols.append(symbols) eos_batches = symbols.data.eq(self.eos_id) if eos_batches.dim() > 0: eos_batches = eos_batches.cpu().view(-1).numpy() update_idx = ((lengths > di) & eos_batches) != 0 lengths[update_idx] = len(sequence_symbols) return cum_sum, symbols # Manual unrolling is used to support random teacher forcing. # If teacher_forcing_ratio is True or False instead of a probability, # the unrolling can be done in graph if mode == TEACH_FORCE: decoder_output, decoder_hidden, attn = self.forward_step( decoder_input, decoder_hidden, attn_context) # in teach forcing mode, we don't need symbols. decoder_outputs = decoder_output else: # do free running here cum_sum = None for di in range(self.max_length): decoder_output, decoder_hidden, step_attn = self.forward_step( decoder_input, decoder_hidden, attn_context) cum_sum, symbols = decode(di, cum_sum, decoder_output, step_attn) decoder_input = symbols decoder_outputs = torch.cat(decoder_outputs, dim=1) if gen_type == 'beam': # do back tracking here to recover the 1-best according to # beam search. final_seq_symbols = [] cum_sum = cum_sum.view(-1, beam_size) max_seq_id = cum_sum.topk(1)[1].data.cpu().view(-1).numpy() rev_seq_symbols = sequence_symbols[::-1] rev_back_ptrs = back_pointers[::-1] for symbols, back_ptrs in zip(rev_seq_symbols, rev_back_ptrs): symbol2ds = symbols.view(-1, beam_size) back2ds = back_ptrs.view(-1, beam_size) selected_symbols = [] selected_parents = [] for b_id in range(batch_size): selected_parents.append(back2ds[b_id, max_seq_id[b_id]]) selected_symbols.append(symbol2ds[b_id, max_seq_id[b_id]]) final_seq_symbols.append( torch.cat(selected_symbols).unsqueeze(1)) max_seq_id = torch.cat(selected_parents).data.cpu().numpy() sequence_symbols = final_seq_symbols[::-1] # save the decoded sequence symbols and sequence length ret_dict[DecoderRNN.KEY_SEQUENCE] = sequence_symbols ret_dict[DecoderRNN.KEY_LENGTH] = lengths.tolist() return decoder_outputs, decoder_hidden, ret_dict
def exp_forward(self, data_feed): ctx_lens = data_feed['context_lens'] batch_size = len(ctx_lens) ctx_utts = self.np2var(data_feed['contexts'], LONG) out_utts = self.np2var(data_feed['outputs'], LONG) output_lens = self.np2var(data_feed['output_lens'], FLOAT) # context encoder c_inputs = self.utt_encoder(ctx_utts) c_outs, c_last = self.ctx_encoder(c_inputs, ctx_lens) c_last = c_last.squeeze(0) # prior network py_logits = self.p_y(c_last).view(-1, self.config.k) log_py = F.log_softmax(py_logits, dim=py_logits.dim()-1) exp_size = np.power(self.config.k, self.config.y_size) sample_y = cast_type( Variable(torch.zeros((exp_size * self.config.y_size, self.config.k))), FLOAT, self.use_gpu) d = dict((str(i), range(self.config.k)) for i in range(self.config.y_size)) all_y_ids = [] for combo in itertools.product(*[d[k] for k in sorted(d.keys())]): all_y_ids.append(list(combo)) np_y_ids = np.array(all_y_ids) np_y_ids = self.np2var(np_y_ids, LONG) # map sample to initial state of decoder sample_y.scatter_(1, np_y_ids.view(-1, 1), 1.0) sample_y = sample_y.view(-1, self.config.k * self.config.y_size) # pack attention context attn_inputs = None labels = out_utts[:, 1:].contiguous() c_last = c_last.unsqueeze(0) nll_xcz = 0.0 cum_pcs = 0.0 all_words = torch.sum(output_lens-1) for exp_id in range(exp_size): cur_sample_y = sample_y[exp_id:exp_id+1] cur_sample_y = cur_sample_y.expand(batch_size, self.config.k*self.config.y_size) # find out logp(z|c) log_pyc = torch.sum(log_py.view(-1, self.config.k*self.config.y_size) * cur_sample_y, dim=1) # map sample to initial state of decoder dec_init_state = self.c_init_connector(cur_sample_y) + c_last # decode dec_outs, dec_last, dec_ctx = self.decoder(batch_size, out_utts[:, 0:-1], dec_init_state, attn_context=attn_inputs, mode=TEACH_FORCE, gen_type="greedy", beam_size=self.config.beam_size) output = dec_outs.view(-1, dec_outs.size(-1)) target = labels.view(-1) enc_dec_nll = F.nll_loss(output, target, size_average=False, ignore_index=self.nll_loss.padding_idx, weight=self.nll_loss.weight, reduce=False) enc_dec_nll = enc_dec_nll.view(-1, dec_outs.size(1)) enc_dec_nll = torch.sum(enc_dec_nll, dim=1) py_c = torch.exp(log_pyc) cum_pcs += py_c nll_xcz += py_c * enc_dec_nll nll_xcz = torch.sum(nll_xcz) / all_words return Pack(nll=nll_xcz)
def sample_gumbel(self, logits, use_gpu, eps=1e-20): u = torch.rand(logits.size()) sample = Variable(-torch.log(-torch.log(u + eps) + eps)) sample = cast_type(sample, FLOAT, use_gpu) return sample
def np2var(self, inputs, dtype): if inputs is None: return None return cast_type(Variable(torch.from_numpy(inputs)), dtype, self.use_gpu)
def forward_rl(self, data_feed, max_words, temp=0.1): ctx_lens = data_feed['context_lens'] batch_size = len(ctx_lens) ctx_utts = self.np2var(data_feed['contexts'], LONG) out_utts = self.np2var(data_feed['outputs'], LONG) # context encoder c_inputs = self.utt_encoder(ctx_utts) c_outs, c_last = self.ctx_encoder(c_inputs, ctx_lens) c_last = c_last.squeeze(0) # get decoder inputs dec_inputs = out_utts[:, :-1] labels = out_utts[:, 1:].contiguous() # create decoder initial states # enc_last = torch.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) # DB infor is not fed here enc_last = c_last logits_py, log_py = self.c2z(enc_last) qy = F.softmax(logits_py / temp, dim=1) log_qy = F.log_softmax(logits_py, dim=1) idx = torch.multinomial(qy, 1).detach() logprob_sample_z = log_qy.gathcher(1, idx).view(-1) joint_logpz = torch.sum(logprob_sample_z) sample_y = cast_type(Variable(torch.zeros(log_qy.size())), FLOAT, self.use_gpu) sample_y.scatter_(1, idx, 1.0) # pack attention context if self.config.dec_use_attn: z_embeddings = torch.t(self.z_embedding.weight).split(self.k_size, dim=0) attn_context = [] temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) for z_id in range(self.y_size): attn_context.append( torch.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) attn_context = torch.cat(attn_context, dim=1) dec_init_state = torch.sum(attn_context, dim=1).unsqueeze(0) else: dec_init_state = self.z_embedding( sample_y.view(1, -1, self.config.y_size * self.config.k_size)) attn_context = None # decode if self.config.dec_rnn_cell == 'lstm': dec_init_state = tuple([dec_init_state, dec_init_state]) # decode logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, dec_init_state=dec_init_state, attn_context=attn_context, vocab=self.vocab, max_words=max_words, temp=0.1) return logprobs, outs, joint_logpz, sample_y