def get_action_infos(src_query, tgt_actions, force_copy=False): action_infos = [] hyp = Hypothesis() for t, action in enumerate(tgt_actions): action_info = ActionInfo(action) action_info.t = t if hyp.frontier_node: action_info.parent_t = hyp.frontier_node.created_time action_info.frontier_prod = hyp.frontier_node.production action_info.frontier_field = hyp.frontier_field.field if isinstance(action, GenTokenAction): try: tok_src_idx = src_query.index(str(action.token)) action_info.copy_from_src = True action_info.src_token_position = tok_src_idx except ValueError: if force_copy: raise ValueError( 'cannot copy primitive token %s from source' % action.token) hyp.apply_action(action) action_infos.append(action_info) return action_infos
def get_action_infos(src_query, tgt_actions, force_copy=False): action_infos = [] hyp = Hypothesis() for t, action in enumerate(tgt_actions): action_info = ActionInfo(action) action_info.t = t if hyp.frontier_node: action_info.parent_t = hyp.frontier_node.created_time action_info.frontier_prod = hyp.frontier_node.production action_info.frontier_field = hyp.frontier_field.field if isinstance(action, GenTokenAction): try: tok_src_idx = src_query.index(str(action.token)) action_info.copy_from_src = True action_info.src_token_position = tok_src_idx except ValueError: if force_copy: raise ValueError('cannot copy primitive token %s from source' % action.token) hyp.apply_action(action) action_infos.append(action_info) return action_infos
def parse(self, src_sent, context=None, beam_size=5, debug=False): """Perform beam search to infer the target AST given a source utterance Args: src_sent: list of source utterance tokens context: other context used for prediction beam_size: beam size Returns: A list of `DecodeHypothesis`, each representing an AST """ args = self.args primitive_vocab = self.vocab.primitive T = torch.cuda if args.cuda else torch src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False) # Variable(1, src_sent_len, hidden_size * 2) ##########src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)]) src_encodings, last_state = self.encode(src_sent_var, [len(src_sent)]) # (1, src_sent_len, hidden_size) src_encodings_att_linear = self.att_src_linear(src_encodings) #######dec_init_vec = self.init_decoder_state(last_state, last_cell) dec_init_vec = self.init_decoder_state(last_state) ##### if args.lstm == 'parent_feed': h_tm1 = dec_init_vec[0], dec_init_vec[1], \ Variable(self.new_tensor(args.hidden_size).zero_()), \ Variable(self.new_tensor(args.hidden_size).zero_()) else: h_tm1 = dec_init_vec zero_action_embed = Variable( self.new_tensor(args.action_embed_size).zero_()) with torch.no_grad(): hyp_scores = Variable(self.new_tensor([0.])) # For computing copy probabilities, we marginalize over tokens with the same surface form # `aggregated_primitive_tokens` stores the position of occurrence of each source token aggregated_primitive_tokens = OrderedDict() for token_pos, token in enumerate(src_sent): aggregated_primitive_tokens.setdefault(token, []).append(token_pos) t = 0 hypotheses = [DecodeHypothesis()] hyp_states = [[]] completed_hypotheses = [] while len(completed_hypotheses ) < beam_size and t < args.decode_max_time_step: hyp_num = len(hypotheses) # (hyp_num, src_sent_len, hidden_size * 2) exp_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1), src_encodings.size(2)) # (hyp_num, src_sent_len, hidden_size) exp_src_encodings_att_linear = src_encodings_att_linear.expand( hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2)) if t == 0: with torch.no_grad(): x = Variable( self.new_tensor(1, self.decoder_lstm.input_size).zero_()) if args.no_parent_field_type_embed is False: offset = args.action_embed_size # prev_action offset += args.att_vec_size * (not args.no_input_feed) offset += args.action_embed_size * ( not args.no_parent_production_embed) offset += args.field_embed_size * ( not args.no_parent_field_embed) x[0, offset: offset + args.type_embed_size] = \ self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]] else: actions_tm1 = [hyp.actions[-1] for hyp in hypotheses] a_tm1_embeds = [] for a_tm1 in actions_tm1: if a_tm1: if isinstance(a_tm1, ApplyRuleAction): a_tm1_embed = self.production_embed.weight[ self.grammar.prod2id[a_tm1.production]] elif isinstance(a_tm1, ReduceAction): a_tm1_embed = self.production_embed.weight[len( self.grammar)] else: a_tm1_embed = self.primitive_embed.weight[ self.vocab.primitive[a_tm1.token]] a_tm1_embeds.append(a_tm1_embed) else: a_tm1_embeds.append(zero_action_embed) a_tm1_embeds = torch.stack(a_tm1_embeds) inputs = [a_tm1_embeds] if args.no_input_feed is False: inputs.append(att_tm1) if args.no_parent_production_embed is False: # frontier production frontier_prods = [ hyp.frontier_node.production for hyp in hypotheses ] frontier_prod_embeds = self.production_embed( Variable( self.new_long_tensor([ self.grammar.prod2id[prod] for prod in frontier_prods ]))) inputs.append(frontier_prod_embeds) if args.no_parent_field_embed is False: # frontier field frontier_fields = [ hyp.frontier_field.field for hyp in hypotheses ] frontier_field_embeds = self.field_embed( Variable( self.new_long_tensor([ self.grammar.field2id[field] for field in frontier_fields ]))) inputs.append(frontier_field_embeds) if args.no_parent_field_type_embed is False: # frontier field type frontier_field_types = [ hyp.frontier_field.type for hyp in hypotheses ] frontier_field_type_embeds = self.type_embed( Variable( self.new_long_tensor([ self.grammar.type2id[type] for type in frontier_field_types ]))) inputs.append(frontier_field_type_embeds) # parent states if args.no_parent_state is False: p_ts = [ hyp.frontier_node.created_time for hyp in hypotheses ] parent_states = torch.stack([ hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts) ]) parent_cells = torch.stack([ hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts) ]) if args.lstm == 'parent_feed': h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells) else: inputs.append(parent_states) x = torch.cat(inputs, dim=-1) (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings, exp_src_encodings_att_linear, src_token_mask=None) # Variable(batch_size, grammar_size) # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1)) apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1) if args.no_copy: primitive_prob = gen_from_vocab_prob else: # Variable(batch_size, src_sent_len) primitive_copy_prob = self.src_pointer_net( src_encodings, None, att_t.unsqueeze(0)).squeeze(0) # Variable(batch_size, 2) primitive_predictor_prob = F.softmax( self.primitive_predictor(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) primitive_prob = primitive_predictor_prob[:, 0].unsqueeze( 1) * gen_from_vocab_prob # if src_unk_pos_list: # primitive_prob[:, primitive_vocab.unk_id] = 1.e-10 gentoken_prev_hyp_ids = [] gentoken_new_hyp_unks = [] applyrule_new_hyp_scores = [] applyrule_new_hyp_prod_ids = [] applyrule_prev_hyp_ids = [] for hyp_id, hyp in enumerate(hypotheses): # generate new continuations action_types = self.transition_system.get_valid_continuation_types( hyp) for action_type in action_types: if action_type == ApplyRuleAction: productions = self.transition_system.get_valid_continuating_productions( hyp) for production in productions: prod_id = self.grammar.prod2id[production] prod_score = apply_rule_log_prob[ hyp_id, prod_id].data.item() new_hyp_score = hyp.score + prod_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(prod_id) applyrule_prev_hyp_ids.append(hyp_id) elif action_type == ReduceAction: action_score = apply_rule_log_prob[ hyp_id, len(self.grammar)].data.item() new_hyp_score = hyp.score + action_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(len(self.grammar)) applyrule_prev_hyp_ids.append(hyp_id) else: # GenToken action gentoken_prev_hyp_ids.append(hyp_id) hyp_copy_info = dict() # of (token_pos, copy_prob) hyp_unk_copy_info = [] if args.no_copy is False: for token, token_pos_list in aggregated_primitive_tokens.items( ): sum_copy_prob = torch.gather( primitive_copy_prob[hyp_id], 0, Variable( T.LongTensor(token_pos_list))).sum() gated_copy_prob = primitive_predictor_prob[ hyp_id, 1] * sum_copy_prob if token in primitive_vocab: token_id = primitive_vocab[token] primitive_prob[ hyp_id, token_id] = primitive_prob[ hyp_id, token_id] + gated_copy_prob hyp_copy_info[token] = ( token_pos_list, gated_copy_prob.data.item()) else: hyp_unk_copy_info.append({ 'token': token, 'token_pos_list': token_pos_list, 'copy_prob': gated_copy_prob.data.item() }) if args.no_copy is False and len( hyp_unk_copy_info) > 0: unk_i = np.array([ x['copy_prob'] for x in hyp_unk_copy_info ]).argmax() token = hyp_unk_copy_info[unk_i]['token'] primitive_prob[ hyp_id, primitive_vocab. unk_id] = hyp_unk_copy_info[unk_i]['copy_prob'] gentoken_new_hyp_unks.append(token) hyp_copy_info[token] = ( hyp_unk_copy_info[unk_i]['token_pos_list'], hyp_unk_copy_info[unk_i]['copy_prob']) new_hyp_scores = None if applyrule_new_hyp_scores: new_hyp_scores = Variable( self.new_tensor(applyrule_new_hyp_scores)) if gentoken_prev_hyp_ids: primitive_log_prob = torch.log(primitive_prob) gen_token_new_hyp_scores = ( hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1) if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores else: new_hyp_scores = torch.cat( [new_hyp_scores, gen_token_new_hyp_scores]) top_new_hyp_scores, top_new_hyp_pos = torch.topk( new_hyp_scores, k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses))) live_hyp_ids = [] new_hypotheses = [] for new_hyp_score, new_hyp_pos in zip( top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()): action_info = ActionInfo() if new_hyp_pos < len(applyrule_new_hyp_scores): # it's an ApplyRule or Reduce action prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos] prev_hyp = hypotheses[prev_hyp_id] prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos] # ApplyRule action if prod_id < len(self.grammar): production = self.grammar.id2prod[prod_id] action = ApplyRuleAction(production) # Reduce action else: action = ReduceAction() else: # it's a GenToken action token_id = (new_hyp_pos - len(applyrule_new_hyp_scores) ) % primitive_prob.size(1) k = (new_hyp_pos - len(applyrule_new_hyp_scores) ) // primitive_prob.size(1) # try: # copy_info = gentoken_copy_infos[k] prev_hyp_id = gentoken_prev_hyp_ids[k] prev_hyp = hypotheses[prev_hyp_id] # except: # print('k=%d' % k, file=sys.stderr) # print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr) # print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr) # print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr) # print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr) # print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr) # print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr) # print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr) # print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr) # print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr) # # torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin') # # # exit(-1) # raise ValueError() if token_id == primitive_vocab.unk_id: if gentoken_new_hyp_unks: token = gentoken_new_hyp_unks[k] else: token = primitive_vocab.id2word( primitive_vocab.unk_id) ###### else: token = primitive_vocab.id2word(token_id.item()) action = GenTokenAction(token) if token in aggregated_primitive_tokens: action_info.copy_from_src = True action_info.src_token_position = aggregated_primitive_tokens[ token] if debug: action_info.gen_copy_switch = 'n/a' if args.no_copy else primitive_predictor_prob[ prev_hyp_id, :].log().cpu().data.numpy() action_info.in_vocab = token in primitive_vocab action_info.gen_token_prob = gen_from_vocab_prob[prev_hyp_id, token_id].log().cpu().data.item() \ if token in primitive_vocab else 'n/a' action_info.copy_token_prob = torch.gather(primitive_copy_prob[prev_hyp_id], 0, Variable(T.LongTensor(action_info.src_token_position))).sum().log().cpu().data.item() \ if args.no_copy is False and action_info.copy_from_src else 'n/a' action_info.action = action action_info.t = t if t > 0: action_info.parent_t = prev_hyp.frontier_node.created_time action_info.frontier_prod = prev_hyp.frontier_node.production action_info.frontier_field = prev_hyp.frontier_field.field if debug: action_info.action_prob = new_hyp_score - prev_hyp.score new_hyp = prev_hyp.clone_and_apply_action_info(action_info) new_hyp.score = new_hyp_score if new_hyp.completed: # add length normalization new_hyp.score /= (t + 1) completed_hypotheses.append(new_hyp) else: new_hypotheses.append(new_hyp) live_hyp_ids.append(prev_hyp_id) if live_hyp_ids: hyp_states = [ hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids ] h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) att_tm1 = att_t[live_hyp_ids] hypotheses = new_hypotheses hyp_scores = Variable( self.new_tensor([hyp.score for hyp in hypotheses])) t += 1 else: break completed_hypotheses.sort(key=lambda hyp: -hyp.score) return completed_hypotheses
def parse(self, src_sent, context=None, beam_size=5): """Perform beam search to infer the target AST given a source utterance Args: src_sent: list of source utterance tokens context: other context used for prediction beam_size: beam size Returns: A list of `DecodeHypothesis`, each representing an AST """ # print('!!!!!!') args = self.args primitive_vocab = self.vocab.primitive src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False) # Variable(1, src_sent_len, hidden_size * 2) src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)]) # (1, src_sent_len, hidden_size) src_encodings_att_linear = self.att_src_linear(src_encodings) dec_init_vec = self.init_decoder_state(last_state, last_cell) if args.lstm == 'parent_feed': h_tm1 = dec_init_vec[0], dec_init_vec[1], \ Variable(self.new_tensor(args.hidden_size).zero_()), \ Variable(self.new_tensor(args.hidden_size).zero_()) else: h_tm1 = dec_init_vec zero_action_embed = Variable(self.new_tensor(args.action_embed_size).zero_()) hyp_scores = Variable(self.new_tensor([0.]), volatile=True) src_token_vocab_ids = [primitive_vocab[token] for token in src_sent] src_unk_pos_list = [pos for pos, token_id in enumerate(src_token_vocab_ids) if token_id == primitive_vocab.unk_id] # sometimes a word may appear multi-times in the source, in this case, # we just copy its first appearing position. Therefore we mask the words # appearing second and onwards to -1 token_set = set() for i, tid in enumerate(src_token_vocab_ids): if tid in token_set: src_token_vocab_ids[i] = -1 else: token_set.add(tid) t = 0 hypotheses = [DecodeHypothesis()] hyp_states = [[]] completed_hypotheses = [] while len(completed_hypotheses) < beam_size and t < args.decode_max_time_step: hyp_num = len(hypotheses) # (hyp_num, src_sent_len, hidden_size * 2) exp_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1), src_encodings.size(2)) # (hyp_num, src_sent_len, hidden_size) exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2)) if t == 0: x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True) if args.no_parent_field_type_embed is False: offset = args.action_embed_size # prev_action offset += args.att_vec_size * (not args.no_input_feed) offset += args.action_embed_size * (not args.no_parent_production_embed) offset += args.field_embed_size * (not args.no_parent_field_embed) x[0, offset: offset + args.type_embed_size] = \ self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]] else: actions_tm1 = [hyp.actions[-1] for hyp in hypotheses] a_tm1_embeds = [] for a_tm1 in actions_tm1: if a_tm1: if isinstance(a_tm1, ApplyRuleAction): a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[a_tm1.production]] elif isinstance(a_tm1, ReduceAction): a_tm1_embed = self.production_embed.weight[len(self.grammar)] else: a_tm1_embed = self.primitive_embed.weight[self.vocab.primitive[a_tm1.token]] a_tm1_embeds.append(a_tm1_embed) else: a_tm1_embeds.append(zero_action_embed) a_tm1_embeds = torch.stack(a_tm1_embeds) inputs = [a_tm1_embeds] if args.no_input_feed is False: inputs.append(att_tm1) if args.no_parent_production_embed is False: # frontier production frontier_prods = [hyp.frontier_node.production for hyp in hypotheses] frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor( [self.grammar.prod2id[prod] for prod in frontier_prods]))) inputs.append(frontier_prod_embeds) if args.no_parent_field_embed is False: # frontier field frontier_fields = [hyp.frontier_field.field for hyp in hypotheses] frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([ self.grammar.field2id[field] for field in frontier_fields]))) inputs.append(frontier_field_embeds) if args.no_parent_field_type_embed is False: # frontier field type frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses] frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([ self.grammar.type2id[type] for type in frontier_field_types]))) inputs.append(frontier_field_type_embeds) # parent states if args.no_parent_state is False: p_ts = [hyp.frontier_node.created_time for hyp in hypotheses] parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)]) parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)]) if args.lstm == 'parent_feed': h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells) else: inputs.append(parent_states) x = torch.cat(inputs, dim=-1) if args.lstm == 'lstm_with_dropout': self.decoder_lstm.set_dropout_masks(hyp_num) (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings, exp_src_encodings_att_linear, src_token_mask=None) # Variable(batch_size, grammar_size) # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1)) apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1) # Variable(batch_size, src_sent_len) primitive_copy_prob = self.src_pointer_net(src_encodings, None, att_t.unsqueeze(0)).squeeze(0) # Variable(batch_size, primitive_vocab_size) gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1) # Variable(batch_size, 2) primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(1) * gen_from_vocab_prob if src_unk_pos_list: primitive_prob[:, primitive_vocab.unk_id] = 1.e-10 gentoken_prev_hyp_ids = [] gentoken_new_hyp_unks = [] gentoken_copy_infos = [] applyrule_new_hyp_scores = [] applyrule_new_hyp_prod_ids = [] applyrule_prev_hyp_ids = [] for hyp_id, hyp in enumerate(hypotheses): # generate new continuations action_types = self.transition_system.get_valid_continuation_types(hyp) for action_type in action_types: if action_type == ApplyRuleAction: productions = self.transition_system.get_valid_continuating_productions(hyp) for production in productions: prod_id = self.grammar.prod2id[production] prod_score = apply_rule_log_prob[hyp_id, prod_id].data[0] # print(type(hyp.score),type(prod_score)) new_hyp_score = hyp.score + prod_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(prod_id) applyrule_prev_hyp_ids.append(hyp_id) elif action_type == ReduceAction: action_score = apply_rule_log_prob[hyp_id, len(self.grammar)].data[0] new_hyp_score = hyp.score + action_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(len(self.grammar)) applyrule_prev_hyp_ids.append(hyp_id) else: # GenToken action gentoken_prev_hyp_ids.append(hyp_id) hyp_copy_info = dict() # of (token_pos, copy_prob) # first, we compute copy probabilities for tokens in the source sentence for token_pos, token_vocab_id in enumerate(src_token_vocab_ids): if args.no_copy is False and token_vocab_id != -1 and token_vocab_id != primitive_vocab.unk_id: p_copy = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, token_pos] primitive_prob[hyp_id, token_vocab_id] = primitive_prob[hyp_id, token_vocab_id] + p_copy token = src_sent[token_pos] hyp_copy_info[token] = (token_pos, p_copy.data[0]) # second, add the probability of copying the most probable unk word if args.no_copy is False and src_unk_pos_list: unk_pos = primitive_copy_prob[hyp_id][src_unk_pos_list].data.cpu().numpy().argmax() unk_pos = src_unk_pos_list[unk_pos] token = src_sent[unk_pos] gentoken_new_hyp_unks.append(token) unk_copy_score = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, unk_pos] primitive_prob[hyp_id, primitive_vocab.unk_id] = unk_copy_score hyp_copy_info[token] = (unk_pos, unk_copy_score.data[0]) gentoken_copy_infos.append(hyp_copy_info) new_hyp_scores = None if applyrule_new_hyp_scores: new_hyp_scores = Variable(self.new_tensor(applyrule_new_hyp_scores)) if gentoken_prev_hyp_ids: primitive_log_prob = torch.log(primitive_prob) gen_token_new_hyp_scores = (hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1) if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores else: new_hyp_scores = torch.cat([new_hyp_scores, gen_token_new_hyp_scores]) top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores, k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses))) live_hyp_ids = [] new_hypotheses = [] for new_hyp_score, new_hyp_pos in zip(top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()): # for new_hyp_score, new_hyp_pos in zip(top_new_hyp_scores.data.cuda(), top_new_hyp_pos.data.cuda()): action_info = ActionInfo() if new_hyp_pos < len(applyrule_new_hyp_scores): # it's an ApplyRule or Reduce action prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos] prev_hyp = hypotheses[prev_hyp_id] prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos] # ApplyRule action if prod_id < len(self.grammar): production = self.grammar.id2prod[prod_id] action = ApplyRuleAction(production) # Reduce action else: action = ReduceAction() else: # it's a GenToken action token_id = (new_hyp_pos - len(applyrule_new_hyp_scores)) % primitive_prob.size(1) k = (new_hyp_pos - len(applyrule_new_hyp_scores)) // primitive_prob.size(1) # try: copy_info = gentoken_copy_infos[k] prev_hyp_id = gentoken_prev_hyp_ids[k] prev_hyp = hypotheses[prev_hyp_id] # except: # print('k=%d' % k, file=sys.stderr) # print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr) # print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr) # print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr) # print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr) # print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr) # print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr) # print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr) # print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr) # print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr) # # torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin') # # # exit(-1) # raise ValueError() if token_id == primitive_vocab.unk_id: if gentoken_new_hyp_unks: token = gentoken_new_hyp_unks[k] else: token = primitive_vocab.id2word[primitive_vocab.unk_id] else: token = primitive_vocab.id2word[token_id] action = GenTokenAction(token) if token in copy_info: action_info.copy_from_src = True action_info.src_token_position = copy_info[token][0] action_info.action = action action_info.t = t if t > 0: action_info.parent_t = prev_hyp.frontier_node.created_time action_info.frontier_prod = prev_hyp.frontier_node.production action_info.frontier_field = prev_hyp.frontier_field.field new_hyp = prev_hyp.clone_and_apply_action_info(action_info) new_hyp.score = new_hyp_score #advoid none new hyp if new_hyp.completed: # print('What happened!') # print(new_hyp.actions) # print(len(new_hyp.actions)) if len(new_hyp.actions) != 2: completed_hypotheses.append(new_hyp) else: new_hypotheses.append(new_hyp) live_hyp_ids.append(prev_hyp_id) if live_hyp_ids: hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids] h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) att_tm1 = att_t[live_hyp_ids] hypotheses = new_hypotheses hyp_scores = Variable(self.new_tensor([hyp.score for hyp in hypotheses])) t += 1 else: break completed_hypotheses.sort(key=lambda hyp: -hyp.score) return completed_hypotheses
def parse(self, question, context, beam_size=5): table = context args = self.args src_sent_var = nn_utils.to_input_variable([question], self.vocab.source, cuda=self.args.cuda, training=False) utterance_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(question)]) dec_init_vec = self.init_decoder_state(last_state, last_cell) column_word_encodings, table_header_encoding, table_header_mask = self.encode_table_header([table]) h_tm1 = dec_init_vec # (batch_size, query_len, hidden_size) utterance_encodings_att_linear = self.att_src_linear(utterance_encodings) zero_action_embed = Variable(self.new_tensor(self.args.action_embed_size).zero_()) t = 0 hypotheses = [DecodeHypothesis()] hyp_states = [[]] completed_hypotheses = [] while len(completed_hypotheses) < beam_size and t < self.args.decode_max_time_step: hyp_num = len(hypotheses) # (hyp_num, src_sent_len, hidden_size * 2) exp_src_encodings = utterance_encodings.expand(hyp_num, utterance_encodings.size(1), utterance_encodings.size(2)) # (hyp_num, src_sent_len, hidden_size) exp_src_encodings_att_linear = utterance_encodings_att_linear.expand(hyp_num, utterance_encodings_att_linear.size(1), utterance_encodings_att_linear.size(2)) # x: [prev_action, parent_production_embed, parent_field_embed, parent_field_type_embed, parent_action_state] if t == 0: x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True) if args.no_parent_field_type_embed is False: offset = args.action_embed_size # prev_action offset += args.hidden_size * (not args.no_input_feed) offset += args.action_embed_size * (not args.no_parent_production_embed) offset += args.field_embed_size * (not args.no_parent_field_embed) x[0, offset: offset + args.type_embed_size] = \ self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]] else: a_tm1_embeds = [] for e_id, hyp in enumerate(hypotheses): action_tm1 = hyp.actions[-1] if action_tm1: if isinstance(action_tm1, ApplyRuleAction): a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]] elif isinstance(action_tm1, ReduceAction): a_tm1_embed = self.production_embed.weight[len(self.grammar)] elif isinstance(action_tm1, WikiSqlSelectColumnAction): a_tm1_embed = self.column_rnn_input(table_header_encoding[0, action_tm1.column_id]) elif isinstance(action_tm1, GenTokenAction): a_tm1_embed = self.src_embed.weight[self.vocab.source[action_tm1.token]] else: raise ValueError('unknown action %s' % action_tm1) else: a_tm1_embed = zero_action_embed a_tm1_embeds.append(a_tm1_embed) a_tm1_embeds = torch.stack(a_tm1_embeds) inputs = [a_tm1_embeds] if args.no_input_feed is False: inputs.append(att_tm1) if args.no_parent_production_embed is False: # frontier production frontier_prods = [hyp.frontier_node.production for hyp in hypotheses] frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor( [self.grammar.prod2id[prod] for prod in frontier_prods]))) inputs.append(frontier_prod_embeds) if args.no_parent_field_embed is False: # frontier field frontier_fields = [hyp.frontier_field.field for hyp in hypotheses] frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([ self.grammar.field2id[field] for field in frontier_fields]))) inputs.append(frontier_field_embeds) if args.no_parent_field_type_embed is False: # frontier field type frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses] frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([ self.grammar.type2id[type] for type in frontier_field_types]))) inputs.append(frontier_field_type_embeds) # parent states if args.no_parent_state is False: p_ts = [hyp.frontier_node.created_time for hyp in hypotheses] parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)]) parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)]) if args.lstm == 'parent_feed': h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells) else: inputs.append(parent_states) x = torch.cat(inputs, dim=-1) (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings, exp_src_encodings_att_linear, src_token_mask=None) # ApplyRule action probability # (batch_size, grammar_size) apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1) # column attention # (batch_size, max_head_num) column_attention_weights = self.column_pointer_net(table_header_encoding, table_header_mask, att_t.unsqueeze(0)).squeeze(0) column_selection_log_prob = torch.log(column_attention_weights) # (batch_size, 2) primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1) # primitive copy prob # (batch_size, src_token_num) primitive_copy_prob = self.src_pointer_net(utterance_encodings, None, att_t.unsqueeze(0)).squeeze(0) # (batch_size, primitive_vocab_size) primitive_gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1) new_hyp_meta = [] for hyp_id, hyp in enumerate(hypotheses): # generate new continuations action_types = self.transition_system.get_valid_continuation_types(hyp) for action_type in action_types: if action_type == ApplyRuleAction: productions = self.transition_system.get_valid_continuating_productions(hyp) for production in productions: prod_id = self.grammar.prod2id[production] prod_score = apply_rule_log_prob[hyp_id, prod_id] new_hyp_score = hyp.score + prod_score meta_entry = {'action_type': 'apply_rule', 'prod_id': prod_id, 'score': prod_score, 'new_hyp_score': new_hyp_score, 'prev_hyp_id': hyp_id} new_hyp_meta.append(meta_entry) elif action_type == ReduceAction: action_score = apply_rule_log_prob[hyp_id, len(self.grammar)] new_hyp_score = hyp.score + action_score meta_entry = {'action_type': 'apply_rule', 'prod_id': len(self.grammar), 'score': action_score, 'new_hyp_score': new_hyp_score, 'prev_hyp_id': hyp_id} new_hyp_meta.append(meta_entry) elif action_type == WikiSqlSelectColumnAction: for col_id, column in enumerate(table.header): col_sel_score = column_selection_log_prob[hyp_id, col_id] new_hyp_score = hyp.score + col_sel_score meta_entry = {'action_type': 'sel_col', 'col_id': col_id, 'score': col_sel_score, 'new_hyp_score': new_hyp_score, 'prev_hyp_id': hyp_id} new_hyp_meta.append(meta_entry) elif action_type == GenTokenAction: # remember that we can only copy stuff from the input! # we only copy tokens sequentially!! prev_action = hyp.action_infos[-1].action valid_token_pos_list = [] if type(prev_action) is GenTokenAction and \ not prev_action.is_stop_signal(): token_pos = hyp.action_infos[-1].src_token_position + 1 if token_pos < len(question): valid_token_pos_list = [token_pos] else: valid_token_pos_list = list(range(len(question))) col_id = hyp.frontier_node['col_idx'].value if table.header[col_id].type == 'real': valid_token_pos_list = [i for i in valid_token_pos_list if any(c.isdigit() for c in question[i]) or hyp._value_buffer and question[i] in (',', '.', '-', '%')] p_copies = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id] for token_pos in valid_token_pos_list: token = question[token_pos] p_copy = p_copies[token_pos] score_copy = torch.log(p_copy) meta_entry = {'action_type': 'gen_token', 'token': token, 'token_pos': token_pos, 'score': score_copy, 'new_hyp_score': score_copy + hyp.score, 'prev_hyp_id': hyp_id} new_hyp_meta.append(meta_entry) # add generation probability for </primitive> if hyp._value_buffer: eos_prob = primitive_predictor_prob[hyp_id, 0] * \ primitive_gen_from_vocab_prob[hyp_id, self.vocab.primitive['</primitive>']] eos_score = torch.log(eos_prob) meta_entry = {'action_type': 'gen_token', 'token': '</primitive>', 'score': eos_score, 'new_hyp_score': eos_score + hyp.score, 'prev_hyp_id': hyp_id} new_hyp_meta.append(meta_entry) if not new_hyp_meta: break new_hyp_scores = torch.cat([x['new_hyp_score'] for x in new_hyp_meta]) top_new_hyp_scores, meta_ids = torch.topk(new_hyp_scores, k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses))) live_hyp_ids = [] new_hypotheses = [] for new_hyp_score, meta_id in zip(top_new_hyp_scores.data.cpu(), meta_ids.data.cpu()): action_info = ActionInfo() hyp_meta_entry = new_hyp_meta[meta_id] prev_hyp_id = hyp_meta_entry['prev_hyp_id'] prev_hyp = hypotheses[prev_hyp_id] action_type_str = hyp_meta_entry['action_type'] if action_type_str == 'apply_rule': # ApplyRule action prod_id = hyp_meta_entry['prod_id'] if prod_id < len(self.grammar): production = self.grammar.id2prod[prod_id] action = ApplyRuleAction(production) # Reduce action else: action = ReduceAction() elif action_type_str == 'sel_col': action = WikiSqlSelectColumnAction(hyp_meta_entry['col_id']) else: action = GenTokenAction(hyp_meta_entry['token']) if 'token_pos' in hyp_meta_entry: action_info.copy_from_src = True action_info.src_token_position = hyp_meta_entry['token_pos'] action_info.action = action action_info.t = t if t > 0: action_info.parent_t = prev_hyp.frontier_node.created_time action_info.frontier_prod = prev_hyp.frontier_node.production action_info.frontier_field = prev_hyp.frontier_field.field new_hyp = prev_hyp.clone_and_apply_action_info(action_info) new_hyp.score = new_hyp_score if new_hyp.completed: completed_hypotheses.append(new_hyp) else: new_hypotheses.append(new_hyp) live_hyp_ids.append(prev_hyp_id) if live_hyp_ids: hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids] h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) att_tm1 = att_t[live_hyp_ids] hypotheses = new_hypotheses t += 1 else: break completed_hypotheses.sort(key=lambda hyp: -hyp.score) return completed_hypotheses
def parse(self, src_sent, context=None, beam_size=5): """Perform beam search to infer the target AST given a source utterance Args: src_sent: list of source utterance tokens context: other context used for prediction beam_size: beam size Returns: A list of `DecodeHypothesis`, each representing an AST """ args = self.args primitive_vocab = self.vocab.primitive src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False) # Variable(1, src_sent_len, hidden_size * 2) src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)]) # (1, src_sent_len, hidden_size) src_encodings_att_linear = self.att_src_linear(src_encodings) dec_init_vec = self.init_decoder_state(last_state, last_cell) if args.lstm == 'parent_feed': h_tm1 = dec_init_vec[0], dec_init_vec[1], \ Variable(self.new_tensor(args.hidden_size).zero_()), \ Variable(self.new_tensor(args.hidden_size).zero_()) else: h_tm1 = dec_init_vec zero_action_embed = Variable(self.new_tensor(args.action_embed_size).zero_()) hyp_scores = Variable(self.new_tensor([0.]), volatile=True) src_token_vocab_ids = [primitive_vocab[token] for token in src_sent] src_unk_pos_list = [pos for pos, token_id in enumerate(src_token_vocab_ids) if token_id == primitive_vocab.unk_id] # sometimes a word may appear multi-times in the source, in this case, # we just copy its first appearing position. Therefore we mask the words # appearing second and onwards to -1 token_set = set() for i, tid in enumerate(src_token_vocab_ids): if tid in token_set: src_token_vocab_ids[i] = -1 else: token_set.add(tid) t = 0 hypotheses = [DecodeHypothesis()] hyp_states = [[]] completed_hypotheses = [] while len(completed_hypotheses) < beam_size and t < args.decode_max_time_step: hyp_num = len(hypotheses) # (hyp_num, src_sent_len, hidden_size * 2) exp_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1), src_encodings.size(2)) # (hyp_num, src_sent_len, hidden_size) exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2)) if t == 0: x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True) if args.no_parent_field_type_embed is False: offset = args.action_embed_size # prev_action offset += args.att_vec_size * (not args.no_input_feed) offset += args.action_embed_size * (not args.no_parent_production_embed) offset += args.field_embed_size * (not args.no_parent_field_embed) x[0, offset: offset + args.type_embed_size] = \ self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]] else: actions_tm1 = [hyp.actions[-1] for hyp in hypotheses] a_tm1_embeds = [] for a_tm1 in actions_tm1: if a_tm1: if isinstance(a_tm1, ApplyRuleAction): a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[a_tm1.production]] elif isinstance(a_tm1, ReduceAction): a_tm1_embed = self.production_embed.weight[len(self.grammar)] else: a_tm1_embed = self.primitive_embed.weight[self.vocab.primitive[a_tm1.token]] a_tm1_embeds.append(a_tm1_embed) else: a_tm1_embeds.append(zero_action_embed) a_tm1_embeds = torch.stack(a_tm1_embeds) inputs = [a_tm1_embeds] if args.no_input_feed is False: inputs.append(att_tm1) if args.no_parent_production_embed is False: # frontier production frontier_prods = [hyp.frontier_node.production for hyp in hypotheses] frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor( [self.grammar.prod2id[prod] for prod in frontier_prods]))) inputs.append(frontier_prod_embeds) if args.no_parent_field_embed is False: # frontier field frontier_fields = [hyp.frontier_field.field for hyp in hypotheses] frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([ self.grammar.field2id[field] for field in frontier_fields]))) inputs.append(frontier_field_embeds) if args.no_parent_field_type_embed is False: # frontier field type frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses] frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([ self.grammar.type2id[type] for type in frontier_field_types]))) inputs.append(frontier_field_type_embeds) # parent states if args.no_parent_state is False: p_ts = [hyp.frontier_node.created_time for hyp in hypotheses] parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)]) parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)]) if args.lstm == 'parent_feed': h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells) else: inputs.append(parent_states) x = torch.cat(inputs, dim=-1) if args.lstm == 'lstm_with_dropout': self.decoder_lstm.set_dropout_masks(hyp_num) (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings, exp_src_encodings_att_linear, src_token_mask=None) # Variable(batch_size, grammar_size) # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1)) apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1) # Variable(batch_size, src_sent_len) primitive_copy_prob = self.src_pointer_net(src_encodings, None, att_t.unsqueeze(0)).squeeze(0) # Variable(batch_size, primitive_vocab_size) gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1) # Variable(batch_size, 2) primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(1) * gen_from_vocab_prob if src_unk_pos_list: primitive_prob[:, primitive_vocab.unk_id] = 1.e-10 gentoken_prev_hyp_ids = [] gentoken_new_hyp_unks = [] gentoken_copy_infos = [] applyrule_new_hyp_scores = [] applyrule_new_hyp_prod_ids = [] applyrule_prev_hyp_ids = [] for hyp_id, hyp in enumerate(hypotheses): # generate new continuations action_types = self.transition_system.get_valid_continuation_types(hyp) for action_type in action_types: if action_type == ApplyRuleAction: productions = self.transition_system.get_valid_continuating_productions(hyp) for production in productions: prod_id = self.grammar.prod2id[production] prod_score = apply_rule_log_prob[hyp_id, prod_id].data[0] new_hyp_score = hyp.score + prod_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(prod_id) applyrule_prev_hyp_ids.append(hyp_id) elif action_type == ReduceAction: action_score = apply_rule_log_prob[hyp_id, len(self.grammar)].data[0] new_hyp_score = hyp.score + action_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(len(self.grammar)) applyrule_prev_hyp_ids.append(hyp_id) else: # GenToken action gentoken_prev_hyp_ids.append(hyp_id) hyp_copy_info = dict() # of (token_pos, copy_prob) # first, we compute copy probabilities for tokens in the source sentence for token_pos, token_vocab_id in enumerate(src_token_vocab_ids): if args.no_copy is False and token_vocab_id != -1 and token_vocab_id != primitive_vocab.unk_id: p_copy = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, token_pos] primitive_prob[hyp_id, token_vocab_id] = primitive_prob[hyp_id, token_vocab_id] + p_copy token = src_sent[token_pos] hyp_copy_info[token] = (token_pos, p_copy.data[0]) # second, add the probability of copying the most probable unk word if args.no_copy is False and src_unk_pos_list: unk_pos = primitive_copy_prob[hyp_id][src_unk_pos_list].data.cpu().numpy().argmax() unk_pos = src_unk_pos_list[unk_pos] token = src_sent[unk_pos] gentoken_new_hyp_unks.append(token) unk_copy_score = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, unk_pos] primitive_prob[hyp_id, primitive_vocab.unk_id] = unk_copy_score hyp_copy_info[token] = (unk_pos, unk_copy_score.data[0]) gentoken_copy_infos.append(hyp_copy_info) new_hyp_scores = None if applyrule_new_hyp_scores: new_hyp_scores = Variable(self.new_tensor(applyrule_new_hyp_scores)) if gentoken_prev_hyp_ids: primitive_log_prob = torch.log(primitive_prob) gen_token_new_hyp_scores = (hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1) if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores else: new_hyp_scores = torch.cat([new_hyp_scores, gen_token_new_hyp_scores]) top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores, k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses))) live_hyp_ids = [] new_hypotheses = [] for new_hyp_score, new_hyp_pos in zip(top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()): action_info = ActionInfo() if new_hyp_pos < len(applyrule_new_hyp_scores): # it's an ApplyRule or Reduce action prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos] prev_hyp = hypotheses[prev_hyp_id] prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos] # ApplyRule action if prod_id < len(self.grammar): production = self.grammar.id2prod[prod_id] action = ApplyRuleAction(production) # Reduce action else: action = ReduceAction() else: # it's a GenToken action token_id = (new_hyp_pos - len(applyrule_new_hyp_scores)) % primitive_prob.size(1) k = (new_hyp_pos - len(applyrule_new_hyp_scores)) // primitive_prob.size(1) # try: copy_info = gentoken_copy_infos[k] prev_hyp_id = gentoken_prev_hyp_ids[k] prev_hyp = hypotheses[prev_hyp_id] # except: # print('k=%d' % k, file=sys.stderr) # print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr) # print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr) # print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr) # print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr) # print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr) # print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr) # print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr) # print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr) # print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr) # # torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin') # # # exit(-1) # raise ValueError() if token_id == primitive_vocab.unk_id: if gentoken_new_hyp_unks: token = gentoken_new_hyp_unks[k] else: token = primitive_vocab.id2word[primitive_vocab.unk_id] else: token = primitive_vocab.id2word[token_id] action = GenTokenAction(token) if token in copy_info: action_info.copy_from_src = True action_info.src_token_position = copy_info[token][0] action_info.action = action action_info.t = t if t > 0: action_info.parent_t = prev_hyp.frontier_node.created_time action_info.frontier_prod = prev_hyp.frontier_node.production action_info.frontier_field = prev_hyp.frontier_field.field new_hyp = prev_hyp.clone_and_apply_action_info(action_info) new_hyp.score = new_hyp_score if new_hyp.completed: completed_hypotheses.append(new_hyp) else: new_hypotheses.append(new_hyp) live_hyp_ids.append(prev_hyp_id) if live_hyp_ids: hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids] h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) att_tm1 = att_t[live_hyp_ids] hypotheses = new_hypotheses hyp_scores = Variable(self.new_tensor([hyp.score for hyp in hypotheses])) t += 1 else: break completed_hypotheses.sort(key=lambda hyp: -hyp.score) return completed_hypotheses
def get_action_infos(src_query, tgt_actions, force_copy=False, copy_method='token'): action_infos = [] hyp = Hypothesis() t = 0 while t < len(tgt_actions): action = tgt_actions[t] if type(action) is GenTokenAction: begin_t = t t += 1 while t < len(tgt_actions) and type( tgt_actions[t]) is GenTokenAction: t += 1 end_t = t gen_token_actions = tgt_actions[begin_t:end_t] assert gen_token_actions[-1].is_stop_signal() tokens = [action.token for action in gen_token_actions[:-1]] try: tok_src_start_idx, tok_src_end_idx = find_sub_sequence( src_query, tokens) tok_src_idxs = list(range(tok_src_start_idx, tok_src_end_idx)) except IndexError: print('\tCannot find [%s] in [%s]' % (' '.join(tokens), ' '.join(src_query)), file=sys.stderr) tok_src_idxs = [src_query.index(token) for token in tokens] tok_src_idxs.append(-1) # for </primitive> for tok_src_idx, gen_token_action in zip(tok_src_idxs, gen_token_actions): action_info = ActionInfo(gen_token_action) if not gen_token_action.is_stop_signal(): action_info.copy_from_src = True action_info.src_token_position = tok_src_idx assert src_query[tok_src_idx] == gen_token_action.token if hyp.frontier_node: action_info.parent_t = hyp.frontier_node.created_time action_info.frontier_prod = hyp.frontier_node.production action_info.frontier_field = hyp.frontier_field.field hyp.apply_action(gen_token_action) action_infos.append(action_info) else: action_info = ActionInfo(action) if hyp.frontier_node: action_info.parent_t = hyp.frontier_node.created_time action_info.frontier_prod = hyp.frontier_node.production action_info.frontier_field = hyp.frontier_field.field hyp.apply_action(action) action_infos.append(action_info) t += 1 # for t, action in enumerate(tgt_actions): # action_info = ActionInfo(action) # action_info.t = t # if hyp.frontier_node: # action_info.parent_t = hyp.frontier_node.created_time # action_info.frontier_prod = hyp.frontier_node.production # action_info.frontier_field = hyp.frontier_field.field # # if type(action) is GenTokenAction: # try: # tok_src_idx = src_query.index(str(action.token)) # action_info.copy_from_src = True # action_info.src_token_position = tok_src_idx # except ValueError: # if force_copy and not action.is_stop_signal(): # raise ValueError('cannot copy primitive token %s from source' % action.token) # # hyp.apply_action(action) # action_infos.append(action_info) return action_infos
def parse_sample(self, src_sent): # get samples of q_{\phi}(z|x) by sampling args = self.args sample_size = self.sample_size primitive_vocab = self.vocab.primitive src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False) # Variable(1, src_sent_len, hidden_size * 2) src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)]) # (1, src_sent_len, hidden_size) src_encodings_att_linear = self.att_src_linear(src_encodings) h_tm1 = self.init_decoder_state(last_state, last_cell) zero_action_embed = Variable( self.new_tensor(args.action_embed_size).zero_()) hyp_scores = Variable(self.new_tensor([0.]), volatile=True) src_token_vocab_ids = [primitive_vocab[token] for token in src_sent] src_unk_pos_list = [ pos for pos, token_id in enumerate(src_token_vocab_ids) if token_id == primitive_vocab.unk_id ] # sometimes a word may appear multi-times in the source, in this case, # we just copy its first appearing position. Therefore we mask the words # appearing second and onwards to -1 token_set = set() for i, tid in enumerate(src_token_vocab_ids): if tid in token_set: src_token_vocab_ids[i] = -1 else: token_set.add(tid) completed_hypotheses = [] while len(completed_hypotheses) < sample_size: t = 0 hypotheses = [DecodeHypothesis()] hyp_states = [[]] while t < args.decode_max_time_step: hyp_num = len(hypotheses) # (hyp_num, src_sent_len, hidden_size * 2) exp_src_encodings = src_encodings.expand( hyp_num, src_encodings.size(1), src_encodings.size(2)) # (hyp_num, src_sent_len, hidden_size) exp_src_encodings_att_linear = src_encodings_att_linear.expand( hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2)) if t == 0: x = Variable(self.new_tensor( 1, self.decoder_lstm.input_size).zero_(), volatile=True) offset = args.action_embed_size * 2 + args.field_embed_size x[0, offset:offset + args.type_embed_size] = self.type_embed.weight[ self.grammar.type2id[self.grammar.root_type]] else: actions_tm1 = [hyp.actions[-1] for hyp in hypotheses] a_tm1_embeds = [] for a_tm1 in actions_tm1: if a_tm1: if isinstance(a_tm1, ApplyRuleAction): a_tm1_embed = self.production_embed.weight[ self.grammar.prod2id[a_tm1.production]] elif isinstance(a_tm1, ReduceAction): a_tm1_embed = self.production_embed.weight[len( self.grammar)] else: a_tm1_embed = self.primitive_embed.weight[ self.vocab.primitive[a_tm1.token]] a_tm1_embeds.append(a_tm1_embed) else: a_tm1_embeds.append(zero_action_embed) a_tm1_embeds = torch.stack(a_tm1_embeds) # frontier production frontier_prods = [ hyp.frontier_node.production for hyp in hypotheses ] frontier_prod_embeds = self.production_embed( Variable( self.new_long_tensor([ self.grammar.prod2id[prod] for prod in frontier_prods ]))) # frontier field frontier_fields = [ hyp.frontier_field.field for hyp in hypotheses ] frontier_field_embeds = self.field_embed( Variable( self.new_long_tensor([ self.grammar.field2id[field] for field in frontier_fields ]))) # frontier field type frontier_field_types = [ hyp.frontier_field.type for hyp in hypotheses ] frontier_field_type_embeds = self.type_embed( Variable( self.new_long_tensor([ self.grammar.type2id[type] for type in frontier_field_types ]))) # parent states p_ts = [ hyp.frontier_node.created_time for hyp in hypotheses ] hist_states = torch.stack([ hyp_states[hyp_id][p_t] for hyp_id, p_t in enumerate(p_ts) ]) x = torch.cat([ a_tm1_embeds, att_tm1, frontier_prod_embeds, frontier_field_embeds, frontier_field_type_embeds, hist_states ], dim=-1) if args.lstm == 'lstm_with_dropout': self.decoder_lstm.set_dropout_masks(hyp_num) (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings, exp_src_encodings_att_linear, src_token_mask=None) # Variable(batch_size, grammar_size) apply_rule_log_prob = F.log_softmax( self.production_readout(att_t), dim=-1) # Variable(batch_size, src_sent_len) primitive_copy_prob = self.src_pointer_net( src_encodings, None, att_t.unsqueeze(0)).squeeze(0) # Variable(batch_size, primitive_vocab_size) gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1) # Variable(batch_size, 2) primitive_predictor_prob = F.softmax( self.primitive_predictor(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) primitive_prob = primitive_predictor_prob[:, 0].unsqueeze( 1) * gen_from_vocab_prob if src_unk_pos_list: primitive_prob[:, primitive_vocab.unk_id] = 1.e-10 gentoken_prev_hyp_ids = [] gentoken_new_hyp_unks = [] gentoken_copy_infos = [] applyrule_new_hyp_scores = [] applyrule_new_hyp_prod_ids = [] applyrule_prev_hyp_ids = [] for hyp_id, hyp in enumerate(hypotheses): # generate new continuations action_types = self.transition_system.get_valid_continuation_types( hyp) for action_type in action_types: if action_type == ApplyRuleAction: productions = self.transition_system.get_valid_continuating_productions( hyp) for production in productions: prod_id = self.grammar.prod2id[production] prod_score = apply_rule_log_prob[ hyp_id, prod_id].data[0] new_hyp_score = hyp.score + prod_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(prod_id) applyrule_prev_hyp_ids.append(hyp_id) elif action_type == ReduceAction: action_score = apply_rule_log_prob[ hyp_id, len(self.grammar)].data[0] new_hyp_score = hyp.score + action_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(len( self.grammar)) applyrule_prev_hyp_ids.append(hyp_id) else: # GenToken action gentoken_prev_hyp_ids.append(hyp_id) hyp_copy_info = dict() # of (token_pos, copy_prob) # first, we compute copy probabilities for tokens in the source sentence for token_pos, token_vocab_id in enumerate( src_token_vocab_ids): if token_vocab_id != -1 and token_vocab_id != primitive_vocab.unk_id: p_copy = primitive_predictor_prob[ hyp_id, 1] * primitive_copy_prob[hyp_id, token_pos] primitive_prob[ hyp_id, token_vocab_id] = primitive_prob[ hyp_id, token_vocab_id] + p_copy token = src_sent[token_pos] hyp_copy_info[token] = (token_pos, p_copy.data[0]) # second, add the probability of copying the most probable unk word if src_unk_pos_list: unk_pos = primitive_copy_prob[ hyp_id][src_unk_pos_list].data.cpu().numpy( ).argmax() unk_pos = src_unk_pos_list[unk_pos] token = src_sent[unk_pos] gentoken_new_hyp_unks.append(token) unk_copy_score = primitive_predictor_prob[ hyp_id, 1] * primitive_copy_prob[hyp_id, unk_pos] primitive_prob[ hyp_id, primitive_vocab.unk_id] = unk_copy_score hyp_copy_info[token] = (unk_pos, unk_copy_score.data[0]) gentoken_copy_infos.append(hyp_copy_info) new_hyp_scores = None if applyrule_new_hyp_scores: new_hyp_scores = Variable( self.new_tensor(applyrule_new_hyp_scores)) if gentoken_prev_hyp_ids: primitive_log_prob = torch.log(primitive_prob) gen_token_new_hyp_scores = ( hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1) if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores else: new_hyp_scores = torch.cat( [new_hyp_scores, gen_token_new_hyp_scores]) # top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores, # k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses))) pro = F.softmax(new_hyp_scores, 0) new_hyp_pos = torch.multinomial(pro, 1).data[0] new_hyp_score = new_hyp_scores[new_hyp_pos].data[0] live_hyp_ids = [] new_hypotheses = [] action_info = ActionInfo() if new_hyp_pos < len(applyrule_new_hyp_scores): # it's an ApplyRule or Reduce action prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos] prev_hyp = hypotheses[prev_hyp_id] prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos] # ApplyRule action if prod_id < len(self.grammar): production = self.grammar.id2prod[prod_id] action = ApplyRuleAction(production) # Reduce action else: action = ReduceAction() else: # it's a GenToken action token_id = (new_hyp_pos - len(applyrule_new_hyp_scores) ) % primitive_prob.size(1) k = (new_hyp_pos - len(applyrule_new_hyp_scores) ) // primitive_prob.size(1) # try: copy_info = gentoken_copy_infos[k] prev_hyp_id = gentoken_prev_hyp_ids[k] prev_hyp = hypotheses[prev_hyp_id] if token_id == primitive_vocab.unk_id: if gentoken_new_hyp_unks: token = gentoken_new_hyp_unks[k] else: token = primitive_vocab.id2word[ primitive_vocab.unk_id] else: token = primitive_vocab.id2word[token_id] action = GenTokenAction(token) if token in copy_info: action_info.copy_from_src = True action_info.src_token_position = copy_info[token][0] action_info.action = action action_info.t = t if t > 0: action_info.parent_t = prev_hyp.frontier_node.created_time action_info.frontier_prod = prev_hyp.frontier_node.production action_info.frontier_field = prev_hyp.frontier_field.field new_hyp = prev_hyp.clone_and_apply_action_info(action_info) new_hyp.score = new_hyp_score if new_hyp.completed: completed_hypotheses.append(new_hyp) else: new_hypotheses.append(new_hyp) live_hyp_ids.append(prev_hyp_id) if live_hyp_ids: hyp_states = [ hyp_states[i] + [h_t[i]] for i in live_hyp_ids ] h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) att_tm1 = att_t[live_hyp_ids] hypotheses = new_hypotheses hyp_scores = Variable( self.new_tensor([hyp.score for hyp in hypotheses])) t += 1 else: break return completed_hypotheses
def parse(self, src_sent, context=None, beam_size=5, debug=False): """Perform beam search to infer the target AST given a source utterance Args: src_sent: list of source utterance tokens context: other context used for prediction beam_size: beam size Returns: A list of `DecodeHypothesis`, each representing an AST """ with torch.no_grad(): args = self.args primitive_vocab = self.vocab.primitive src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False) # Variable(1, src_sent_len, hidden_size) src_encodings = self.encode(src_sent_var) zero_action_embed = torch.zeros(args.action_embed_size) hyp_scores = torch.tensor([0.0]) # For computing copy probabilities, we marginalize over tokens with the same surface form # `aggregated_primitive_tokens` stores the position of occurrence of each source token aggregated_primitive_tokens = OrderedDict() for token_pos, token in enumerate(src_sent): aggregated_primitive_tokens.setdefault(token, []).append(token_pos) t = 0 hypotheses = [DecodeHypothesis()] completed_hypotheses = [] while len(completed_hypotheses ) < beam_size and t < args.decode_max_time_step: hyp_num = len(hypotheses) # (hyp_num, src_sent_len, hidden_size) exp_src_encodings = src_encodings.expand( hyp_num, src_encodings.size(1), src_encodings.size(2)) if t == 0: x = torch.zeros(1, self.d_model) parent_ids = np.array([[0]]) if args.no_parent_field_type_embed is False: offset = self.args.action_embed_size # prev_action offset += self.args.action_embed_size * ( not self.args.no_parent_production_embed) offset += self.args.field_embed_size * ( not self.args.no_parent_field_embed) x[0, offset:offset + self.type_embed_size] = self.type_embed( torch.tensor(self.grammar.type2id[ self.grammar.root_type])) x = x.unsqueeze(-2) else: actions_tm1 = [hyp.actions[-1] for hyp in hypotheses] a_tm1_embeds = [] for a_tm1 in actions_tm1: if a_tm1: if isinstance(a_tm1, ApplyRuleAction): a_tm1_embed = self.production_embed( torch.tensor(self.grammar.prod2id[ a_tm1.production])) elif isinstance(a_tm1, ReduceAction): a_tm1_embed = self.production_embed( torch.tensor(len(self.grammar))) else: a_tm1_embed = self.primitive_embed( torch.tensor( self.vocab.primitive[a_tm1.token])) a_tm1_embeds.append(a_tm1_embed) else: a_tm1_embeds.append(zero_action_embed) a_tm1_embeds = torch.stack(a_tm1_embeds) inputs = [a_tm1_embeds] if args.no_parent_production_embed is False: # frontier production frontier_prods = [ hyp.frontier_node.production for hyp in hypotheses ] frontier_prod_embeds = self.production_embed( torch.tensor([ self.grammar.prod2id[prod] for prod in frontier_prods ], dtype=torch.long)) inputs.append(frontier_prod_embeds) if args.no_parent_field_embed is False: # frontier field frontier_fields = [ hyp.frontier_field.field for hyp in hypotheses ] frontier_field_embeds = self.field_embed( torch.tensor([ self.grammar.field2id[field] for field in frontier_fields ], dtype=torch.long)) inputs.append(frontier_field_embeds) if args.no_parent_field_type_embed is False: # frontier field type frontier_field_types = [ hyp.frontier_field.type for hyp in hypotheses ] frontier_field_type_embeds = self.type_embed( torch.tensor([ self.grammar.type2id[type] for type in frontier_field_types ], dtype=torch.long)) inputs.append(frontier_field_type_embeds) x = torch.cat( [x, torch.cat(inputs, dim=-1).unsqueeze(-2)], dim=1) recent_parents = np.array( [[hyp.frontier_node.created_time] if hyp.frontier_node else 0 for hyp in hypotheses]) parent_ids = np.hstack([parent_ids, recent_parents]) src_mask = torch.ones( exp_src_encodings.shape[:-1], dtype=torch.uint8, device=exp_src_encodings.device).unsqueeze(-2) tgt_mask = subsequent_mask(x.shape[-2]) att_t = self.decoder(x, exp_src_encodings, [parent_ids], src_mask, tgt_mask)[:, -1] # Variable(batch_size, grammar_size) # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1)) apply_rule_log_prob = F.log_softmax( self.production_readout(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1) if args.no_copy: primitive_prob = gen_from_vocab_prob else: # Variable(batch_size, src_sent_len) primitive_copy_prob = self.src_pointer_net( src_encodings, None, att_t.unsqueeze(0)).squeeze(0) # Variable(batch_size, 2) primitive_predictor_prob = F.softmax( self.primitive_predictor(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) primitive_prob = primitive_predictor_prob[:, 0].unsqueeze( 1) * gen_from_vocab_prob # if src_unk_pos_list: # primitive_prob[:, primitive_vocab.unk_id] = 1.e-10 gentoken_prev_hyp_ids = [] gentoken_new_hyp_unks = [] applyrule_new_hyp_scores = [] applyrule_new_hyp_prod_ids = [] applyrule_prev_hyp_ids = [] for hyp_id, hyp in enumerate(hypotheses): # generate new continuations action_types = self.transition_system.get_valid_continuation_types( hyp) for action_type in action_types: if action_type == ApplyRuleAction: productions = self.transition_system.get_valid_continuating_productions( hyp) for production in productions: prod_id = self.grammar.prod2id[production] prod_score = apply_rule_log_prob[ hyp_id, prod_id].item() new_hyp_score = hyp.score + prod_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(prod_id) applyrule_prev_hyp_ids.append(hyp_id) elif action_type == ReduceAction: action_score = apply_rule_log_prob[ hyp_id, len(self.grammar)].item() new_hyp_score = hyp.score + action_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(len( self.grammar)) applyrule_prev_hyp_ids.append(hyp_id) else: # GenToken action gentoken_prev_hyp_ids.append(hyp_id) hyp_copy_info = dict() # of (token_pos, copy_prob) hyp_unk_copy_info = [] if args.no_copy is False: for (token, token_pos_list ) in aggregated_primitive_tokens.items(): sum_copy_prob = torch.gather( primitive_copy_prob[hyp_id], 0, torch.tensor(token_pos_list, dtype=torch.long)).sum() gated_copy_prob = primitive_predictor_prob[ hyp_id, 1] * sum_copy_prob if token in primitive_vocab: token_id = primitive_vocab[token] primitive_prob[hyp_id, token_id] = ( primitive_prob[hyp_id, token_id] + gated_copy_prob) hyp_copy_info[token] = ( token_pos_list, gated_copy_prob.item()) else: hyp_unk_copy_info.append({ "token": token, "token_pos_list": token_pos_list, "copy_prob": gated_copy_prob.item(), }) if args.no_copy is False and len( hyp_unk_copy_info) > 0: unk_i = np.array([ x["copy_prob"] for x in hyp_unk_copy_info ]).argmax() token = hyp_unk_copy_info[unk_i]["token"] primitive_prob[hyp_id, primitive_vocab. unk_id] = hyp_unk_copy_info[ unk_i]["copy_prob"] gentoken_new_hyp_unks.append(token) hyp_copy_info[token] = ( hyp_unk_copy_info[unk_i]["token_pos_list"], hyp_unk_copy_info[unk_i]["copy_prob"], ) new_hyp_scores = None if applyrule_new_hyp_scores: new_hyp_scores = torch.tensor(applyrule_new_hyp_scores) if gentoken_prev_hyp_ids: primitive_log_prob = torch.log(primitive_prob) gen_token_new_hyp_scores = ( hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1) if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores else: new_hyp_scores = torch.cat( [new_hyp_scores, gen_token_new_hyp_scores]) top_new_hyp_scores, top_new_hyp_pos = torch.topk( new_hyp_scores, k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses))) live_hyp_ids = [] new_hypotheses = [] for new_hyp_score, new_hyp_pos in zip( top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()): action_info = ActionInfo() if new_hyp_pos < len(applyrule_new_hyp_scores): # it's an ApplyRule or Reduce action prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos] prev_hyp = hypotheses[prev_hyp_id] prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos] # ApplyRule action if prod_id < len(self.grammar): production = self.grammar.id2prod[prod_id] action = ApplyRuleAction(production) # Reduce action else: action = ReduceAction() else: # it's a GenToken action token_id = int( (new_hyp_pos - len(applyrule_new_hyp_scores)) % primitive_prob.size(1)) k = (new_hyp_pos - len(applyrule_new_hyp_scores) ) // primitive_prob.size(1) # try: # copy_info = gentoken_copy_infos[k] prev_hyp_id = gentoken_prev_hyp_ids[k] prev_hyp = hypotheses[prev_hyp_id] # except: # print('k=%d' % k, file=sys.stderr) # print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr) # print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr) # print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr) # print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr) # print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr) # print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr) # print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr) # print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr) # print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr) # # torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin') # # # exit(-1) # raise ValueError() if token_id == int(primitive_vocab.unk_id): if gentoken_new_hyp_unks: token = gentoken_new_hyp_unks[k] else: token = primitive_vocab.id2word[ primitive_vocab.unk_id] else: token = primitive_vocab.id2word[token_id] action = GenTokenAction(token) if token in aggregated_primitive_tokens: action_info.copy_from_src = True action_info.src_token_position = aggregated_primitive_tokens[ token] if debug: action_info.gen_copy_switch = ( "n/a" if args.no_copy else primitive_predictor_prob[ prev_hyp_id, :].log().cpu().data.numpy()) action_info.in_vocab = token in primitive_vocab action_info.gen_token_prob = ( gen_from_vocab_prob[prev_hyp_id, token_id].log().cpu(). item() if token in primitive_vocab else "n/a") action_info.copy_token_prob = ( torch.gather( primitive_copy_prob[prev_hyp_id], 0, torch.tensor( action_info.src_token_position, dtype=torch.long, device=self.device), ).sum().log().cpu().item() if args.no_copy is False and action_info.copy_from_src else "n/a") action_info.action = action action_info.t = t if t > 0: action_info.parent_t = prev_hyp.frontier_node.created_time action_info.frontier_prod = prev_hyp.frontier_node.production action_info.frontier_field = prev_hyp.frontier_field.field if debug: action_info.action_prob = new_hyp_score - prev_hyp.score new_hyp = prev_hyp.clone_and_apply_action_info(action_info) new_hyp.score = new_hyp_score if new_hyp.completed: completed_hypotheses.append(new_hyp) else: new_hypotheses.append(new_hyp) live_hyp_ids.append(prev_hyp_id) if live_hyp_ids: x = x[live_hyp_ids] parent_ids = parent_ids[live_hyp_ids] hypotheses = new_hypotheses hyp_scores = torch.tensor( [hyp.score for hyp in hypotheses]) t += 1 else: break completed_hypotheses.sort(key=lambda hyp: -hyp.score) return completed_hypotheses