def _score(self, src_codes, tgt_nls): """score examples sorted by code length""" args = self.args if args.tie_embed: src_code_var = self.to_input_variable_with_unk_handling( src_codes, cuda=args.cuda).t() tgt_nl_var = self.to_input_variable_with_unk_handling( tgt_nls, cuda=args.cuda).t() else: src_code_var = nn_utils.to_input_variable(src_codes, self.vocab.code, cuda=args.cuda).t() tgt_nl_var = nn_utils.to_input_variable(tgt_nls, self.vocab.source, cuda=args.cuda).t() src_code_mask = Variable(nn_utils.length_array_to_mask_tensor( [len(x) for x in src_codes], cuda=args.cuda, valid_entry_has_mask_one=True).float(), requires_grad=False) tgt_nl_mask = Variable(nn_utils.length_array_to_mask_tensor( [len(x) for x in tgt_nls], cuda=args.cuda, valid_entry_has_mask_one=True).float(), requires_grad=False) scores = self.pi_model(src_code_var, tgt_nl_var, src_code_mask, tgt_nl_mask) return scores
def _score(self, src_codes, tgt_nls): """score examples sorted by code length""" args = self.args # src_code = [self.tokenize_code(e.tgt_code) for e in examples] # tgt_nl = [e.src_sent for e in examples] src_code_var = nn_utils.to_input_variable(src_codes, self.vocab.code, cuda=args.cuda) tgt_nl_var = nn_utils.to_input_variable(tgt_nls, self.vocab.source, cuda=args.cuda, append_boundary_sym=True) tgt_token_copy_idx_mask, tgt_token_gen_mask = self.get_generate_and_copy_meta_tensor( src_codes, tgt_nls) if isinstance(self.seq2seq, Seq2SeqWithCopy): scores = self.seq2seq(src_code_var, [len(c) for c in src_codes], tgt_nl_var, tgt_token_copy_idx_mask, tgt_token_gen_mask) else: scores = self.seq2seq(src_code_var, [len(c) for c in src_codes], tgt_nl_var) return scores
def __call__(self, code_list): # we assume the code is generated from astor and therefore has an astor style! code_tokens = [self.transition_system.tokenize_code(code, mode='canonicalize') for code in code_list] code_var = nn_utils.to_input_variable(code_tokens, self.vocab, cuda=self.args.cuda, append_boundary_sym=True) return -self.forward(code_var)
def sample(self, src_sents, sample_size): src_sents_len = [len(src_sent) for src_sent in src_sents] # Variable: (src_sent_len, batch_size) src_sents_var = nn_utils.to_input_variable(src_sents, self.vocab.src, cuda=self.cuda, training=False) return self.sample_from_variable(src_sents_var, src_sents_len, sample_size)
def _score(self, src_codes, tgt_nls): """score examples sorted by code length""" args = self.args # src_code = [self.tokenize_code(e.tgt_code) for e in examples] # tgt_nl = [e.src_sent for e in examples] src_code_var = nn_utils.to_input_variable(src_codes, self.vocab.code, cuda=args.cuda) tgt_nl_var = nn_utils.to_input_variable(tgt_nls, self.vocab.source, cuda=args.cuda, append_boundary_sym=True) tgt_token_copy_pos, tgt_token_copy_mask, tgt_token_gen_mask = self.get_generate_and_copy_meta_tensor(src_codes, tgt_nls) scores = self.seq2seq(src_code_var, [len(c) for c in src_codes], tgt_nl_var, tgt_token_copy_pos, tgt_token_copy_mask, tgt_token_gen_mask) return scores
def __call__(self, code_list): # we assume the code is generated from astor and therefore has an astor style! code_tokens = [ self.transition_system.tokenize_code(code, mode='canonicalize') for code in code_list ] code_var = nn_utils.to_input_variable(code_tokens, self.vocab, cuda=self.args.cuda, append_boundary_sym=True) return -self.forward(code_var)
def evaluate_ppl(): model.eval() cum_loss = 0. cum_tgt_words = 0. for examples in nn_utils.batch_iter(dev_set.examples, args.batch_size): batch_tokens = [transition_system.tokenize_code(e.tgt_code) for e in examples] batch = nn_utils.to_input_variable(batch_tokens, vocab, cuda=args.cuda, append_boundary_sym=True) loss = model.forward(batch).sum() cum_loss += loss.data[0] cum_tgt_words += sum(len(tokens) + 1 for tokens in batch_tokens) # add ending </s> ppl = np.exp(cum_loss / cum_tgt_words) model.train() return ppl
def evaluate_ppl(): model.eval() cum_loss = 0. cum_tgt_words = 0. for batch in dev_set.batch_iter(args.batch_size): src_sents_var = nn_utils.to_input_variable( [e.src_sent for e in batch], vocab.source, cuda=args.cuda, append_boundary_sym=True) loss = model(src_sents_var).sum() cum_loss += loss.data[0] cum_tgt_words += sum(len(e.src_sent) + 1 for e in batch) # add ending </s> ppl = np.exp(cum_loss / cum_tgt_words) model.train() return ppl
def parse(self, src_sent, context=None, beam_size=5): """Perform beam search to infer the target AST given a source utterance Args: src_sent: list of source utterance tokens context: other context used for prediction beam_size: beam size Returns: A list of `DecodeHypothesis`, each representing an AST """ # print('!!!!!!') args = self.args primitive_vocab = self.vocab.primitive src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False) # Variable(1, src_sent_len, hidden_size * 2) src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)]) # (1, src_sent_len, hidden_size) src_encodings_att_linear = self.att_src_linear(src_encodings) dec_init_vec = self.init_decoder_state(last_state, last_cell) if args.lstm == 'parent_feed': h_tm1 = dec_init_vec[0], dec_init_vec[1], \ Variable(self.new_tensor(args.hidden_size).zero_()), \ Variable(self.new_tensor(args.hidden_size).zero_()) else: h_tm1 = dec_init_vec zero_action_embed = Variable(self.new_tensor(args.action_embed_size).zero_()) hyp_scores = Variable(self.new_tensor([0.]), volatile=True) src_token_vocab_ids = [primitive_vocab[token] for token in src_sent] src_unk_pos_list = [pos for pos, token_id in enumerate(src_token_vocab_ids) if token_id == primitive_vocab.unk_id] # sometimes a word may appear multi-times in the source, in this case, # we just copy its first appearing position. Therefore we mask the words # appearing second and onwards to -1 token_set = set() for i, tid in enumerate(src_token_vocab_ids): if tid in token_set: src_token_vocab_ids[i] = -1 else: token_set.add(tid) t = 0 hypotheses = [DecodeHypothesis()] hyp_states = [[]] completed_hypotheses = [] while len(completed_hypotheses) < beam_size and t < args.decode_max_time_step: hyp_num = len(hypotheses) # (hyp_num, src_sent_len, hidden_size * 2) exp_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1), src_encodings.size(2)) # (hyp_num, src_sent_len, hidden_size) exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2)) if t == 0: x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True) if args.no_parent_field_type_embed is False: offset = args.action_embed_size # prev_action offset += args.att_vec_size * (not args.no_input_feed) offset += args.action_embed_size * (not args.no_parent_production_embed) offset += args.field_embed_size * (not args.no_parent_field_embed) x[0, offset: offset + args.type_embed_size] = \ self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]] else: actions_tm1 = [hyp.actions[-1] for hyp in hypotheses] a_tm1_embeds = [] for a_tm1 in actions_tm1: if a_tm1: if isinstance(a_tm1, ApplyRuleAction): a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[a_tm1.production]] elif isinstance(a_tm1, ReduceAction): a_tm1_embed = self.production_embed.weight[len(self.grammar)] else: a_tm1_embed = self.primitive_embed.weight[self.vocab.primitive[a_tm1.token]] a_tm1_embeds.append(a_tm1_embed) else: a_tm1_embeds.append(zero_action_embed) a_tm1_embeds = torch.stack(a_tm1_embeds) inputs = [a_tm1_embeds] if args.no_input_feed is False: inputs.append(att_tm1) if args.no_parent_production_embed is False: # frontier production frontier_prods = [hyp.frontier_node.production for hyp in hypotheses] frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor( [self.grammar.prod2id[prod] for prod in frontier_prods]))) inputs.append(frontier_prod_embeds) if args.no_parent_field_embed is False: # frontier field frontier_fields = [hyp.frontier_field.field for hyp in hypotheses] frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([ self.grammar.field2id[field] for field in frontier_fields]))) inputs.append(frontier_field_embeds) if args.no_parent_field_type_embed is False: # frontier field type frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses] frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([ self.grammar.type2id[type] for type in frontier_field_types]))) inputs.append(frontier_field_type_embeds) # parent states if args.no_parent_state is False: p_ts = [hyp.frontier_node.created_time for hyp in hypotheses] parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)]) parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)]) if args.lstm == 'parent_feed': h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells) else: inputs.append(parent_states) x = torch.cat(inputs, dim=-1) if args.lstm == 'lstm_with_dropout': self.decoder_lstm.set_dropout_masks(hyp_num) (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings, exp_src_encodings_att_linear, src_token_mask=None) # Variable(batch_size, grammar_size) # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1)) apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1) # Variable(batch_size, src_sent_len) primitive_copy_prob = self.src_pointer_net(src_encodings, None, att_t.unsqueeze(0)).squeeze(0) # Variable(batch_size, primitive_vocab_size) gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1) # Variable(batch_size, 2) primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(1) * gen_from_vocab_prob if src_unk_pos_list: primitive_prob[:, primitive_vocab.unk_id] = 1.e-10 gentoken_prev_hyp_ids = [] gentoken_new_hyp_unks = [] gentoken_copy_infos = [] applyrule_new_hyp_scores = [] applyrule_new_hyp_prod_ids = [] applyrule_prev_hyp_ids = [] for hyp_id, hyp in enumerate(hypotheses): # generate new continuations action_types = self.transition_system.get_valid_continuation_types(hyp) for action_type in action_types: if action_type == ApplyRuleAction: productions = self.transition_system.get_valid_continuating_productions(hyp) for production in productions: prod_id = self.grammar.prod2id[production] prod_score = apply_rule_log_prob[hyp_id, prod_id].data[0] # print(type(hyp.score),type(prod_score)) new_hyp_score = hyp.score + prod_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(prod_id) applyrule_prev_hyp_ids.append(hyp_id) elif action_type == ReduceAction: action_score = apply_rule_log_prob[hyp_id, len(self.grammar)].data[0] new_hyp_score = hyp.score + action_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(len(self.grammar)) applyrule_prev_hyp_ids.append(hyp_id) else: # GenToken action gentoken_prev_hyp_ids.append(hyp_id) hyp_copy_info = dict() # of (token_pos, copy_prob) # first, we compute copy probabilities for tokens in the source sentence for token_pos, token_vocab_id in enumerate(src_token_vocab_ids): if args.no_copy is False and token_vocab_id != -1 and token_vocab_id != primitive_vocab.unk_id: p_copy = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, token_pos] primitive_prob[hyp_id, token_vocab_id] = primitive_prob[hyp_id, token_vocab_id] + p_copy token = src_sent[token_pos] hyp_copy_info[token] = (token_pos, p_copy.data[0]) # second, add the probability of copying the most probable unk word if args.no_copy is False and src_unk_pos_list: unk_pos = primitive_copy_prob[hyp_id][src_unk_pos_list].data.cpu().numpy().argmax() unk_pos = src_unk_pos_list[unk_pos] token = src_sent[unk_pos] gentoken_new_hyp_unks.append(token) unk_copy_score = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, unk_pos] primitive_prob[hyp_id, primitive_vocab.unk_id] = unk_copy_score hyp_copy_info[token] = (unk_pos, unk_copy_score.data[0]) gentoken_copy_infos.append(hyp_copy_info) new_hyp_scores = None if applyrule_new_hyp_scores: new_hyp_scores = Variable(self.new_tensor(applyrule_new_hyp_scores)) if gentoken_prev_hyp_ids: primitive_log_prob = torch.log(primitive_prob) gen_token_new_hyp_scores = (hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1) if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores else: new_hyp_scores = torch.cat([new_hyp_scores, gen_token_new_hyp_scores]) top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores, k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses))) live_hyp_ids = [] new_hypotheses = [] for new_hyp_score, new_hyp_pos in zip(top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()): # for new_hyp_score, new_hyp_pos in zip(top_new_hyp_scores.data.cuda(), top_new_hyp_pos.data.cuda()): action_info = ActionInfo() if new_hyp_pos < len(applyrule_new_hyp_scores): # it's an ApplyRule or Reduce action prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos] prev_hyp = hypotheses[prev_hyp_id] prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos] # ApplyRule action if prod_id < len(self.grammar): production = self.grammar.id2prod[prod_id] action = ApplyRuleAction(production) # Reduce action else: action = ReduceAction() else: # it's a GenToken action token_id = (new_hyp_pos - len(applyrule_new_hyp_scores)) % primitive_prob.size(1) k = (new_hyp_pos - len(applyrule_new_hyp_scores)) // primitive_prob.size(1) # try: copy_info = gentoken_copy_infos[k] prev_hyp_id = gentoken_prev_hyp_ids[k] prev_hyp = hypotheses[prev_hyp_id] # except: # print('k=%d' % k, file=sys.stderr) # print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr) # print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr) # print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr) # print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr) # print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr) # print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr) # print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr) # print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr) # print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr) # # torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin') # # # exit(-1) # raise ValueError() if token_id == primitive_vocab.unk_id: if gentoken_new_hyp_unks: token = gentoken_new_hyp_unks[k] else: token = primitive_vocab.id2word[primitive_vocab.unk_id] else: token = primitive_vocab.id2word[token_id] action = GenTokenAction(token) if token in copy_info: action_info.copy_from_src = True action_info.src_token_position = copy_info[token][0] action_info.action = action action_info.t = t if t > 0: action_info.parent_t = prev_hyp.frontier_node.created_time action_info.frontier_prod = prev_hyp.frontier_node.production action_info.frontier_field = prev_hyp.frontier_field.field new_hyp = prev_hyp.clone_and_apply_action_info(action_info) new_hyp.score = new_hyp_score #advoid none new hyp if new_hyp.completed: # print('What happened!') # print(new_hyp.actions) # print(len(new_hyp.actions)) if len(new_hyp.actions) != 2: completed_hypotheses.append(new_hyp) else: new_hypotheses.append(new_hyp) live_hyp_ids.append(prev_hyp_id) if live_hyp_ids: hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids] h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) att_tm1 = att_t[live_hyp_ids] hypotheses = new_hypotheses hyp_scores = Variable(self.new_tensor([hyp.score for hyp in hypotheses])) t += 1 else: break completed_hypotheses.sort(key=lambda hyp: -hyp.score) return completed_hypotheses
def parse(self, question, context, beam_size=5): table = context args = self.args src_sent_var = nn_utils.to_input_variable([question], self.vocab.source, cuda=self.args.cuda, training=False) utterance_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(question)]) dec_init_vec = self.init_decoder_state(last_state, last_cell) column_word_encodings, table_header_encoding, table_header_mask = self.encode_table_header([table]) h_tm1 = dec_init_vec # (batch_size, query_len, hidden_size) utterance_encodings_att_linear = self.att_src_linear(utterance_encodings) zero_action_embed = Variable(self.new_tensor(self.args.action_embed_size).zero_()) t = 0 hypotheses = [DecodeHypothesis()] hyp_states = [[]] completed_hypotheses = [] while len(completed_hypotheses) < beam_size and t < self.args.decode_max_time_step: hyp_num = len(hypotheses) # (hyp_num, src_sent_len, hidden_size * 2) exp_src_encodings = utterance_encodings.expand(hyp_num, utterance_encodings.size(1), utterance_encodings.size(2)) # (hyp_num, src_sent_len, hidden_size) exp_src_encodings_att_linear = utterance_encodings_att_linear.expand(hyp_num, utterance_encodings_att_linear.size(1), utterance_encodings_att_linear.size(2)) # x: [prev_action, parent_production_embed, parent_field_embed, parent_field_type_embed, parent_action_state] if t == 0: x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True) if args.no_parent_field_type_embed is False: offset = args.action_embed_size # prev_action offset += args.hidden_size * (not args.no_input_feed) offset += args.action_embed_size * (not args.no_parent_production_embed) offset += args.field_embed_size * (not args.no_parent_field_embed) x[0, offset: offset + args.type_embed_size] = \ self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]] else: a_tm1_embeds = [] for e_id, hyp in enumerate(hypotheses): action_tm1 = hyp.actions[-1] if action_tm1: if isinstance(action_tm1, ApplyRuleAction): a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]] elif isinstance(action_tm1, ReduceAction): a_tm1_embed = self.production_embed.weight[len(self.grammar)] elif isinstance(action_tm1, WikiSqlSelectColumnAction): a_tm1_embed = self.column_rnn_input(table_header_encoding[0, action_tm1.column_id]) elif isinstance(action_tm1, GenTokenAction): a_tm1_embed = self.src_embed.weight[self.vocab.source[action_tm1.token]] else: raise ValueError('unknown action %s' % action_tm1) else: a_tm1_embed = zero_action_embed a_tm1_embeds.append(a_tm1_embed) a_tm1_embeds = torch.stack(a_tm1_embeds) inputs = [a_tm1_embeds] if args.no_input_feed is False: inputs.append(att_tm1) if args.no_parent_production_embed is False: # frontier production frontier_prods = [hyp.frontier_node.production for hyp in hypotheses] frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor( [self.grammar.prod2id[prod] for prod in frontier_prods]))) inputs.append(frontier_prod_embeds) if args.no_parent_field_embed is False: # frontier field frontier_fields = [hyp.frontier_field.field for hyp in hypotheses] frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([ self.grammar.field2id[field] for field in frontier_fields]))) inputs.append(frontier_field_embeds) if args.no_parent_field_type_embed is False: # frontier field type frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses] frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([ self.grammar.type2id[type] for type in frontier_field_types]))) inputs.append(frontier_field_type_embeds) # parent states if args.no_parent_state is False: p_ts = [hyp.frontier_node.created_time for hyp in hypotheses] parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)]) parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)]) if args.lstm == 'parent_feed': h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells) else: inputs.append(parent_states) x = torch.cat(inputs, dim=-1) (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings, exp_src_encodings_att_linear, src_token_mask=None) # ApplyRule action probability # (batch_size, grammar_size) apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1) # column attention # (batch_size, max_head_num) column_attention_weights = self.column_pointer_net(table_header_encoding, table_header_mask, att_t.unsqueeze(0)).squeeze(0) column_selection_log_prob = torch.log(column_attention_weights) # (batch_size, 2) primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1) # primitive copy prob # (batch_size, src_token_num) primitive_copy_prob = self.src_pointer_net(utterance_encodings, None, att_t.unsqueeze(0)).squeeze(0) # (batch_size, primitive_vocab_size) primitive_gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1) new_hyp_meta = [] for hyp_id, hyp in enumerate(hypotheses): # generate new continuations action_types = self.transition_system.get_valid_continuation_types(hyp) for action_type in action_types: if action_type == ApplyRuleAction: productions = self.transition_system.get_valid_continuating_productions(hyp) for production in productions: prod_id = self.grammar.prod2id[production] prod_score = apply_rule_log_prob[hyp_id, prod_id] new_hyp_score = hyp.score + prod_score meta_entry = {'action_type': 'apply_rule', 'prod_id': prod_id, 'score': prod_score, 'new_hyp_score': new_hyp_score, 'prev_hyp_id': hyp_id} new_hyp_meta.append(meta_entry) elif action_type == ReduceAction: action_score = apply_rule_log_prob[hyp_id, len(self.grammar)] new_hyp_score = hyp.score + action_score meta_entry = {'action_type': 'apply_rule', 'prod_id': len(self.grammar), 'score': action_score, 'new_hyp_score': new_hyp_score, 'prev_hyp_id': hyp_id} new_hyp_meta.append(meta_entry) elif action_type == WikiSqlSelectColumnAction: for col_id, column in enumerate(table.header): col_sel_score = column_selection_log_prob[hyp_id, col_id] new_hyp_score = hyp.score + col_sel_score meta_entry = {'action_type': 'sel_col', 'col_id': col_id, 'score': col_sel_score, 'new_hyp_score': new_hyp_score, 'prev_hyp_id': hyp_id} new_hyp_meta.append(meta_entry) elif action_type == GenTokenAction: # remember that we can only copy stuff from the input! # we only copy tokens sequentially!! prev_action = hyp.action_infos[-1].action valid_token_pos_list = [] if type(prev_action) is GenTokenAction and \ not prev_action.is_stop_signal(): token_pos = hyp.action_infos[-1].src_token_position + 1 if token_pos < len(question): valid_token_pos_list = [token_pos] else: valid_token_pos_list = list(range(len(question))) col_id = hyp.frontier_node['col_idx'].value if table.header[col_id].type == 'real': valid_token_pos_list = [i for i in valid_token_pos_list if any(c.isdigit() for c in question[i]) or hyp._value_buffer and question[i] in (',', '.', '-', '%')] p_copies = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id] for token_pos in valid_token_pos_list: token = question[token_pos] p_copy = p_copies[token_pos] score_copy = torch.log(p_copy) meta_entry = {'action_type': 'gen_token', 'token': token, 'token_pos': token_pos, 'score': score_copy, 'new_hyp_score': score_copy + hyp.score, 'prev_hyp_id': hyp_id} new_hyp_meta.append(meta_entry) # add generation probability for </primitive> if hyp._value_buffer: eos_prob = primitive_predictor_prob[hyp_id, 0] * \ primitive_gen_from_vocab_prob[hyp_id, self.vocab.primitive['</primitive>']] eos_score = torch.log(eos_prob) meta_entry = {'action_type': 'gen_token', 'token': '</primitive>', 'score': eos_score, 'new_hyp_score': eos_score + hyp.score, 'prev_hyp_id': hyp_id} new_hyp_meta.append(meta_entry) if not new_hyp_meta: break new_hyp_scores = torch.cat([x['new_hyp_score'] for x in new_hyp_meta]) top_new_hyp_scores, meta_ids = torch.topk(new_hyp_scores, k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses))) live_hyp_ids = [] new_hypotheses = [] for new_hyp_score, meta_id in zip(top_new_hyp_scores.data.cpu(), meta_ids.data.cpu()): action_info = ActionInfo() hyp_meta_entry = new_hyp_meta[meta_id] prev_hyp_id = hyp_meta_entry['prev_hyp_id'] prev_hyp = hypotheses[prev_hyp_id] action_type_str = hyp_meta_entry['action_type'] if action_type_str == 'apply_rule': # ApplyRule action prod_id = hyp_meta_entry['prod_id'] if prod_id < len(self.grammar): production = self.grammar.id2prod[prod_id] action = ApplyRuleAction(production) # Reduce action else: action = ReduceAction() elif action_type_str == 'sel_col': action = WikiSqlSelectColumnAction(hyp_meta_entry['col_id']) else: action = GenTokenAction(hyp_meta_entry['token']) if 'token_pos' in hyp_meta_entry: action_info.copy_from_src = True action_info.src_token_position = hyp_meta_entry['token_pos'] action_info.action = action action_info.t = t if t > 0: action_info.parent_t = prev_hyp.frontier_node.created_time action_info.frontier_prod = prev_hyp.frontier_node.production action_info.frontier_field = prev_hyp.frontier_field.field new_hyp = prev_hyp.clone_and_apply_action_info(action_info) new_hyp.score = new_hyp_score if new_hyp.completed: completed_hypotheses.append(new_hyp) else: new_hypotheses.append(new_hyp) live_hyp_ids.append(prev_hyp_id) if live_hyp_ids: hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids] h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) att_tm1 = att_t[live_hyp_ids] hypotheses = new_hypotheses t += 1 else: break completed_hypotheses.sort(key=lambda hyp: -hyp.score) return completed_hypotheses
def beam_search(self, src_sents, decode_max_time_step, beam_size=5, to_word=True): """ given a not-batched source, sentence perform beam search to find the n-best :param src_sent: List[word_id], encoded source sentence :return: list[list[word_id]] top-k predicted natural language sentence in the beam """ src_sents_var = nn_utils.to_input_variable(src_sents, self.src_vocab, cuda=self.cuda, training=False, append_boundary_sym=False) #TODO(junxian): check if src_sents_var(src_seq_length, embed_size) is ok src_encodings, (last_state, last_cell) = self.encode(src_sents_var, [len(src_sents[0])]) # (1, query_len, hidden_size * 2) src_encodings = src_encodings.permute(1, 0, 2) src_encodings_att_linear = self.att_src_linear(src_encodings) h_tm1 = self.init_decoder_state(last_state, last_cell) # tensor constructors new_float_tensor = src_encodings.data.new if self.cuda: new_long_tensor = torch.cuda.LongTensor else: new_long_tensor = torch.LongTensor att_tm1 = Variable(torch.zeros(1, self.hidden_size), volatile=True) hyp_scores = Variable(torch.zeros(1), volatile=True) if self.cuda: att_tm1 = att_tm1.cuda() hyp_scores = hyp_scores.cuda() eos_id = self.tgt_vocab['</s>'] bos_id = self.tgt_vocab['<s>'] tgt_vocab_size = len(self.tgt_vocab) hypotheses = [[bos_id]] completed_hypotheses = [] completed_hypothesis_scores = [] t = 0 while len( completed_hypotheses) < beam_size and t < decode_max_time_step: t += 1 hyp_num = len(hypotheses) expanded_src_encodings = src_encodings.expand( hyp_num, src_encodings.size(1), src_encodings.size(2)) expanded_src_encodings_att_linear = src_encodings_att_linear.expand( hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2)) y_tm1 = Variable(new_long_tensor([hyp[-1] for hyp in hypotheses]), volatile=True) y_tm1_embed = self.tgt_embed(y_tm1) x = torch.cat([y_tm1_embed, att_tm1], 1) # h_t: (hyp_num, hidden_size) (h_t, cell_t), att_t, score_t = self.step( x, h_tm1, expanded_src_encodings, expanded_src_encodings_att_linear, src_sent_masks=None) p_t = F.log_softmax(score_t) live_hyp_num = beam_size - len(completed_hypotheses) new_hyp_scores = (hyp_scores.unsqueeze(1).expand_as(p_t) + p_t).view(-1) top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores, k=live_hyp_num) prev_hyp_ids = top_new_hyp_pos / tgt_vocab_size word_ids = top_new_hyp_pos % tgt_vocab_size new_hypotheses = [] live_hyp_ids = [] new_hyp_scores = [] for prev_hyp_id, word_id, new_hyp_score in zip( prev_hyp_ids.cpu().data, word_ids.cpu().data, top_new_hyp_scores.cpu().data): hyp_tgt_words = hypotheses[prev_hyp_id] + [word_id] if word_id == eos_id: completed_hypotheses.append( hyp_tgt_words[1:-1] ) # remove <s> and </s> in completed hypothesis completed_hypothesis_scores.append(new_hyp_score) else: new_hypotheses.append(hyp_tgt_words) live_hyp_ids.append(prev_hyp_id) new_hyp_scores.append(new_hyp_score) if len(completed_hypotheses) == beam_size: break live_hyp_ids = new_long_tensor(live_hyp_ids) h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) att_tm1 = att_t[live_hyp_ids] hyp_scores = Variable( new_float_tensor(new_hyp_scores), volatile=True) # new_hyp_scores[live_hyp_ids] hypotheses = new_hypotheses if len(completed_hypotheses) == 0: completed_hypotheses = [ hypotheses[0][1:-1] ] # remove <s> and </s> in completed hypothesis completed_hypothesis_scores = [0.0] if to_word: for i, hyp in enumerate(completed_hypotheses): completed_hypotheses[i] = [ self.tgt_vocab.id2word[w] for w in hyp ] ranked_hypotheses = sorted(zip(completed_hypotheses, completed_hypothesis_scores), key=lambda x: x[1], reverse=True) return [hyp for hyp, score in ranked_hypotheses]
def parse(self, src_sent, context=None, beam_size=5, debug=False): """Perform beam search to infer the target AST given a source utterance Args: src_sent: list of source utterance tokens context: other context used for prediction beam_size: beam size Returns: A list of `DecodeHypothesis`, each representing an AST """ with torch.no_grad(): args = self.args primitive_vocab = self.vocab.primitive src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False) # Variable(1, src_sent_len, hidden_size) src_encodings = self.encode(src_sent_var) zero_action_embed = torch.zeros(args.action_embed_size) hyp_scores = torch.tensor([0.0]) # For computing copy probabilities, we marginalize over tokens with the same surface form # `aggregated_primitive_tokens` stores the position of occurrence of each source token aggregated_primitive_tokens = OrderedDict() for token_pos, token in enumerate(src_sent): aggregated_primitive_tokens.setdefault(token, []).append(token_pos) t = 0 hypotheses = [DecodeHypothesis()] completed_hypotheses = [] while len(completed_hypotheses ) < beam_size and t < args.decode_max_time_step: hyp_num = len(hypotheses) # (hyp_num, src_sent_len, hidden_size) exp_src_encodings = src_encodings.expand( hyp_num, src_encodings.size(1), src_encodings.size(2)) if t == 0: x = torch.zeros(1, self.d_model) parent_ids = np.array([[0]]) if args.no_parent_field_type_embed is False: offset = self.args.action_embed_size # prev_action offset += self.args.action_embed_size * ( not self.args.no_parent_production_embed) offset += self.args.field_embed_size * ( not self.args.no_parent_field_embed) x[0, offset:offset + self.type_embed_size] = self.type_embed( torch.tensor(self.grammar.type2id[ self.grammar.root_type])) x = x.unsqueeze(-2) else: actions_tm1 = [hyp.actions[-1] for hyp in hypotheses] a_tm1_embeds = [] for a_tm1 in actions_tm1: if a_tm1: if isinstance(a_tm1, ApplyRuleAction): a_tm1_embed = self.production_embed( torch.tensor(self.grammar.prod2id[ a_tm1.production])) elif isinstance(a_tm1, ReduceAction): a_tm1_embed = self.production_embed( torch.tensor(len(self.grammar))) else: a_tm1_embed = self.primitive_embed( torch.tensor( self.vocab.primitive[a_tm1.token])) a_tm1_embeds.append(a_tm1_embed) else: a_tm1_embeds.append(zero_action_embed) a_tm1_embeds = torch.stack(a_tm1_embeds) inputs = [a_tm1_embeds] if args.no_parent_production_embed is False: # frontier production frontier_prods = [ hyp.frontier_node.production for hyp in hypotheses ] frontier_prod_embeds = self.production_embed( torch.tensor([ self.grammar.prod2id[prod] for prod in frontier_prods ], dtype=torch.long)) inputs.append(frontier_prod_embeds) if args.no_parent_field_embed is False: # frontier field frontier_fields = [ hyp.frontier_field.field for hyp in hypotheses ] frontier_field_embeds = self.field_embed( torch.tensor([ self.grammar.field2id[field] for field in frontier_fields ], dtype=torch.long)) inputs.append(frontier_field_embeds) if args.no_parent_field_type_embed is False: # frontier field type frontier_field_types = [ hyp.frontier_field.type for hyp in hypotheses ] frontier_field_type_embeds = self.type_embed( torch.tensor([ self.grammar.type2id[type] for type in frontier_field_types ], dtype=torch.long)) inputs.append(frontier_field_type_embeds) x = torch.cat( [x, torch.cat(inputs, dim=-1).unsqueeze(-2)], dim=1) recent_parents = np.array( [[hyp.frontier_node.created_time] if hyp.frontier_node else 0 for hyp in hypotheses]) parent_ids = np.hstack([parent_ids, recent_parents]) src_mask = torch.ones( exp_src_encodings.shape[:-1], dtype=torch.uint8, device=exp_src_encodings.device).unsqueeze(-2) tgt_mask = subsequent_mask(x.shape[-2]) att_t = self.decoder(x, exp_src_encodings, [parent_ids], src_mask, tgt_mask)[:, -1] # Variable(batch_size, grammar_size) # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1)) apply_rule_log_prob = F.log_softmax( self.production_readout(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1) if args.no_copy: primitive_prob = gen_from_vocab_prob else: # Variable(batch_size, src_sent_len) primitive_copy_prob = self.src_pointer_net( src_encodings, None, att_t.unsqueeze(0)).squeeze(0) # Variable(batch_size, 2) primitive_predictor_prob = F.softmax( self.primitive_predictor(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) primitive_prob = primitive_predictor_prob[:, 0].unsqueeze( 1) * gen_from_vocab_prob # if src_unk_pos_list: # primitive_prob[:, primitive_vocab.unk_id] = 1.e-10 gentoken_prev_hyp_ids = [] gentoken_new_hyp_unks = [] applyrule_new_hyp_scores = [] applyrule_new_hyp_prod_ids = [] applyrule_prev_hyp_ids = [] for hyp_id, hyp in enumerate(hypotheses): # generate new continuations action_types = self.transition_system.get_valid_continuation_types( hyp) for action_type in action_types: if action_type == ApplyRuleAction: productions = self.transition_system.get_valid_continuating_productions( hyp) for production in productions: prod_id = self.grammar.prod2id[production] prod_score = apply_rule_log_prob[ hyp_id, prod_id].item() new_hyp_score = hyp.score + prod_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(prod_id) applyrule_prev_hyp_ids.append(hyp_id) elif action_type == ReduceAction: action_score = apply_rule_log_prob[ hyp_id, len(self.grammar)].item() new_hyp_score = hyp.score + action_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(len( self.grammar)) applyrule_prev_hyp_ids.append(hyp_id) else: # GenToken action gentoken_prev_hyp_ids.append(hyp_id) hyp_copy_info = dict() # of (token_pos, copy_prob) hyp_unk_copy_info = [] if args.no_copy is False: for (token, token_pos_list ) in aggregated_primitive_tokens.items(): sum_copy_prob = torch.gather( primitive_copy_prob[hyp_id], 0, torch.tensor(token_pos_list, dtype=torch.long)).sum() gated_copy_prob = primitive_predictor_prob[ hyp_id, 1] * sum_copy_prob if token in primitive_vocab: token_id = primitive_vocab[token] primitive_prob[hyp_id, token_id] = ( primitive_prob[hyp_id, token_id] + gated_copy_prob) hyp_copy_info[token] = ( token_pos_list, gated_copy_prob.item()) else: hyp_unk_copy_info.append({ "token": token, "token_pos_list": token_pos_list, "copy_prob": gated_copy_prob.item(), }) if args.no_copy is False and len( hyp_unk_copy_info) > 0: unk_i = np.array([ x["copy_prob"] for x in hyp_unk_copy_info ]).argmax() token = hyp_unk_copy_info[unk_i]["token"] primitive_prob[hyp_id, primitive_vocab. unk_id] = hyp_unk_copy_info[ unk_i]["copy_prob"] gentoken_new_hyp_unks.append(token) hyp_copy_info[token] = ( hyp_unk_copy_info[unk_i]["token_pos_list"], hyp_unk_copy_info[unk_i]["copy_prob"], ) new_hyp_scores = None if applyrule_new_hyp_scores: new_hyp_scores = torch.tensor(applyrule_new_hyp_scores) if gentoken_prev_hyp_ids: primitive_log_prob = torch.log(primitive_prob) gen_token_new_hyp_scores = ( hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1) if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores else: new_hyp_scores = torch.cat( [new_hyp_scores, gen_token_new_hyp_scores]) top_new_hyp_scores, top_new_hyp_pos = torch.topk( new_hyp_scores, k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses))) live_hyp_ids = [] new_hypotheses = [] for new_hyp_score, new_hyp_pos in zip( top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()): action_info = ActionInfo() if new_hyp_pos < len(applyrule_new_hyp_scores): # it's an ApplyRule or Reduce action prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos] prev_hyp = hypotheses[prev_hyp_id] prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos] # ApplyRule action if prod_id < len(self.grammar): production = self.grammar.id2prod[prod_id] action = ApplyRuleAction(production) # Reduce action else: action = ReduceAction() else: # it's a GenToken action token_id = int( (new_hyp_pos - len(applyrule_new_hyp_scores)) % primitive_prob.size(1)) k = (new_hyp_pos - len(applyrule_new_hyp_scores) ) // primitive_prob.size(1) # try: # copy_info = gentoken_copy_infos[k] prev_hyp_id = gentoken_prev_hyp_ids[k] prev_hyp = hypotheses[prev_hyp_id] # except: # print('k=%d' % k, file=sys.stderr) # print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr) # print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr) # print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr) # print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr) # print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr) # print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr) # print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr) # print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr) # print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr) # # torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin') # # # exit(-1) # raise ValueError() if token_id == int(primitive_vocab.unk_id): if gentoken_new_hyp_unks: token = gentoken_new_hyp_unks[k] else: token = primitive_vocab.id2word[ primitive_vocab.unk_id] else: token = primitive_vocab.id2word[token_id] action = GenTokenAction(token) if token in aggregated_primitive_tokens: action_info.copy_from_src = True action_info.src_token_position = aggregated_primitive_tokens[ token] if debug: action_info.gen_copy_switch = ( "n/a" if args.no_copy else primitive_predictor_prob[ prev_hyp_id, :].log().cpu().data.numpy()) action_info.in_vocab = token in primitive_vocab action_info.gen_token_prob = ( gen_from_vocab_prob[prev_hyp_id, token_id].log().cpu(). item() if token in primitive_vocab else "n/a") action_info.copy_token_prob = ( torch.gather( primitive_copy_prob[prev_hyp_id], 0, torch.tensor( action_info.src_token_position, dtype=torch.long, device=self.device), ).sum().log().cpu().item() if args.no_copy is False and action_info.copy_from_src else "n/a") action_info.action = action action_info.t = t if t > 0: action_info.parent_t = prev_hyp.frontier_node.created_time action_info.frontier_prod = prev_hyp.frontier_node.production action_info.frontier_field = prev_hyp.frontier_field.field if debug: action_info.action_prob = new_hyp_score - prev_hyp.score new_hyp = prev_hyp.clone_and_apply_action_info(action_info) new_hyp.score = new_hyp_score if new_hyp.completed: completed_hypotheses.append(new_hyp) else: new_hypotheses.append(new_hyp) live_hyp_ids.append(prev_hyp_id) if live_hyp_ids: x = x[live_hyp_ids] parent_ids = parent_ids[live_hyp_ids] hypotheses = new_hypotheses hyp_scores = torch.tensor( [hyp.score for hyp in hypotheses]) t += 1 else: break completed_hypotheses.sort(key=lambda hyp: -hyp.score) return completed_hypotheses
def parse(self, src_sent, context=None, beam_size=5): """Perform beam search to infer the target AST given a source utterance Args: src_sent: list of source utterance tokens context: other context used for prediction beam_size: beam size Returns: A list of `DecodeHypothesis`, each representing an AST """ args = self.args primitive_vocab = self.vocab.primitive src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False) # Variable(1, src_sent_len, hidden_size * 2) src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)]) # (1, src_sent_len, hidden_size) src_encodings_att_linear = self.att_src_linear(src_encodings) dec_init_vec = self.init_decoder_state(last_state, last_cell) if args.lstm == 'parent_feed': h_tm1 = dec_init_vec[0], dec_init_vec[1], \ Variable(self.new_tensor(args.hidden_size).zero_()), \ Variable(self.new_tensor(args.hidden_size).zero_()) else: h_tm1 = dec_init_vec zero_action_embed = Variable(self.new_tensor(args.action_embed_size).zero_()) hyp_scores = Variable(self.new_tensor([0.]), volatile=True) src_token_vocab_ids = [primitive_vocab[token] for token in src_sent] src_unk_pos_list = [pos for pos, token_id in enumerate(src_token_vocab_ids) if token_id == primitive_vocab.unk_id] # sometimes a word may appear multi-times in the source, in this case, # we just copy its first appearing position. Therefore we mask the words # appearing second and onwards to -1 token_set = set() for i, tid in enumerate(src_token_vocab_ids): if tid in token_set: src_token_vocab_ids[i] = -1 else: token_set.add(tid) t = 0 hypotheses = [DecodeHypothesis()] hyp_states = [[]] completed_hypotheses = [] while len(completed_hypotheses) < beam_size and t < args.decode_max_time_step: hyp_num = len(hypotheses) # (hyp_num, src_sent_len, hidden_size * 2) exp_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1), src_encodings.size(2)) # (hyp_num, src_sent_len, hidden_size) exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2)) if t == 0: x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True) if args.no_parent_field_type_embed is False: offset = args.action_embed_size # prev_action offset += args.att_vec_size * (not args.no_input_feed) offset += args.action_embed_size * (not args.no_parent_production_embed) offset += args.field_embed_size * (not args.no_parent_field_embed) x[0, offset: offset + args.type_embed_size] = \ self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]] else: actions_tm1 = [hyp.actions[-1] for hyp in hypotheses] a_tm1_embeds = [] for a_tm1 in actions_tm1: if a_tm1: if isinstance(a_tm1, ApplyRuleAction): a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[a_tm1.production]] elif isinstance(a_tm1, ReduceAction): a_tm1_embed = self.production_embed.weight[len(self.grammar)] else: a_tm1_embed = self.primitive_embed.weight[self.vocab.primitive[a_tm1.token]] a_tm1_embeds.append(a_tm1_embed) else: a_tm1_embeds.append(zero_action_embed) a_tm1_embeds = torch.stack(a_tm1_embeds) inputs = [a_tm1_embeds] if args.no_input_feed is False: inputs.append(att_tm1) if args.no_parent_production_embed is False: # frontier production frontier_prods = [hyp.frontier_node.production for hyp in hypotheses] frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor( [self.grammar.prod2id[prod] for prod in frontier_prods]))) inputs.append(frontier_prod_embeds) if args.no_parent_field_embed is False: # frontier field frontier_fields = [hyp.frontier_field.field for hyp in hypotheses] frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([ self.grammar.field2id[field] for field in frontier_fields]))) inputs.append(frontier_field_embeds) if args.no_parent_field_type_embed is False: # frontier field type frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses] frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([ self.grammar.type2id[type] for type in frontier_field_types]))) inputs.append(frontier_field_type_embeds) # parent states if args.no_parent_state is False: p_ts = [hyp.frontier_node.created_time for hyp in hypotheses] parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)]) parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)]) if args.lstm == 'parent_feed': h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells) else: inputs.append(parent_states) x = torch.cat(inputs, dim=-1) if args.lstm == 'lstm_with_dropout': self.decoder_lstm.set_dropout_masks(hyp_num) (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings, exp_src_encodings_att_linear, src_token_mask=None) # Variable(batch_size, grammar_size) # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1)) apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1) # Variable(batch_size, src_sent_len) primitive_copy_prob = self.src_pointer_net(src_encodings, None, att_t.unsqueeze(0)).squeeze(0) # Variable(batch_size, primitive_vocab_size) gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1) # Variable(batch_size, 2) primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(1) * gen_from_vocab_prob if src_unk_pos_list: primitive_prob[:, primitive_vocab.unk_id] = 1.e-10 gentoken_prev_hyp_ids = [] gentoken_new_hyp_unks = [] gentoken_copy_infos = [] applyrule_new_hyp_scores = [] applyrule_new_hyp_prod_ids = [] applyrule_prev_hyp_ids = [] for hyp_id, hyp in enumerate(hypotheses): # generate new continuations action_types = self.transition_system.get_valid_continuation_types(hyp) for action_type in action_types: if action_type == ApplyRuleAction: productions = self.transition_system.get_valid_continuating_productions(hyp) for production in productions: prod_id = self.grammar.prod2id[production] prod_score = apply_rule_log_prob[hyp_id, prod_id].data[0] new_hyp_score = hyp.score + prod_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(prod_id) applyrule_prev_hyp_ids.append(hyp_id) elif action_type == ReduceAction: action_score = apply_rule_log_prob[hyp_id, len(self.grammar)].data[0] new_hyp_score = hyp.score + action_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(len(self.grammar)) applyrule_prev_hyp_ids.append(hyp_id) else: # GenToken action gentoken_prev_hyp_ids.append(hyp_id) hyp_copy_info = dict() # of (token_pos, copy_prob) # first, we compute copy probabilities for tokens in the source sentence for token_pos, token_vocab_id in enumerate(src_token_vocab_ids): if args.no_copy is False and token_vocab_id != -1 and token_vocab_id != primitive_vocab.unk_id: p_copy = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, token_pos] primitive_prob[hyp_id, token_vocab_id] = primitive_prob[hyp_id, token_vocab_id] + p_copy token = src_sent[token_pos] hyp_copy_info[token] = (token_pos, p_copy.data[0]) # second, add the probability of copying the most probable unk word if args.no_copy is False and src_unk_pos_list: unk_pos = primitive_copy_prob[hyp_id][src_unk_pos_list].data.cpu().numpy().argmax() unk_pos = src_unk_pos_list[unk_pos] token = src_sent[unk_pos] gentoken_new_hyp_unks.append(token) unk_copy_score = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, unk_pos] primitive_prob[hyp_id, primitive_vocab.unk_id] = unk_copy_score hyp_copy_info[token] = (unk_pos, unk_copy_score.data[0]) gentoken_copy_infos.append(hyp_copy_info) new_hyp_scores = None if applyrule_new_hyp_scores: new_hyp_scores = Variable(self.new_tensor(applyrule_new_hyp_scores)) if gentoken_prev_hyp_ids: primitive_log_prob = torch.log(primitive_prob) gen_token_new_hyp_scores = (hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1) if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores else: new_hyp_scores = torch.cat([new_hyp_scores, gen_token_new_hyp_scores]) top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores, k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses))) live_hyp_ids = [] new_hypotheses = [] for new_hyp_score, new_hyp_pos in zip(top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()): action_info = ActionInfo() if new_hyp_pos < len(applyrule_new_hyp_scores): # it's an ApplyRule or Reduce action prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos] prev_hyp = hypotheses[prev_hyp_id] prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos] # ApplyRule action if prod_id < len(self.grammar): production = self.grammar.id2prod[prod_id] action = ApplyRuleAction(production) # Reduce action else: action = ReduceAction() else: # it's a GenToken action token_id = (new_hyp_pos - len(applyrule_new_hyp_scores)) % primitive_prob.size(1) k = (new_hyp_pos - len(applyrule_new_hyp_scores)) // primitive_prob.size(1) # try: copy_info = gentoken_copy_infos[k] prev_hyp_id = gentoken_prev_hyp_ids[k] prev_hyp = hypotheses[prev_hyp_id] # except: # print('k=%d' % k, file=sys.stderr) # print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr) # print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr) # print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr) # print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr) # print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr) # print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr) # print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr) # print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr) # print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr) # # torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin') # # # exit(-1) # raise ValueError() if token_id == primitive_vocab.unk_id: if gentoken_new_hyp_unks: token = gentoken_new_hyp_unks[k] else: token = primitive_vocab.id2word[primitive_vocab.unk_id] else: token = primitive_vocab.id2word[token_id] action = GenTokenAction(token) if token in copy_info: action_info.copy_from_src = True action_info.src_token_position = copy_info[token][0] action_info.action = action action_info.t = t if t > 0: action_info.parent_t = prev_hyp.frontier_node.created_time action_info.frontier_prod = prev_hyp.frontier_node.production action_info.frontier_field = prev_hyp.frontier_field.field new_hyp = prev_hyp.clone_and_apply_action_info(action_info) new_hyp.score = new_hyp_score if new_hyp.completed: completed_hypotheses.append(new_hyp) else: new_hypotheses.append(new_hyp) live_hyp_ids.append(prev_hyp_id) if live_hyp_ids: hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids] h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) att_tm1 = att_t[live_hyp_ids] hypotheses = new_hypotheses hyp_scores = Variable(self.new_tensor([hyp.score for hyp in hypotheses])) t += 1 else: break completed_hypotheses.sort(key=lambda hyp: -hyp.score) return completed_hypotheses
def train(args): train_set = Dataset.from_bin_file(args.train_file) dev_set = Dataset.from_bin_file(args.dev_file) vocab = pickle.load(open(args.vocab, 'rb')) model = LSTMLanguageModel(vocab.source, args.embed_size, args.hidden_size, dropout=args.dropout) model.train() if args.cuda: model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) def evaluate_ppl(): model.eval() cum_loss = 0. cum_tgt_words = 0. for batch in dev_set.batch_iter(args.batch_size): src_sents_var = nn_utils.to_input_variable( [e.src_sent for e in batch], vocab.source, cuda=args.cuda, append_boundary_sym=True) loss = model(src_sents_var).sum() cum_loss += loss.data[0] cum_tgt_words += sum(len(e.src_sent) + 1 for e in batch) # add ending </s> ppl = np.exp(cum_loss / cum_tgt_words) model.train() return ppl print('begin training decoder, %d training examples, %d dev examples' % (len(train_set), len(dev_set))) print('vocab size: %d' % len(vocab.source)) epoch = train_iter = 0 report_loss = report_examples = 0. history_dev_scores = [] num_trial = patience = 0 while True: epoch += 1 epoch_begin = time.time() for batch_examples in train_set.batch_iter(batch_size=args.batch_size, shuffle=True): batch_examples = [ e for e in batch_examples if len(e.tgt_actions) <= 100 ] src_sents = [e.src_sent for e in batch_examples] src_sents_var = nn_utils.to_input_variable( src_sents, vocab.source, cuda=args.cuda, append_boundary_sym=True) train_iter += 1 optimizer.zero_grad() loss = model(src_sents_var) # print(loss.data) loss_val = torch.sum(loss).data[0] report_loss += loss_val report_examples += len(batch_examples) loss = torch.mean(loss) loss.backward() # clip gradient grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad) optimizer.step() if train_iter % args.log_every == 0: print('[Iter %d] encoder loss=%.5f' % (train_iter, report_loss / report_examples), file=sys.stderr) report_loss = report_examples = 0. print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr) # model_file = args.save_to + '.iter%d.bin' % train_iter # print('save model to [%s]' % model_file, file=sys.stderr) # model.save(model_file) # perform validation print('[Epoch %d] begin validation' % epoch, file=sys.stderr) eval_start = time.time() # evaluate ppl ppl = evaluate_ppl() print('[Epoch %d] ppl=%.5f took %ds' % (epoch, ppl, time.time() - eval_start), file=sys.stderr) dev_acc = -ppl is_better = history_dev_scores == [] or dev_acc > max( history_dev_scores) history_dev_scores.append(dev_acc) if is_better: patience = 0 model_file = args.save_to + '.bin' print('save currently the best model ..', file=sys.stderr) print('save model to [%s]' % model_file, file=sys.stderr) model.save(model_file) # also save the optimizers' state torch.save(optimizer.state_dict(), args.save_to + '.optim.bin') elif patience < args.patience: patience += 1 print('hit patience %d' % patience, file=sys.stderr) if patience == args.patience: num_trial += 1 print('hit #%d trial' % num_trial, file=sys.stderr) if num_trial == args.max_num_trial: print('early stop!', file=sys.stderr) exit(0) # decay lr, and restore from previously best checkpoint lr = optimizer.param_groups[0]['lr'] * args.lr_decay print( 'load previously best model and decay learning rate to %f' % lr, file=sys.stderr) # load model params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage) model.load_state_dict(params['state_dict']) if args.cuda: model = model.cuda() # load optimizers if args.reset_optimizer: print('reset optimizer', file=sys.stderr) optimizer = torch.optim.Adam( model.inference_model.parameters(), lr=lr) else: print('restore parameters of the optimizers', file=sys.stderr) optimizer.load_state_dict( torch.load(args.save_to + '.optim.bin')) # set new lr for param_group in optimizer.param_groups: param_group['lr'] = lr # reset patience patience = 0
def parse_sample(self, src_sent): # get samples of q_{\phi}(z|x) by sampling args = self.args sample_size = self.sample_size primitive_vocab = self.vocab.primitive src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False) # Variable(1, src_sent_len, hidden_size * 2) src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)]) # (1, src_sent_len, hidden_size) src_encodings_att_linear = self.att_src_linear(src_encodings) h_tm1 = self.init_decoder_state(last_state, last_cell) zero_action_embed = Variable( self.new_tensor(args.action_embed_size).zero_()) hyp_scores = Variable(self.new_tensor([0.]), volatile=True) src_token_vocab_ids = [primitive_vocab[token] for token in src_sent] src_unk_pos_list = [ pos for pos, token_id in enumerate(src_token_vocab_ids) if token_id == primitive_vocab.unk_id ] # sometimes a word may appear multi-times in the source, in this case, # we just copy its first appearing position. Therefore we mask the words # appearing second and onwards to -1 token_set = set() for i, tid in enumerate(src_token_vocab_ids): if tid in token_set: src_token_vocab_ids[i] = -1 else: token_set.add(tid) completed_hypotheses = [] while len(completed_hypotheses) < sample_size: t = 0 hypotheses = [DecodeHypothesis()] hyp_states = [[]] while t < args.decode_max_time_step: hyp_num = len(hypotheses) # (hyp_num, src_sent_len, hidden_size * 2) exp_src_encodings = src_encodings.expand( hyp_num, src_encodings.size(1), src_encodings.size(2)) # (hyp_num, src_sent_len, hidden_size) exp_src_encodings_att_linear = src_encodings_att_linear.expand( hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2)) if t == 0: x = Variable(self.new_tensor( 1, self.decoder_lstm.input_size).zero_(), volatile=True) offset = args.action_embed_size * 2 + args.field_embed_size x[0, offset:offset + args.type_embed_size] = self.type_embed.weight[ self.grammar.type2id[self.grammar.root_type]] else: actions_tm1 = [hyp.actions[-1] for hyp in hypotheses] a_tm1_embeds = [] for a_tm1 in actions_tm1: if a_tm1: if isinstance(a_tm1, ApplyRuleAction): a_tm1_embed = self.production_embed.weight[ self.grammar.prod2id[a_tm1.production]] elif isinstance(a_tm1, ReduceAction): a_tm1_embed = self.production_embed.weight[len( self.grammar)] else: a_tm1_embed = self.primitive_embed.weight[ self.vocab.primitive[a_tm1.token]] a_tm1_embeds.append(a_tm1_embed) else: a_tm1_embeds.append(zero_action_embed) a_tm1_embeds = torch.stack(a_tm1_embeds) # frontier production frontier_prods = [ hyp.frontier_node.production for hyp in hypotheses ] frontier_prod_embeds = self.production_embed( Variable( self.new_long_tensor([ self.grammar.prod2id[prod] for prod in frontier_prods ]))) # frontier field frontier_fields = [ hyp.frontier_field.field for hyp in hypotheses ] frontier_field_embeds = self.field_embed( Variable( self.new_long_tensor([ self.grammar.field2id[field] for field in frontier_fields ]))) # frontier field type frontier_field_types = [ hyp.frontier_field.type for hyp in hypotheses ] frontier_field_type_embeds = self.type_embed( Variable( self.new_long_tensor([ self.grammar.type2id[type] for type in frontier_field_types ]))) # parent states p_ts = [ hyp.frontier_node.created_time for hyp in hypotheses ] hist_states = torch.stack([ hyp_states[hyp_id][p_t] for hyp_id, p_t in enumerate(p_ts) ]) x = torch.cat([ a_tm1_embeds, att_tm1, frontier_prod_embeds, frontier_field_embeds, frontier_field_type_embeds, hist_states ], dim=-1) if args.lstm == 'lstm_with_dropout': self.decoder_lstm.set_dropout_masks(hyp_num) (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings, exp_src_encodings_att_linear, src_token_mask=None) # Variable(batch_size, grammar_size) apply_rule_log_prob = F.log_softmax( self.production_readout(att_t), dim=-1) # Variable(batch_size, src_sent_len) primitive_copy_prob = self.src_pointer_net( src_encodings, None, att_t.unsqueeze(0)).squeeze(0) # Variable(batch_size, primitive_vocab_size) gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1) # Variable(batch_size, 2) primitive_predictor_prob = F.softmax( self.primitive_predictor(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) primitive_prob = primitive_predictor_prob[:, 0].unsqueeze( 1) * gen_from_vocab_prob if src_unk_pos_list: primitive_prob[:, primitive_vocab.unk_id] = 1.e-10 gentoken_prev_hyp_ids = [] gentoken_new_hyp_unks = [] gentoken_copy_infos = [] applyrule_new_hyp_scores = [] applyrule_new_hyp_prod_ids = [] applyrule_prev_hyp_ids = [] for hyp_id, hyp in enumerate(hypotheses): # generate new continuations action_types = self.transition_system.get_valid_continuation_types( hyp) for action_type in action_types: if action_type == ApplyRuleAction: productions = self.transition_system.get_valid_continuating_productions( hyp) for production in productions: prod_id = self.grammar.prod2id[production] prod_score = apply_rule_log_prob[ hyp_id, prod_id].data[0] new_hyp_score = hyp.score + prod_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(prod_id) applyrule_prev_hyp_ids.append(hyp_id) elif action_type == ReduceAction: action_score = apply_rule_log_prob[ hyp_id, len(self.grammar)].data[0] new_hyp_score = hyp.score + action_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(len( self.grammar)) applyrule_prev_hyp_ids.append(hyp_id) else: # GenToken action gentoken_prev_hyp_ids.append(hyp_id) hyp_copy_info = dict() # of (token_pos, copy_prob) # first, we compute copy probabilities for tokens in the source sentence for token_pos, token_vocab_id in enumerate( src_token_vocab_ids): if token_vocab_id != -1 and token_vocab_id != primitive_vocab.unk_id: p_copy = primitive_predictor_prob[ hyp_id, 1] * primitive_copy_prob[hyp_id, token_pos] primitive_prob[ hyp_id, token_vocab_id] = primitive_prob[ hyp_id, token_vocab_id] + p_copy token = src_sent[token_pos] hyp_copy_info[token] = (token_pos, p_copy.data[0]) # second, add the probability of copying the most probable unk word if src_unk_pos_list: unk_pos = primitive_copy_prob[ hyp_id][src_unk_pos_list].data.cpu().numpy( ).argmax() unk_pos = src_unk_pos_list[unk_pos] token = src_sent[unk_pos] gentoken_new_hyp_unks.append(token) unk_copy_score = primitive_predictor_prob[ hyp_id, 1] * primitive_copy_prob[hyp_id, unk_pos] primitive_prob[ hyp_id, primitive_vocab.unk_id] = unk_copy_score hyp_copy_info[token] = (unk_pos, unk_copy_score.data[0]) gentoken_copy_infos.append(hyp_copy_info) new_hyp_scores = None if applyrule_new_hyp_scores: new_hyp_scores = Variable( self.new_tensor(applyrule_new_hyp_scores)) if gentoken_prev_hyp_ids: primitive_log_prob = torch.log(primitive_prob) gen_token_new_hyp_scores = ( hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1) if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores else: new_hyp_scores = torch.cat( [new_hyp_scores, gen_token_new_hyp_scores]) # top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores, # k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses))) pro = F.softmax(new_hyp_scores, 0) new_hyp_pos = torch.multinomial(pro, 1).data[0] new_hyp_score = new_hyp_scores[new_hyp_pos].data[0] live_hyp_ids = [] new_hypotheses = [] action_info = ActionInfo() if new_hyp_pos < len(applyrule_new_hyp_scores): # it's an ApplyRule or Reduce action prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos] prev_hyp = hypotheses[prev_hyp_id] prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos] # ApplyRule action if prod_id < len(self.grammar): production = self.grammar.id2prod[prod_id] action = ApplyRuleAction(production) # Reduce action else: action = ReduceAction() else: # it's a GenToken action token_id = (new_hyp_pos - len(applyrule_new_hyp_scores) ) % primitive_prob.size(1) k = (new_hyp_pos - len(applyrule_new_hyp_scores) ) // primitive_prob.size(1) # try: copy_info = gentoken_copy_infos[k] prev_hyp_id = gentoken_prev_hyp_ids[k] prev_hyp = hypotheses[prev_hyp_id] if token_id == primitive_vocab.unk_id: if gentoken_new_hyp_unks: token = gentoken_new_hyp_unks[k] else: token = primitive_vocab.id2word[ primitive_vocab.unk_id] else: token = primitive_vocab.id2word[token_id] action = GenTokenAction(token) if token in copy_info: action_info.copy_from_src = True action_info.src_token_position = copy_info[token][0] action_info.action = action action_info.t = t if t > 0: action_info.parent_t = prev_hyp.frontier_node.created_time action_info.frontier_prod = prev_hyp.frontier_node.production action_info.frontier_field = prev_hyp.frontier_field.field new_hyp = prev_hyp.clone_and_apply_action_info(action_info) new_hyp.score = new_hyp_score if new_hyp.completed: completed_hypotheses.append(new_hyp) else: new_hypotheses.append(new_hyp) live_hyp_ids.append(prev_hyp_id) if live_hyp_ids: hyp_states = [ hyp_states[i] + [h_t[i]] for i in live_hyp_ids ] h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) att_tm1 = att_t[live_hyp_ids] hypotheses = new_hypotheses hyp_scores = Variable( self.new_tensor([hyp.score for hyp in hypotheses])) t += 1 else: break return completed_hypotheses
def sample(self, src_sent, sample_size, decode_max_time_step, cuda=False, mode='sample'): new_float_tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor new_long_tensor = torch.cuda.LongTensor if cuda else torch.LongTensor src_sent_var = nn_utils.to_input_variable([src_sent], self.src_vocab, cuda=cuda, training=False) # analyze which tokens can be copied from the source src_token_tgt_vocab_ids = [self.tgt_vocab[token] for token in src_sent] src_unk_pos_list = [ pos for pos, token_id in enumerate(src_token_tgt_vocab_ids) if token_id == self.tgt_vocab.unk_id ] # sometimes a word may appear multi-times in the source, in this case, # we just copy its first appearing position. Therefore we mask the words # appearing second and onwards to -1 token_set = set() for i, tid in enumerate(src_token_tgt_vocab_ids): if tid in token_set: src_token_tgt_vocab_ids[i] = -1 else: token_set.add(tid) src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)]) h_tm1 = self.init_decoder_state(last_state, last_cell) # (batch_size, 1, hidden_size) src_encodings_att_linear = self.att_src_linear(src_encodings) t = 0 eos_id = self.tgt_vocab['</s>'] completed_hypotheses = [] completed_hypothesis_scores = [] if mode == 'beam_search': hypotheses = [['<s>']] hypotheses_word_ids = [[self.tgt_vocab['<s>']]] else: hypotheses = [['<s>'] for _ in xrange(sample_size)] hypotheses_word_ids = [[self.tgt_vocab['<s>']] for _ in xrange(sample_size)] att_tm1 = Variable(new_float_tensor(len(hypotheses), self.hidden_size).zero_(), volatile=True) hyp_scores = Variable(new_float_tensor(len(hypotheses)).zero_(), volatile=True) while len(completed_hypotheses ) < sample_size and t < decode_max_time_step: t += 1 hyp_num = len(hypotheses) expanded_src_encodings = src_encodings.expand( hyp_num, src_encodings.size(1), src_encodings.size(2)) expanded_src_encodings_att_linear = src_encodings_att_linear.expand( hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2)) y_tm1 = Variable(new_long_tensor( [hyp[-1] for hyp in hypotheses_word_ids]), volatile=True) y_tm1_embed = self.tgt_embed(y_tm1) x = torch.cat([y_tm1_embed, att_tm1], 1) (h_t, cell_t), att_t = self.step(x, h_tm1, expanded_src_encodings, expanded_src_encodings_att_linear) # (batch_size, 2) tgt_token_predictor = F.softmax(self.tgt_token_predictor(att_t), dim=-1) # (batch_size, tgt_vocab_size) token_gen_prob = F.softmax(self.readout(att_t), dim=-1) # (batch_size, src_sent_len) token_copy_prob = self.src_pointer_net( src_encodings, src_token_mask=None, query_vec=att_t.unsqueeze(0)).squeeze(0) # (batch_size, tgt_vocab_size) token_gen_prob = tgt_token_predictor[:, 0].unsqueeze( 1) * token_gen_prob for token_pos, token_vocab_id in enumerate( src_token_tgt_vocab_ids): if token_vocab_id != -1 and token_vocab_id != self.tgt_vocab.unk_id: p_copy = tgt_token_predictor[:, 1] * token_copy_prob[:, token_pos] token_gen_prob[:, token_vocab_id] = token_gen_prob[:, token_vocab_id] + p_copy # second, add the probability of copying the most probable unk word gentoken_new_hyp_unks = [] if src_unk_pos_list: for hyp_id in xrange(hyp_num): unk_pos = token_copy_prob[hyp_id][ src_unk_pos_list].data.cpu().numpy().argmax() unk_pos = src_unk_pos_list[unk_pos] token = src_sent[unk_pos] gentoken_new_hyp_unks.append(token) unk_copy_score = tgt_token_predictor[ hyp_id, 1] * token_copy_prob[hyp_id, unk_pos] token_gen_prob[hyp_id, self.tgt_vocab.unk_id] = unk_copy_score live_hyp_num = sample_size - len(completed_hypotheses) if mode == 'beam_search': log_token_gen_prob = torch.log(token_gen_prob) new_hyp_scores = ( hyp_scores.unsqueeze(1).expand_as(token_gen_prob) + log_token_gen_prob).view(-1) top_new_hyp_scores, top_new_hyp_pos = torch.topk( new_hyp_scores, k=live_hyp_num) prev_hyp_ids = (top_new_hyp_pos / len(self.tgt_vocab)).cpu().data word_ids = (top_new_hyp_pos % len(self.tgt_vocab)).cpu().data top_new_hyp_scores = top_new_hyp_scores.cpu().data else: word_ids = torch.multinomial(token_gen_prob, num_samples=1) prev_hyp_ids = range(live_hyp_num) top_new_hyp_scores = hyp_scores + torch.log( torch.gather(token_gen_prob, dim=1, index=word_ids)).squeeze(1) top_new_hyp_scores = top_new_hyp_scores.cpu().data word_ids = word_ids.view(-1).cpu().data new_hypotheses = [] new_hypotheses_word_ids = [] live_hyp_ids = [] new_hyp_scores = [] for prev_hyp_id, word_id, new_hyp_score in zip( prev_hyp_ids, word_ids, top_new_hyp_scores): if word_id == eos_id: hyp_tgt_words = hypotheses[prev_hyp_id][1:] completed_hypotheses.append( hyp_tgt_words ) # remove <s> and </s> in completed hypothesis completed_hypothesis_scores.append(new_hyp_score) else: if word_id == self.tgt_vocab.unk_id: if gentoken_new_hyp_unks: word = gentoken_new_hyp_unks[prev_hyp_id] else: word = self.tgt_vocab.id2word[ self.tgt_vocab.unk_id] else: word = self.tgt_vocab.id2word[word_id] hyp_tgt_words = hypotheses[prev_hyp_id] + [word] new_hypotheses.append(hyp_tgt_words) new_hypotheses_word_ids.append( hypotheses_word_ids[prev_hyp_id] + [word_id]) live_hyp_ids.append(prev_hyp_id) new_hyp_scores.append(new_hyp_score) if len(completed_hypotheses) == sample_size: break live_hyp_ids = new_long_tensor(live_hyp_ids) h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) att_tm1 = att_t[live_hyp_ids] hyp_scores = Variable( new_float_tensor(new_hyp_scores), volatile=True) # new_hyp_scores[live_hyp_ids] hypotheses = new_hypotheses hypotheses_word_ids = new_hypotheses_word_ids return completed_hypotheses
def beam_search(self, src_sents, decode_max_time_step, beam_size=5, to_word=True): """ given a not-batched source, sentence perform beam search to find the n-best :param src_sent: List[word_id], encoded source sentence :return: list[list[word_id]] top-k predicted natural language sentence in the beam """ src_sents_var = nn_utils.to_input_variable(src_sents, self.src_vocab, cuda=self.cuda, training=False, append_boundary_sym=False) #TODO(junxian): check if src_sents_var(src_seq_length, embed_size) is ok src_encodings, (last_state, last_cell) = self.encode(src_sents_var, [len(src_sents[0])]) # (1, query_len, hidden_size * 2) src_encodings = src_encodings.permute(1, 0, 2) src_encodings_att_linear = self.att_src_linear(src_encodings) h_tm1 = self.init_decoder_state(last_state, last_cell) # tensor constructors new_float_tensor = src_encodings.data.new if self.cuda: new_long_tensor = torch.cuda.LongTensor else: new_long_tensor = torch.LongTensor att_tm1 = Variable(torch.zeros(1, self.hidden_size), volatile=True) hyp_scores = Variable(torch.zeros(1), volatile=True) if self.cuda: att_tm1 = att_tm1.cuda() hyp_scores = hyp_scores.cuda() eos_id = self.tgt_vocab['</s>'] bos_id = self.tgt_vocab['<s>'] tgt_vocab_size = len(self.tgt_vocab) hypotheses = [[bos_id]] completed_hypotheses = [] completed_hypothesis_scores = [] t = 0 while len(completed_hypotheses) < beam_size and t < decode_max_time_step: t += 1 hyp_num = len(hypotheses) expanded_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1), src_encodings.size(2)) expanded_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2)) y_tm1 = Variable(new_long_tensor([hyp[-1] for hyp in hypotheses]), volatile=True) y_tm1_embed = self.tgt_embed(y_tm1) x = torch.cat([y_tm1_embed, att_tm1], 1) # h_t: (hyp_num, hidden_size) (h_t, cell_t), att_t, score_t = self.step(x, h_tm1, expanded_src_encodings, expanded_src_encodings_att_linear, src_sent_masks=None) p_t = F.log_softmax(score_t) live_hyp_num = beam_size - len(completed_hypotheses) new_hyp_scores = (hyp_scores.unsqueeze(1).expand_as(p_t) + p_t).view(-1) top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores, k=live_hyp_num) prev_hyp_ids = top_new_hyp_pos / tgt_vocab_size word_ids = top_new_hyp_pos % tgt_vocab_size new_hypotheses = [] live_hyp_ids = [] new_hyp_scores = [] for prev_hyp_id, word_id, new_hyp_score in zip(prev_hyp_ids.cpu().data, word_ids.cpu().data, top_new_hyp_scores.cpu().data): hyp_tgt_words = hypotheses[prev_hyp_id] + [word_id] if word_id == eos_id: completed_hypotheses.append(hyp_tgt_words[1:-1]) # remove <s> and </s> in completed hypothesis completed_hypothesis_scores.append(new_hyp_score) else: new_hypotheses.append(hyp_tgt_words) live_hyp_ids.append(prev_hyp_id) new_hyp_scores.append(new_hyp_score) if len(completed_hypotheses) == beam_size: break live_hyp_ids = new_long_tensor(live_hyp_ids) h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) att_tm1 = att_t[live_hyp_ids] hyp_scores = Variable(new_float_tensor(new_hyp_scores), volatile=True) # new_hyp_scores[live_hyp_ids] hypotheses = new_hypotheses if len(completed_hypotheses) == 0: completed_hypotheses = [hypotheses[0][1:-1]] # remove <s> and </s> in completed hypothesis completed_hypothesis_scores = [0.0] if to_word: for i, hyp in enumerate(completed_hypotheses): completed_hypotheses[i] = [self.tgt_vocab.id2word[w] for w in hyp] ranked_hypotheses = sorted(zip(completed_hypotheses, completed_hypothesis_scores), key=lambda x: x[1], reverse=True) return [hyp for hyp, score in ranked_hypotheses]
def src_sents_var(self): return nn_utils.to_input_variable(self.src_sents, self.vocab.source, cuda=self.cuda)
def train_lstm_lm(args): all_data = load_code_dir(args.code_dir) np.random.shuffle(all_data) train_data = all_data[:-1000] dev_data = all_data[-1000:] print('train data size: %d, dev data size: %d' % (len(train_data), len(dev_data)), file=sys.stderr) vocab = VocabEntry.from_corpus([e['tokens'] for e in train_data], size=args.vocab_size, freq_cutoff=args.freq_cutoff) print('vocab size: %d' % len(vocab), file=sys.stderr) model = LSTMPrior(args, vocab) model.train() if args.cuda: model.cuda() optimizer = torch.optim.Adam(model.parameters()) def evaluate_ppl(): model.eval() cum_loss = 0. cum_tgt_words = 0. for examples in nn_utils.batch_iter(dev_data, args.batch_size): batch_tokens = [e['tokens'] for e in examples] batch = nn_utils.to_input_variable(batch_tokens, vocab, cuda=args.cuda, append_boundary_sym=True) loss = model.forward(batch).sum() cum_loss += loss.data[0] cum_tgt_words += sum(len(tokens) + 1 for tokens in batch_tokens) # add ending </s> ppl = np.exp(cum_loss / cum_tgt_words) model.train() return ppl print('begin training decoder, %d training examples, %d dev examples' % (len(train_data), len(dev_data)), file=sys.stderr) epoch = num_trial = train_iter = patience = 0 report_loss = report_examples = 0. history_dev_scores = [] while True: epoch += 1 epoch_begin = time.time() for examples in nn_utils.batch_iter(train_data, batch_size=args.batch_size, shuffle=True): train_iter += 1 optimizer.zero_grad() batch_tokens = [e['tokens'] for e in examples] batch = nn_utils.to_input_variable(batch_tokens, vocab, cuda=args.cuda, append_boundary_sym=True) loss = model.forward(batch) # print(loss.data) loss_val = torch.sum(loss).data[0] report_loss += loss_val report_examples += len(examples) loss = torch.mean(loss) loss.backward() # clip gradient grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad) optimizer.step() if train_iter % args.log_every == 0: print('[Iter %d] encoder loss=%.5f' % (train_iter, report_loss / report_examples), file=sys.stderr) report_loss = report_examples = 0. print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr) # model_file = args.save_to + '.iter%d.bin' % train_iter # print('save model to [%s]' % model_file, file=sys.stderr) # model.save(model_file) # perform validation print('[Epoch %d] begin validation' % epoch, file=sys.stderr) eval_start = time.time() # evaluate ppl ppl = evaluate_ppl() print('[Epoch %d] ppl=%.5f took %ds' % (epoch, ppl, time.time() - eval_start), file=sys.stderr) dev_acc = -ppl is_better = history_dev_scores == [] or dev_acc > max( history_dev_scores) history_dev_scores.append(dev_acc) if is_better: patience = 0 model_file = args.save_to + '.bin' print('save currently the best model ..', file=sys.stderr) print('save model to [%s]' % model_file, file=sys.stderr) model.save(model_file) # also save the optimizers' state torch.save(optimizer.state_dict(), args.save_to + '.optim.bin') elif patience < args.patience: patience += 1 print('hit patience %d' % patience, file=sys.stderr) if patience == args.patience: num_trial += 1 print('hit #%d trial' % num_trial, file=sys.stderr) if num_trial == args.max_num_trial: print('early stop!', file=sys.stderr) exit(0) # decay lr, and restore from previously best checkpoint lr = optimizer.param_groups[0]['lr'] * args.lr_decay print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr) # load model params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage) model.load_state_dict(params['state_dict']) if args.cuda: model = model.cuda() # load optimizers if args.reset_optimizer: print('reset optimizer', file=sys.stderr) optimizer = torch.optim.Adam( model.inference_model.parameters(), lr=lr) else: print('restore parameters of the optimizers', file=sys.stderr) optimizer.load_state_dict( torch.load(args.save_to + '.optim.bin')) # set new lr for param_group in optimizer.param_groups: param_group['lr'] = lr # reset patience patience = 0
def parse(self, src_sent, context=None, beam_size=5, debug=False): """Perform beam search to infer the target AST given a source utterance Args: src_sent: list of source utterance tokens context: other context used for prediction beam_size: beam size Returns: A list of `DecodeHypothesis`, each representing an AST """ args = self.args primitive_vocab = self.vocab.primitive T = torch.cuda if args.cuda else torch src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False) # Variable(1, src_sent_len, hidden_size * 2) ##########src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)]) src_encodings, last_state = self.encode(src_sent_var, [len(src_sent)]) # (1, src_sent_len, hidden_size) src_encodings_att_linear = self.att_src_linear(src_encodings) #######dec_init_vec = self.init_decoder_state(last_state, last_cell) dec_init_vec = self.init_decoder_state(last_state) ##### if args.lstm == 'parent_feed': h_tm1 = dec_init_vec[0], dec_init_vec[1], \ Variable(self.new_tensor(args.hidden_size).zero_()), \ Variable(self.new_tensor(args.hidden_size).zero_()) else: h_tm1 = dec_init_vec zero_action_embed = Variable( self.new_tensor(args.action_embed_size).zero_()) with torch.no_grad(): hyp_scores = Variable(self.new_tensor([0.])) # For computing copy probabilities, we marginalize over tokens with the same surface form # `aggregated_primitive_tokens` stores the position of occurrence of each source token aggregated_primitive_tokens = OrderedDict() for token_pos, token in enumerate(src_sent): aggregated_primitive_tokens.setdefault(token, []).append(token_pos) t = 0 hypotheses = [DecodeHypothesis()] hyp_states = [[]] completed_hypotheses = [] while len(completed_hypotheses ) < beam_size and t < args.decode_max_time_step: hyp_num = len(hypotheses) # (hyp_num, src_sent_len, hidden_size * 2) exp_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1), src_encodings.size(2)) # (hyp_num, src_sent_len, hidden_size) exp_src_encodings_att_linear = src_encodings_att_linear.expand( hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2)) if t == 0: with torch.no_grad(): x = Variable( self.new_tensor(1, self.decoder_lstm.input_size).zero_()) if args.no_parent_field_type_embed is False: offset = args.action_embed_size # prev_action offset += args.att_vec_size * (not args.no_input_feed) offset += args.action_embed_size * ( not args.no_parent_production_embed) offset += args.field_embed_size * ( not args.no_parent_field_embed) x[0, offset: offset + args.type_embed_size] = \ self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]] else: actions_tm1 = [hyp.actions[-1] for hyp in hypotheses] a_tm1_embeds = [] for a_tm1 in actions_tm1: if a_tm1: if isinstance(a_tm1, ApplyRuleAction): a_tm1_embed = self.production_embed.weight[ self.grammar.prod2id[a_tm1.production]] elif isinstance(a_tm1, ReduceAction): a_tm1_embed = self.production_embed.weight[len( self.grammar)] else: a_tm1_embed = self.primitive_embed.weight[ self.vocab.primitive[a_tm1.token]] a_tm1_embeds.append(a_tm1_embed) else: a_tm1_embeds.append(zero_action_embed) a_tm1_embeds = torch.stack(a_tm1_embeds) inputs = [a_tm1_embeds] if args.no_input_feed is False: inputs.append(att_tm1) if args.no_parent_production_embed is False: # frontier production frontier_prods = [ hyp.frontier_node.production for hyp in hypotheses ] frontier_prod_embeds = self.production_embed( Variable( self.new_long_tensor([ self.grammar.prod2id[prod] for prod in frontier_prods ]))) inputs.append(frontier_prod_embeds) if args.no_parent_field_embed is False: # frontier field frontier_fields = [ hyp.frontier_field.field for hyp in hypotheses ] frontier_field_embeds = self.field_embed( Variable( self.new_long_tensor([ self.grammar.field2id[field] for field in frontier_fields ]))) inputs.append(frontier_field_embeds) if args.no_parent_field_type_embed is False: # frontier field type frontier_field_types = [ hyp.frontier_field.type for hyp in hypotheses ] frontier_field_type_embeds = self.type_embed( Variable( self.new_long_tensor([ self.grammar.type2id[type] for type in frontier_field_types ]))) inputs.append(frontier_field_type_embeds) # parent states if args.no_parent_state is False: p_ts = [ hyp.frontier_node.created_time for hyp in hypotheses ] parent_states = torch.stack([ hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts) ]) parent_cells = torch.stack([ hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts) ]) if args.lstm == 'parent_feed': h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells) else: inputs.append(parent_states) x = torch.cat(inputs, dim=-1) (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings, exp_src_encodings_att_linear, src_token_mask=None) # Variable(batch_size, grammar_size) # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1)) apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1) if args.no_copy: primitive_prob = gen_from_vocab_prob else: # Variable(batch_size, src_sent_len) primitive_copy_prob = self.src_pointer_net( src_encodings, None, att_t.unsqueeze(0)).squeeze(0) # Variable(batch_size, 2) primitive_predictor_prob = F.softmax( self.primitive_predictor(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) primitive_prob = primitive_predictor_prob[:, 0].unsqueeze( 1) * gen_from_vocab_prob # if src_unk_pos_list: # primitive_prob[:, primitive_vocab.unk_id] = 1.e-10 gentoken_prev_hyp_ids = [] gentoken_new_hyp_unks = [] applyrule_new_hyp_scores = [] applyrule_new_hyp_prod_ids = [] applyrule_prev_hyp_ids = [] for hyp_id, hyp in enumerate(hypotheses): # generate new continuations action_types = self.transition_system.get_valid_continuation_types( hyp) for action_type in action_types: if action_type == ApplyRuleAction: productions = self.transition_system.get_valid_continuating_productions( hyp) for production in productions: prod_id = self.grammar.prod2id[production] prod_score = apply_rule_log_prob[ hyp_id, prod_id].data.item() new_hyp_score = hyp.score + prod_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(prod_id) applyrule_prev_hyp_ids.append(hyp_id) elif action_type == ReduceAction: action_score = apply_rule_log_prob[ hyp_id, len(self.grammar)].data.item() new_hyp_score = hyp.score + action_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(len(self.grammar)) applyrule_prev_hyp_ids.append(hyp_id) else: # GenToken action gentoken_prev_hyp_ids.append(hyp_id) hyp_copy_info = dict() # of (token_pos, copy_prob) hyp_unk_copy_info = [] if args.no_copy is False: for token, token_pos_list in aggregated_primitive_tokens.items( ): sum_copy_prob = torch.gather( primitive_copy_prob[hyp_id], 0, Variable( T.LongTensor(token_pos_list))).sum() gated_copy_prob = primitive_predictor_prob[ hyp_id, 1] * sum_copy_prob if token in primitive_vocab: token_id = primitive_vocab[token] primitive_prob[ hyp_id, token_id] = primitive_prob[ hyp_id, token_id] + gated_copy_prob hyp_copy_info[token] = ( token_pos_list, gated_copy_prob.data.item()) else: hyp_unk_copy_info.append({ 'token': token, 'token_pos_list': token_pos_list, 'copy_prob': gated_copy_prob.data.item() }) if args.no_copy is False and len( hyp_unk_copy_info) > 0: unk_i = np.array([ x['copy_prob'] for x in hyp_unk_copy_info ]).argmax() token = hyp_unk_copy_info[unk_i]['token'] primitive_prob[ hyp_id, primitive_vocab. unk_id] = hyp_unk_copy_info[unk_i]['copy_prob'] gentoken_new_hyp_unks.append(token) hyp_copy_info[token] = ( hyp_unk_copy_info[unk_i]['token_pos_list'], hyp_unk_copy_info[unk_i]['copy_prob']) new_hyp_scores = None if applyrule_new_hyp_scores: new_hyp_scores = Variable( self.new_tensor(applyrule_new_hyp_scores)) if gentoken_prev_hyp_ids: primitive_log_prob = torch.log(primitive_prob) gen_token_new_hyp_scores = ( hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1) if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores else: new_hyp_scores = torch.cat( [new_hyp_scores, gen_token_new_hyp_scores]) top_new_hyp_scores, top_new_hyp_pos = torch.topk( new_hyp_scores, k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses))) live_hyp_ids = [] new_hypotheses = [] for new_hyp_score, new_hyp_pos in zip( top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()): action_info = ActionInfo() if new_hyp_pos < len(applyrule_new_hyp_scores): # it's an ApplyRule or Reduce action prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos] prev_hyp = hypotheses[prev_hyp_id] prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos] # ApplyRule action if prod_id < len(self.grammar): production = self.grammar.id2prod[prod_id] action = ApplyRuleAction(production) # Reduce action else: action = ReduceAction() else: # it's a GenToken action token_id = (new_hyp_pos - len(applyrule_new_hyp_scores) ) % primitive_prob.size(1) k = (new_hyp_pos - len(applyrule_new_hyp_scores) ) // primitive_prob.size(1) # try: # copy_info = gentoken_copy_infos[k] prev_hyp_id = gentoken_prev_hyp_ids[k] prev_hyp = hypotheses[prev_hyp_id] # except: # print('k=%d' % k, file=sys.stderr) # print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr) # print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr) # print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr) # print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr) # print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr) # print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr) # print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr) # print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr) # print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr) # # torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin') # # # exit(-1) # raise ValueError() if token_id == primitive_vocab.unk_id: if gentoken_new_hyp_unks: token = gentoken_new_hyp_unks[k] else: token = primitive_vocab.id2word( primitive_vocab.unk_id) ###### else: token = primitive_vocab.id2word(token_id.item()) action = GenTokenAction(token) if token in aggregated_primitive_tokens: action_info.copy_from_src = True action_info.src_token_position = aggregated_primitive_tokens[ token] if debug: action_info.gen_copy_switch = 'n/a' if args.no_copy else primitive_predictor_prob[ prev_hyp_id, :].log().cpu().data.numpy() action_info.in_vocab = token in primitive_vocab action_info.gen_token_prob = gen_from_vocab_prob[prev_hyp_id, token_id].log().cpu().data.item() \ if token in primitive_vocab else 'n/a' action_info.copy_token_prob = torch.gather(primitive_copy_prob[prev_hyp_id], 0, Variable(T.LongTensor(action_info.src_token_position))).sum().log().cpu().data.item() \ if args.no_copy is False and action_info.copy_from_src else 'n/a' action_info.action = action action_info.t = t if t > 0: action_info.parent_t = prev_hyp.frontier_node.created_time action_info.frontier_prod = prev_hyp.frontier_node.production action_info.frontier_field = prev_hyp.frontier_field.field if debug: action_info.action_prob = new_hyp_score - prev_hyp.score new_hyp = prev_hyp.clone_and_apply_action_info(action_info) new_hyp.score = new_hyp_score if new_hyp.completed: # add length normalization new_hyp.score /= (t + 1) completed_hypotheses.append(new_hyp) else: new_hypotheses.append(new_hyp) live_hyp_ids.append(prev_hyp_id) if live_hyp_ids: hyp_states = [ hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids ] h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) att_tm1 = att_t[live_hyp_ids] hypotheses = new_hypotheses hyp_scores = Variable( self.new_tensor([hyp.score for hyp in hypotheses])) t += 1 else: break completed_hypotheses.sort(key=lambda hyp: -hyp.score) return completed_hypotheses