def parse(self, src_sent, context=None, beam_size=5, debug=False): """Perform beam search to infer the target AST given a source utterance Args: src_sent: list of source utterance tokens context: other context used for prediction beam_size: beam size Returns: A list of `DecodeHypothesis`, each representing an AST """ args = self.args primitive_vocab = self.vocab.primitive T = torch.cuda if args.cuda else torch src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False) # Variable(1, src_sent_len, hidden_size * 2) ##########src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)]) src_encodings, last_state = self.encode(src_sent_var, [len(src_sent)]) # (1, src_sent_len, hidden_size) src_encodings_att_linear = self.att_src_linear(src_encodings) #######dec_init_vec = self.init_decoder_state(last_state, last_cell) dec_init_vec = self.init_decoder_state(last_state) ##### if args.lstm == 'parent_feed': h_tm1 = dec_init_vec[0], dec_init_vec[1], \ Variable(self.new_tensor(args.hidden_size).zero_()), \ Variable(self.new_tensor(args.hidden_size).zero_()) else: h_tm1 = dec_init_vec zero_action_embed = Variable( self.new_tensor(args.action_embed_size).zero_()) with torch.no_grad(): hyp_scores = Variable(self.new_tensor([0.])) # For computing copy probabilities, we marginalize over tokens with the same surface form # `aggregated_primitive_tokens` stores the position of occurrence of each source token aggregated_primitive_tokens = OrderedDict() for token_pos, token in enumerate(src_sent): aggregated_primitive_tokens.setdefault(token, []).append(token_pos) t = 0 hypotheses = [DecodeHypothesis()] hyp_states = [[]] completed_hypotheses = [] while len(completed_hypotheses ) < beam_size and t < args.decode_max_time_step: hyp_num = len(hypotheses) # (hyp_num, src_sent_len, hidden_size * 2) exp_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1), src_encodings.size(2)) # (hyp_num, src_sent_len, hidden_size) exp_src_encodings_att_linear = src_encodings_att_linear.expand( hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2)) if t == 0: with torch.no_grad(): x = Variable( self.new_tensor(1, self.decoder_lstm.input_size).zero_()) if args.no_parent_field_type_embed is False: offset = args.action_embed_size # prev_action offset += args.att_vec_size * (not args.no_input_feed) offset += args.action_embed_size * ( not args.no_parent_production_embed) offset += args.field_embed_size * ( not args.no_parent_field_embed) x[0, offset: offset + args.type_embed_size] = \ self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]] else: actions_tm1 = [hyp.actions[-1] for hyp in hypotheses] a_tm1_embeds = [] for a_tm1 in actions_tm1: if a_tm1: if isinstance(a_tm1, ApplyRuleAction): a_tm1_embed = self.production_embed.weight[ self.grammar.prod2id[a_tm1.production]] elif isinstance(a_tm1, ReduceAction): a_tm1_embed = self.production_embed.weight[len( self.grammar)] else: a_tm1_embed = self.primitive_embed.weight[ self.vocab.primitive[a_tm1.token]] a_tm1_embeds.append(a_tm1_embed) else: a_tm1_embeds.append(zero_action_embed) a_tm1_embeds = torch.stack(a_tm1_embeds) inputs = [a_tm1_embeds] if args.no_input_feed is False: inputs.append(att_tm1) if args.no_parent_production_embed is False: # frontier production frontier_prods = [ hyp.frontier_node.production for hyp in hypotheses ] frontier_prod_embeds = self.production_embed( Variable( self.new_long_tensor([ self.grammar.prod2id[prod] for prod in frontier_prods ]))) inputs.append(frontier_prod_embeds) if args.no_parent_field_embed is False: # frontier field frontier_fields = [ hyp.frontier_field.field for hyp in hypotheses ] frontier_field_embeds = self.field_embed( Variable( self.new_long_tensor([ self.grammar.field2id[field] for field in frontier_fields ]))) inputs.append(frontier_field_embeds) if args.no_parent_field_type_embed is False: # frontier field type frontier_field_types = [ hyp.frontier_field.type for hyp in hypotheses ] frontier_field_type_embeds = self.type_embed( Variable( self.new_long_tensor([ self.grammar.type2id[type] for type in frontier_field_types ]))) inputs.append(frontier_field_type_embeds) # parent states if args.no_parent_state is False: p_ts = [ hyp.frontier_node.created_time for hyp in hypotheses ] parent_states = torch.stack([ hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts) ]) parent_cells = torch.stack([ hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts) ]) if args.lstm == 'parent_feed': h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells) else: inputs.append(parent_states) x = torch.cat(inputs, dim=-1) (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings, exp_src_encodings_att_linear, src_token_mask=None) # Variable(batch_size, grammar_size) # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1)) apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1) if args.no_copy: primitive_prob = gen_from_vocab_prob else: # Variable(batch_size, src_sent_len) primitive_copy_prob = self.src_pointer_net( src_encodings, None, att_t.unsqueeze(0)).squeeze(0) # Variable(batch_size, 2) primitive_predictor_prob = F.softmax( self.primitive_predictor(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) primitive_prob = primitive_predictor_prob[:, 0].unsqueeze( 1) * gen_from_vocab_prob # if src_unk_pos_list: # primitive_prob[:, primitive_vocab.unk_id] = 1.e-10 gentoken_prev_hyp_ids = [] gentoken_new_hyp_unks = [] applyrule_new_hyp_scores = [] applyrule_new_hyp_prod_ids = [] applyrule_prev_hyp_ids = [] for hyp_id, hyp in enumerate(hypotheses): # generate new continuations action_types = self.transition_system.get_valid_continuation_types( hyp) for action_type in action_types: if action_type == ApplyRuleAction: productions = self.transition_system.get_valid_continuating_productions( hyp) for production in productions: prod_id = self.grammar.prod2id[production] prod_score = apply_rule_log_prob[ hyp_id, prod_id].data.item() new_hyp_score = hyp.score + prod_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(prod_id) applyrule_prev_hyp_ids.append(hyp_id) elif action_type == ReduceAction: action_score = apply_rule_log_prob[ hyp_id, len(self.grammar)].data.item() new_hyp_score = hyp.score + action_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(len(self.grammar)) applyrule_prev_hyp_ids.append(hyp_id) else: # GenToken action gentoken_prev_hyp_ids.append(hyp_id) hyp_copy_info = dict() # of (token_pos, copy_prob) hyp_unk_copy_info = [] if args.no_copy is False: for token, token_pos_list in aggregated_primitive_tokens.items( ): sum_copy_prob = torch.gather( primitive_copy_prob[hyp_id], 0, Variable( T.LongTensor(token_pos_list))).sum() gated_copy_prob = primitive_predictor_prob[ hyp_id, 1] * sum_copy_prob if token in primitive_vocab: token_id = primitive_vocab[token] primitive_prob[ hyp_id, token_id] = primitive_prob[ hyp_id, token_id] + gated_copy_prob hyp_copy_info[token] = ( token_pos_list, gated_copy_prob.data.item()) else: hyp_unk_copy_info.append({ 'token': token, 'token_pos_list': token_pos_list, 'copy_prob': gated_copy_prob.data.item() }) if args.no_copy is False and len( hyp_unk_copy_info) > 0: unk_i = np.array([ x['copy_prob'] for x in hyp_unk_copy_info ]).argmax() token = hyp_unk_copy_info[unk_i]['token'] primitive_prob[ hyp_id, primitive_vocab. unk_id] = hyp_unk_copy_info[unk_i]['copy_prob'] gentoken_new_hyp_unks.append(token) hyp_copy_info[token] = ( hyp_unk_copy_info[unk_i]['token_pos_list'], hyp_unk_copy_info[unk_i]['copy_prob']) new_hyp_scores = None if applyrule_new_hyp_scores: new_hyp_scores = Variable( self.new_tensor(applyrule_new_hyp_scores)) if gentoken_prev_hyp_ids: primitive_log_prob = torch.log(primitive_prob) gen_token_new_hyp_scores = ( hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1) if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores else: new_hyp_scores = torch.cat( [new_hyp_scores, gen_token_new_hyp_scores]) top_new_hyp_scores, top_new_hyp_pos = torch.topk( new_hyp_scores, k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses))) live_hyp_ids = [] new_hypotheses = [] for new_hyp_score, new_hyp_pos in zip( top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()): action_info = ActionInfo() if new_hyp_pos < len(applyrule_new_hyp_scores): # it's an ApplyRule or Reduce action prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos] prev_hyp = hypotheses[prev_hyp_id] prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos] # ApplyRule action if prod_id < len(self.grammar): production = self.grammar.id2prod[prod_id] action = ApplyRuleAction(production) # Reduce action else: action = ReduceAction() else: # it's a GenToken action token_id = (new_hyp_pos - len(applyrule_new_hyp_scores) ) % primitive_prob.size(1) k = (new_hyp_pos - len(applyrule_new_hyp_scores) ) // primitive_prob.size(1) # try: # copy_info = gentoken_copy_infos[k] prev_hyp_id = gentoken_prev_hyp_ids[k] prev_hyp = hypotheses[prev_hyp_id] # except: # print('k=%d' % k, file=sys.stderr) # print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr) # print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr) # print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr) # print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr) # print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr) # print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr) # print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr) # print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr) # print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr) # # torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin') # # # exit(-1) # raise ValueError() if token_id == primitive_vocab.unk_id: if gentoken_new_hyp_unks: token = gentoken_new_hyp_unks[k] else: token = primitive_vocab.id2word( primitive_vocab.unk_id) ###### else: token = primitive_vocab.id2word(token_id.item()) action = GenTokenAction(token) if token in aggregated_primitive_tokens: action_info.copy_from_src = True action_info.src_token_position = aggregated_primitive_tokens[ token] if debug: action_info.gen_copy_switch = 'n/a' if args.no_copy else primitive_predictor_prob[ prev_hyp_id, :].log().cpu().data.numpy() action_info.in_vocab = token in primitive_vocab action_info.gen_token_prob = gen_from_vocab_prob[prev_hyp_id, token_id].log().cpu().data.item() \ if token in primitive_vocab else 'n/a' action_info.copy_token_prob = torch.gather(primitive_copy_prob[prev_hyp_id], 0, Variable(T.LongTensor(action_info.src_token_position))).sum().log().cpu().data.item() \ if args.no_copy is False and action_info.copy_from_src else 'n/a' action_info.action = action action_info.t = t if t > 0: action_info.parent_t = prev_hyp.frontier_node.created_time action_info.frontier_prod = prev_hyp.frontier_node.production action_info.frontier_field = prev_hyp.frontier_field.field if debug: action_info.action_prob = new_hyp_score - prev_hyp.score new_hyp = prev_hyp.clone_and_apply_action_info(action_info) new_hyp.score = new_hyp_score if new_hyp.completed: # add length normalization new_hyp.score /= (t + 1) completed_hypotheses.append(new_hyp) else: new_hypotheses.append(new_hyp) live_hyp_ids.append(prev_hyp_id) if live_hyp_ids: hyp_states = [ hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids ] h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) att_tm1 = att_t[live_hyp_ids] hypotheses = new_hypotheses hyp_scores = Variable( self.new_tensor([hyp.score for hyp in hypotheses])) t += 1 else: break completed_hypotheses.sort(key=lambda hyp: -hyp.score) return completed_hypotheses
def parse(self, src_sent, context=None, beam_size=5, debug=False): """Perform beam search to infer the target AST given a source utterance Args: src_sent: list of source utterance tokens context: other context used for prediction beam_size: beam size Returns: A list of `DecodeHypothesis`, each representing an AST """ with torch.no_grad(): args = self.args primitive_vocab = self.vocab.primitive src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False) # Variable(1, src_sent_len, hidden_size) src_encodings = self.encode(src_sent_var) zero_action_embed = torch.zeros(args.action_embed_size) hyp_scores = torch.tensor([0.0]) # For computing copy probabilities, we marginalize over tokens with the same surface form # `aggregated_primitive_tokens` stores the position of occurrence of each source token aggregated_primitive_tokens = OrderedDict() for token_pos, token in enumerate(src_sent): aggregated_primitive_tokens.setdefault(token, []).append(token_pos) t = 0 hypotheses = [DecodeHypothesis()] completed_hypotheses = [] while len(completed_hypotheses ) < beam_size and t < args.decode_max_time_step: hyp_num = len(hypotheses) # (hyp_num, src_sent_len, hidden_size) exp_src_encodings = src_encodings.expand( hyp_num, src_encodings.size(1), src_encodings.size(2)) if t == 0: x = torch.zeros(1, self.d_model) parent_ids = np.array([[0]]) if args.no_parent_field_type_embed is False: offset = self.args.action_embed_size # prev_action offset += self.args.action_embed_size * ( not self.args.no_parent_production_embed) offset += self.args.field_embed_size * ( not self.args.no_parent_field_embed) x[0, offset:offset + self.type_embed_size] = self.type_embed( torch.tensor(self.grammar.type2id[ self.grammar.root_type])) x = x.unsqueeze(-2) else: actions_tm1 = [hyp.actions[-1] for hyp in hypotheses] a_tm1_embeds = [] for a_tm1 in actions_tm1: if a_tm1: if isinstance(a_tm1, ApplyRuleAction): a_tm1_embed = self.production_embed( torch.tensor(self.grammar.prod2id[ a_tm1.production])) elif isinstance(a_tm1, ReduceAction): a_tm1_embed = self.production_embed( torch.tensor(len(self.grammar))) else: a_tm1_embed = self.primitive_embed( torch.tensor( self.vocab.primitive[a_tm1.token])) a_tm1_embeds.append(a_tm1_embed) else: a_tm1_embeds.append(zero_action_embed) a_tm1_embeds = torch.stack(a_tm1_embeds) inputs = [a_tm1_embeds] if args.no_parent_production_embed is False: # frontier production frontier_prods = [ hyp.frontier_node.production for hyp in hypotheses ] frontier_prod_embeds = self.production_embed( torch.tensor([ self.grammar.prod2id[prod] for prod in frontier_prods ], dtype=torch.long)) inputs.append(frontier_prod_embeds) if args.no_parent_field_embed is False: # frontier field frontier_fields = [ hyp.frontier_field.field for hyp in hypotheses ] frontier_field_embeds = self.field_embed( torch.tensor([ self.grammar.field2id[field] for field in frontier_fields ], dtype=torch.long)) inputs.append(frontier_field_embeds) if args.no_parent_field_type_embed is False: # frontier field type frontier_field_types = [ hyp.frontier_field.type for hyp in hypotheses ] frontier_field_type_embeds = self.type_embed( torch.tensor([ self.grammar.type2id[type] for type in frontier_field_types ], dtype=torch.long)) inputs.append(frontier_field_type_embeds) x = torch.cat( [x, torch.cat(inputs, dim=-1).unsqueeze(-2)], dim=1) recent_parents = np.array( [[hyp.frontier_node.created_time] if hyp.frontier_node else 0 for hyp in hypotheses]) parent_ids = np.hstack([parent_ids, recent_parents]) src_mask = torch.ones( exp_src_encodings.shape[:-1], dtype=torch.uint8, device=exp_src_encodings.device).unsqueeze(-2) tgt_mask = subsequent_mask(x.shape[-2]) att_t = self.decoder(x, exp_src_encodings, [parent_ids], src_mask, tgt_mask)[:, -1] # Variable(batch_size, grammar_size) # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1)) apply_rule_log_prob = F.log_softmax( self.production_readout(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1) if args.no_copy: primitive_prob = gen_from_vocab_prob else: # Variable(batch_size, src_sent_len) primitive_copy_prob = self.src_pointer_net( src_encodings, None, att_t.unsqueeze(0)).squeeze(0) # Variable(batch_size, 2) primitive_predictor_prob = F.softmax( self.primitive_predictor(att_t), dim=-1) # Variable(batch_size, primitive_vocab_size) primitive_prob = primitive_predictor_prob[:, 0].unsqueeze( 1) * gen_from_vocab_prob # if src_unk_pos_list: # primitive_prob[:, primitive_vocab.unk_id] = 1.e-10 gentoken_prev_hyp_ids = [] gentoken_new_hyp_unks = [] applyrule_new_hyp_scores = [] applyrule_new_hyp_prod_ids = [] applyrule_prev_hyp_ids = [] for hyp_id, hyp in enumerate(hypotheses): # generate new continuations action_types = self.transition_system.get_valid_continuation_types( hyp) for action_type in action_types: if action_type == ApplyRuleAction: productions = self.transition_system.get_valid_continuating_productions( hyp) for production in productions: prod_id = self.grammar.prod2id[production] prod_score = apply_rule_log_prob[ hyp_id, prod_id].item() new_hyp_score = hyp.score + prod_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(prod_id) applyrule_prev_hyp_ids.append(hyp_id) elif action_type == ReduceAction: action_score = apply_rule_log_prob[ hyp_id, len(self.grammar)].item() new_hyp_score = hyp.score + action_score applyrule_new_hyp_scores.append(new_hyp_score) applyrule_new_hyp_prod_ids.append(len( self.grammar)) applyrule_prev_hyp_ids.append(hyp_id) else: # GenToken action gentoken_prev_hyp_ids.append(hyp_id) hyp_copy_info = dict() # of (token_pos, copy_prob) hyp_unk_copy_info = [] if args.no_copy is False: for (token, token_pos_list ) in aggregated_primitive_tokens.items(): sum_copy_prob = torch.gather( primitive_copy_prob[hyp_id], 0, torch.tensor(token_pos_list, dtype=torch.long)).sum() gated_copy_prob = primitive_predictor_prob[ hyp_id, 1] * sum_copy_prob if token in primitive_vocab: token_id = primitive_vocab[token] primitive_prob[hyp_id, token_id] = ( primitive_prob[hyp_id, token_id] + gated_copy_prob) hyp_copy_info[token] = ( token_pos_list, gated_copy_prob.item()) else: hyp_unk_copy_info.append({ "token": token, "token_pos_list": token_pos_list, "copy_prob": gated_copy_prob.item(), }) if args.no_copy is False and len( hyp_unk_copy_info) > 0: unk_i = np.array([ x["copy_prob"] for x in hyp_unk_copy_info ]).argmax() token = hyp_unk_copy_info[unk_i]["token"] primitive_prob[hyp_id, primitive_vocab. unk_id] = hyp_unk_copy_info[ unk_i]["copy_prob"] gentoken_new_hyp_unks.append(token) hyp_copy_info[token] = ( hyp_unk_copy_info[unk_i]["token_pos_list"], hyp_unk_copy_info[unk_i]["copy_prob"], ) new_hyp_scores = None if applyrule_new_hyp_scores: new_hyp_scores = torch.tensor(applyrule_new_hyp_scores) if gentoken_prev_hyp_ids: primitive_log_prob = torch.log(primitive_prob) gen_token_new_hyp_scores = ( hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1) if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores else: new_hyp_scores = torch.cat( [new_hyp_scores, gen_token_new_hyp_scores]) top_new_hyp_scores, top_new_hyp_pos = torch.topk( new_hyp_scores, k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses))) live_hyp_ids = [] new_hypotheses = [] for new_hyp_score, new_hyp_pos in zip( top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()): action_info = ActionInfo() if new_hyp_pos < len(applyrule_new_hyp_scores): # it's an ApplyRule or Reduce action prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos] prev_hyp = hypotheses[prev_hyp_id] prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos] # ApplyRule action if prod_id < len(self.grammar): production = self.grammar.id2prod[prod_id] action = ApplyRuleAction(production) # Reduce action else: action = ReduceAction() else: # it's a GenToken action token_id = int( (new_hyp_pos - len(applyrule_new_hyp_scores)) % primitive_prob.size(1)) k = (new_hyp_pos - len(applyrule_new_hyp_scores) ) // primitive_prob.size(1) # try: # copy_info = gentoken_copy_infos[k] prev_hyp_id = gentoken_prev_hyp_ids[k] prev_hyp = hypotheses[prev_hyp_id] # except: # print('k=%d' % k, file=sys.stderr) # print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr) # print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr) # print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr) # print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr) # print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr) # print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr) # print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr) # print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr) # print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr) # # torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin') # # # exit(-1) # raise ValueError() if token_id == int(primitive_vocab.unk_id): if gentoken_new_hyp_unks: token = gentoken_new_hyp_unks[k] else: token = primitive_vocab.id2word[ primitive_vocab.unk_id] else: token = primitive_vocab.id2word[token_id] action = GenTokenAction(token) if token in aggregated_primitive_tokens: action_info.copy_from_src = True action_info.src_token_position = aggregated_primitive_tokens[ token] if debug: action_info.gen_copy_switch = ( "n/a" if args.no_copy else primitive_predictor_prob[ prev_hyp_id, :].log().cpu().data.numpy()) action_info.in_vocab = token in primitive_vocab action_info.gen_token_prob = ( gen_from_vocab_prob[prev_hyp_id, token_id].log().cpu(). item() if token in primitive_vocab else "n/a") action_info.copy_token_prob = ( torch.gather( primitive_copy_prob[prev_hyp_id], 0, torch.tensor( action_info.src_token_position, dtype=torch.long, device=self.device), ).sum().log().cpu().item() if args.no_copy is False and action_info.copy_from_src else "n/a") action_info.action = action action_info.t = t if t > 0: action_info.parent_t = prev_hyp.frontier_node.created_time action_info.frontier_prod = prev_hyp.frontier_node.production action_info.frontier_field = prev_hyp.frontier_field.field if debug: action_info.action_prob = new_hyp_score - prev_hyp.score new_hyp = prev_hyp.clone_and_apply_action_info(action_info) new_hyp.score = new_hyp_score if new_hyp.completed: completed_hypotheses.append(new_hyp) else: new_hypotheses.append(new_hyp) live_hyp_ids.append(prev_hyp_id) if live_hyp_ids: x = x[live_hyp_ids] parent_ids = parent_ids[live_hyp_ids] hypotheses = new_hypotheses hyp_scores = torch.tensor( [hyp.score for hyp in hypotheses]) t += 1 else: break completed_hypotheses.sort(key=lambda hyp: -hyp.score) return completed_hypotheses