def forward_step(self, input_var, hidden, attn_ctxs, attn_words, ctx_embed=None): """ attn_size: number of context to attend :param input_var: :param hidden: :param attn_ctxs: batch_size x attn_size+1 x ctx_size. If None, then leave it empty :param attn_words: batch_size x attn_size :return: """ # we enable empty attention context batch_size = input_var.size(0) seq_len = input_var.size(1) embedded = self.embedding(input_var) if ctx_embed is not None: embedded += ctx_embed embedded = self.input_dropout(embedded) output, hidden = self.rnn(embedded, hidden) if attn_ctxs is None: # pointer network here logits = self.project(output.contiguous().view(-1, self.hidden_size)) predicted_softmax = F.log_softmax(logits, dim=1) return predicted_softmax, None, hidden, None, None else: attn_size = attn_words.size(1) combined_output, attn = self.attention(output, attn_ctxs) # output: batch_size x seq_len x hidden_size # attn: batch_size x seq_len x (attn_size+1) # pointer network here rnn_softmax = F.softmax(self.project(output.view(-1, self.hidden_size)), dim=1) g = attn[:, :, 0].contiguous() ptr_attn = attn[:, :, 1:].contiguous() ptr_softmax = Variable(torch.zeros((batch_size * seq_len * attn_size, self.vocab_size))) ptr_softmax = cast_type(ptr_softmax, FLOAT, self.use_gpu) # convert words and ids into 1D flat_attn_words = attn_words.unsqueeze(1).repeat(1, seq_len, 1).view(-1, 1) flat_attn = ptr_attn.view(-1, 1) # fill in the attention into ptr_softmax ptr_softmax = ptr_softmax.scatter_(1, flat_attn_words, flat_attn) ptr_softmax = ptr_softmax.view(batch_size * seq_len, attn_size, self.vocab_size) ptr_softmax = torch.sum(ptr_softmax, dim=1) # mix the softmax from rnn and pointer mixture_softmax = rnn_softmax * g.view(-1, 1) + ptr_softmax # take the log to get logsoftmax logits = torch.log(mixture_softmax.clamp(min=1e-8)) predicted_softmax = logits.view(batch_size, seq_len, -1) ptr_softmax = ptr_softmax.view(batch_size, seq_len, -1) return predicted_softmax, ptr_softmax, hidden, ptr_attn, g
def __init__(self, padding_idx, config, rev_vocab=None, key_vocab=None): super(NLLEntropy, self).__init__() self.padding_idx = padding_idx self.avg_type = config.avg_type if rev_vocab is None or key_vocab is None: self.weight = None else: self.logger.info("Use extra cost for key words") weight = np.ones(len(rev_vocab)) for key in key_vocab: weight[rev_vocab[key]] = 10.0 self.weight = cast_type(torch.from_numpy(weight), FLOAT, config.use_gpu)
def gumbel_max(self, log_probs): """ Obtain a sample from the Gumbel max. Not this is not differentibale. :param log_probs: [batch_size x vocab_size] :return: [batch_size x 1] selected token IDs """ sample = torch.Tensor(log_probs.size()).uniform_(0, 1) sample = cast_type(Variable(sample), FLOAT, self.use_gpu) # compute the gumbel sample matrix_u = -1.0 * torch.log(-1.0 * torch.log(sample)) gumbel_log_probs = log_probs + matrix_u max_val, max_ids = torch.max(gumbel_log_probs, dim=-1, keepdim=True) return max_ids
def np2var(self, inputs, dtype): if inputs is None: return None return cast_type(Variable(torch.from_numpy(inputs)), dtype, self.use_gpu)
def forward(self, batch_size, attn_context, attn_words, inputs=None, init_state=None, mode=TEACH_FORCE, gen_type='greedy', ctx_embed=None): # sanity checks ret_dict = dict() if mode == GEN: inputs = None if inputs is not None: decoder_input = inputs else: # prepare the BOS inputs bos_var = Variable(torch.LongTensor([self.sos_id]), volatile=True) bos_var = cast_type(bos_var, LONG, self.use_gpu) decoder_input = bos_var.expand(batch_size, 1) # append sentinel to the attention if attn_context is not None: attn_context = torch.cat([self.sentinel.expand(batch_size, 1, self.attn_size), attn_context], dim=1) decoder_hidden = init_state decoder_outputs = [] # a list of logprob sequence_symbols = [] # a list word ids attentions = [] pointer_gs = [] pointer_outputs = [] lengths = np.array([self.max_length] * batch_size) def decode(step, step_output): decoder_outputs.append(step_output) step_output_slice = step_output.squeeze(1) if gen_type == 'greedy': symbols = step_output_slice.topk(1)[1] elif gen_type == 'sample': symbols = self.gumbel_max(step_output_slice) else: raise ValueError("Unsupported decoding mode") sequence_symbols.append(symbols) eos_batches = symbols.data.eq(self.eos_id) if eos_batches.dim() > 0: eos_batches = eos_batches.cpu().view(-1).numpy() update_idx = ((lengths > di) & eos_batches) != 0 lengths[update_idx] = len(sequence_symbols) return symbols # Manual unrolling is used to support random teacher forcing. # If teacher_forcing_ratio is True or False instead of a probability, # the unrolling can be done in graph if mode == TEACH_FORCE: pred_softmax, ptr_softmax, decoder_hidden, attn, step_g = self.forward_step( decoder_input, decoder_hidden, attn_context, attn_words, ctx_embed) # in teach forcing mode, we don't need symbols. attentions = attn decoder_outputs = pred_softmax pointer_gs = step_g pointer_outputs = ptr_softmax else: # do free running here for di in range(self.max_length): pred_softmax, ptr_softmax, decoder_hidden, step_attn, step_g = self.forward_step( decoder_input, decoder_hidden, attn_context, attn_words, ctx_embed) symbols = decode(di, pred_softmax) # append the results into ctx dictionary attentions.append(step_attn) pointer_gs.append(step_g) pointer_outputs.append(ptr_softmax) decoder_input = symbols # make list be a tensor decoder_outputs = torch.cat(decoder_outputs, dim=1) pointer_outputs = torch.cat(pointer_outputs, dim=1) pointer_gs = torch.cat(pointer_gs, dim=1) # save the decoded sequence symbols and sequence length ret_dict[self.KEY_ATTN_SCORE] = attentions ret_dict[self.KEY_SEQUENCE] = sequence_symbols ret_dict[self.KEY_LENGTH] = lengths ret_dict[self.KEY_G] = pointer_gs ret_dict[self.KEY_PTR_SOFTMAX] = pointer_outputs ret_dict[self.KEY_PTR_CTX] = attn_words return decoder_outputs, decoder_hidden, ret_dict
def forward(self, batch_size, inputs=None, init_state=None, attn_context=None, mode=TEACH_FORCE, gen_type='greedy', beam_size=4): # sanity checks ret_dict = dict() if self.use_attention: # calculate initial attention ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list() if mode == GEN: inputs = None if gen_type != 'beam': beam_size = 1 if inputs is not None: decoder_input = inputs else: # prepare the BOS inputs bos_var = Variable(torch.LongTensor([self.sos_id]), volatile=True) bos_var = cast_type(bos_var, LONG, self.use_gpu) decoder_input = bos_var.expand(batch_size*beam_size, 1) if mode == GEN and gen_type == 'beam': # if beam search, repeat the initial states of the RNN if self.rnn_cell is nn.LSTM: h, c = init_state decoder_hidden = (self.repeat_state(h, batch_size, beam_size), self.repeat_state(c, batch_size, beam_size)) else: decoder_hidden = self.repeat_state(init_state, batch_size, beam_size) else: decoder_hidden = init_state decoder_outputs = [] # a list of logprob sequence_symbols = [] # a list word ids back_pointers = [] # a list of parent beam ID lengths = np.array([self.max_length] * batch_size * beam_size) def decode(step, cum_sum, step_output, step_attn): decoder_outputs.append(step_output) step_output_slice = step_output.squeeze(1) if self.use_attention: ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn) if gen_type == 'greedy': symbols = step_output_slice.topk(1)[1] elif gen_type == 'sample': symbols = self.gumbel_max(step_output_slice) elif gen_type == 'beam': if step == 0: seq_score = step_output_slice.view(batch_size, -1) seq_score = seq_score[:, 0:self.output_size] else: seq_score = cum_sum + step_output_slice seq_score = seq_score.view(batch_size, -1) top_v, top_id = seq_score.topk(beam_size) back_ptr = top_id.div(self.output_size).view(-1, 1) symbols = top_id.fmod(self.output_size).view(-1, 1) cum_sum = top_v.view(-1, 1) back_pointers.append(back_ptr) else: raise ValueError("Unsupported decoding mode") sequence_symbols.append(symbols) eos_batches = symbols.data.eq(self.eos_id) if eos_batches.dim() > 0: eos_batches = eos_batches.cpu().view(-1).numpy() update_idx = ((lengths > di) & eos_batches) != 0 lengths[update_idx] = len(sequence_symbols) return cum_sum, symbols # Manual unrolling is used to support random teacher forcing. # If teacher_forcing_ratio is True or False instead of a probability, # the unrolling can be done in graph if mode == TEACH_FORCE: decoder_output, decoder_hidden, attn = self.forward_step( decoder_input, decoder_hidden, attn_context) # in teach forcing mode, we don't need symbols. decoder_outputs = decoder_output else: # do free running here cum_sum = None for di in range(self.max_length): decoder_output, decoder_hidden, step_attn = self.forward_step( decoder_input, decoder_hidden, attn_context) cum_sum, symbols = decode(di, cum_sum, decoder_output, step_attn) decoder_input = symbols decoder_outputs = torch.cat(decoder_outputs, dim=1) if gen_type == 'beam': # do back tracking here to recover the 1-best according to # beam search. final_seq_symbols = [] cum_sum = cum_sum.view(-1, beam_size) max_seq_id = cum_sum.topk(1)[1].data.cpu().view(-1).numpy() rev_seq_symbols = sequence_symbols[::-1] rev_back_ptrs = back_pointers[::-1] for symbols, back_ptrs in zip(rev_seq_symbols, rev_back_ptrs): symbol2ds = symbols.view(-1, beam_size) back2ds = back_ptrs.view(-1, beam_size) selected_symbols = [] selected_parents =[] for b_id in range(batch_size): selected_parents.append(back2ds[b_id, max_seq_id[b_id]]) selected_symbols.append(symbol2ds[b_id, max_seq_id[b_id]]) final_seq_symbols.append(torch.cat(selected_symbols).unsqueeze(1)) max_seq_id = torch.cat(selected_parents).data.cpu().numpy() sequence_symbols = final_seq_symbols[::-1] # save the decoded sequence symbols and sequence length ret_dict[DecoderRNN.KEY_SEQUENCE] = sequence_symbols ret_dict[DecoderRNN.KEY_LENGTH] = lengths.tolist() return decoder_outputs, decoder_hidden, ret_dict