def get_primitive_field_actions(self, realized_field): actions = [] if realized_field.value is not None: if realized_field.cardinality == 'multiple': # expr -> Global(identifier* names) field_values = realized_field.value else: field_values = [realized_field.value] tokens = [] if realized_field.type.name == 'string': for field_val in field_values: tokens.extend(field_val.split(' ') + ['</primitive>']) else: for field_val in field_values: tokens.append(field_val) for tok in tokens: actions.append(GenTokenAction(tok)) elif realized_field.type.name == 'singleton' and realized_field.value is None: # singleton can be None actions.append(GenTokenAction('None')) return actions
def parse(self, question, context, beam_size=5): table = context args = self.args src_sent_var = nn_utils.to_input_variable([question], self.vocab.source, cuda=self.args.cuda, training=False) utterance_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(question)]) dec_init_vec = self.init_decoder_state(last_state, last_cell) column_word_encodings, table_header_encoding, table_header_mask = self.encode_table_header([table]) h_tm1 = dec_init_vec # (batch_size, query_len, hidden_size) utterance_encodings_att_linear = self.att_src_linear(utterance_encodings) zero_action_embed = Variable(self.new_tensor(self.args.action_embed_size).zero_()) t = 0 hypotheses = [DecodeHypothesis()] hyp_states = [[]] completed_hypotheses = [] while len(completed_hypotheses) < beam_size and t < self.args.decode_max_time_step: hyp_num = len(hypotheses) # (hyp_num, src_sent_len, hidden_size * 2) exp_src_encodings = utterance_encodings.expand(hyp_num, utterance_encodings.size(1), utterance_encodings.size(2)) # (hyp_num, src_sent_len, hidden_size) exp_src_encodings_att_linear = utterance_encodings_att_linear.expand(hyp_num, utterance_encodings_att_linear.size(1), utterance_encodings_att_linear.size(2)) # x: [prev_action, parent_production_embed, parent_field_embed, parent_field_type_embed, parent_action_state] if t == 0: x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True) if args.no_parent_field_type_embed is False: offset = args.action_embed_size # prev_action offset += args.hidden_size * (not args.no_input_feed) offset += args.action_embed_size * (not args.no_parent_production_embed) offset += args.field_embed_size * (not args.no_parent_field_embed) x[0, offset: offset + args.type_embed_size] = \ self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]] else: a_tm1_embeds = [] for e_id, hyp in enumerate(hypotheses): action_tm1 = hyp.actions[-1] if action_tm1: if isinstance(action_tm1, ApplyRuleAction): a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]] elif isinstance(action_tm1, ReduceAction): a_tm1_embed = self.production_embed.weight[len(self.grammar)] elif isinstance(action_tm1, WikiSqlSelectColumnAction): a_tm1_embed = self.column_rnn_input(table_header_encoding[0, action_tm1.column_id]) elif isinstance(action_tm1, GenTokenAction): a_tm1_embed = self.src_embed.weight[self.vocab.source[action_tm1.token]] else: raise ValueError('unknown action %s' % action_tm1) else: a_tm1_embed = zero_action_embed a_tm1_embeds.append(a_tm1_embed) a_tm1_embeds = torch.stack(a_tm1_embeds) inputs = [a_tm1_embeds] if args.no_input_feed is False: inputs.append(att_tm1) if args.no_parent_production_embed is False: # frontier production frontier_prods = [hyp.frontier_node.production for hyp in hypotheses] frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor( [self.grammar.prod2id[prod] for prod in frontier_prods]))) inputs.append(frontier_prod_embeds) if args.no_parent_field_embed is False: # frontier field frontier_fields = [hyp.frontier_field.field for hyp in hypotheses] frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([ self.grammar.field2id[field] for field in frontier_fields]))) inputs.append(frontier_field_embeds) if args.no_parent_field_type_embed is False: # frontier field type frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses] frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([ self.grammar.type2id[type] for type in frontier_field_types]))) inputs.append(frontier_field_type_embeds) # parent states if args.no_parent_state is False: p_ts = [hyp.frontier_node.created_time for hyp in hypotheses] parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)]) parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)]) if args.lstm == 'parent_feed': h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells) else: inputs.append(parent_states) x = torch.cat(inputs, dim=-1) (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings, exp_src_encodings_att_linear, src_token_mask=None) # ApplyRule action probability # (batch_size, grammar_size) apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1) # column attention # (batch_size, max_head_num) column_attention_weights = self.column_pointer_net(table_header_encoding, table_header_mask, att_t.unsqueeze(0)).squeeze(0) column_selection_log_prob = torch.log(column_attention_weights) # (batch_size, 2) primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1) # primitive copy prob # (batch_size, src_token_num) primitive_copy_prob = self.src_pointer_net(utterance_encodings, None, att_t.unsqueeze(0)).squeeze(0) # (batch_size, primitive_vocab_size) primitive_gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1) new_hyp_meta = [] for hyp_id, hyp in enumerate(hypotheses): # generate new continuations action_types = self.transition_system.get_valid_continuation_types(hyp) for action_type in action_types: if action_type == ApplyRuleAction: productions = self.transition_system.get_valid_continuating_productions(hyp) for production in productions: prod_id = self.grammar.prod2id[production] prod_score = apply_rule_log_prob[hyp_id, prod_id] new_hyp_score = hyp.score + prod_score meta_entry = {'action_type': 'apply_rule', 'prod_id': prod_id, 'score': prod_score, 'new_hyp_score': new_hyp_score, 'prev_hyp_id': hyp_id} new_hyp_meta.append(meta_entry) elif action_type == ReduceAction: action_score = apply_rule_log_prob[hyp_id, len(self.grammar)] new_hyp_score = hyp.score + action_score meta_entry = {'action_type': 'apply_rule', 'prod_id': len(self.grammar), 'score': action_score, 'new_hyp_score': new_hyp_score, 'prev_hyp_id': hyp_id} new_hyp_meta.append(meta_entry) elif action_type == WikiSqlSelectColumnAction: for col_id, column in enumerate(table.header): col_sel_score = column_selection_log_prob[hyp_id, col_id] new_hyp_score = hyp.score + col_sel_score meta_entry = {'action_type': 'sel_col', 'col_id': col_id, 'score': col_sel_score, 'new_hyp_score': new_hyp_score, 'prev_hyp_id': hyp_id} new_hyp_meta.append(meta_entry) elif action_type == GenTokenAction: # remember that we can only copy stuff from the input! # we only copy tokens sequentially!! prev_action = hyp.action_infos[-1].action valid_token_pos_list = [] if type(prev_action) is GenTokenAction and \ not prev_action.is_stop_signal(): token_pos = hyp.action_infos[-1].src_token_position + 1 if token_pos < len(question): valid_token_pos_list = [token_pos] else: valid_token_pos_list = list(range(len(question))) col_id = hyp.frontier_node['col_idx'].value if table.header[col_id].type == 'real': valid_token_pos_list = [i for i in valid_token_pos_list if any(c.isdigit() for c in question[i]) or hyp._value_buffer and question[i] in (',', '.', '-', '%')] p_copies = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id] for token_pos in valid_token_pos_list: token = question[token_pos] p_copy = p_copies[token_pos] score_copy = torch.log(p_copy) meta_entry = {'action_type': 'gen_token', 'token': token, 'token_pos': token_pos, 'score': score_copy, 'new_hyp_score': score_copy + hyp.score, 'prev_hyp_id': hyp_id} new_hyp_meta.append(meta_entry) # add generation probability for </primitive> if hyp._value_buffer: eos_prob = primitive_predictor_prob[hyp_id, 0] * \ primitive_gen_from_vocab_prob[hyp_id, self.vocab.primitive['</primitive>']] eos_score = torch.log(eos_prob) meta_entry = {'action_type': 'gen_token', 'token': '</primitive>', 'score': eos_score, 'new_hyp_score': eos_score + hyp.score, 'prev_hyp_id': hyp_id} new_hyp_meta.append(meta_entry) if not new_hyp_meta: break new_hyp_scores = torch.cat([x['new_hyp_score'] for x in new_hyp_meta]) top_new_hyp_scores, meta_ids = torch.topk(new_hyp_scores, k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses))) live_hyp_ids = [] new_hypotheses = [] for new_hyp_score, meta_id in zip(top_new_hyp_scores.data.cpu(), meta_ids.data.cpu()): action_info = ActionInfo() hyp_meta_entry = new_hyp_meta[meta_id] prev_hyp_id = hyp_meta_entry['prev_hyp_id'] prev_hyp = hypotheses[prev_hyp_id] action_type_str = hyp_meta_entry['action_type'] if action_type_str == 'apply_rule': # ApplyRule action prod_id = hyp_meta_entry['prod_id'] if prod_id < len(self.grammar): production = self.grammar.id2prod[prod_id] action = ApplyRuleAction(production) # Reduce action else: action = ReduceAction() elif action_type_str == 'sel_col': action = WikiSqlSelectColumnAction(hyp_meta_entry['col_id']) else: action = GenTokenAction(hyp_meta_entry['token']) if 'token_pos' in hyp_meta_entry: action_info.copy_from_src = True action_info.src_token_position = hyp_meta_entry['token_pos'] action_info.action = action action_info.t = t if t > 0: action_info.parent_t = prev_hyp.frontier_node.created_time action_info.frontier_prod = prev_hyp.frontier_node.production action_info.frontier_field = prev_hyp.frontier_field.field new_hyp = prev_hyp.clone_and_apply_action_info(action_info) new_hyp.score = new_hyp_score if new_hyp.completed: completed_hypotheses.append(new_hyp) else: new_hypotheses.append(new_hyp) live_hyp_ids.append(prev_hyp_id) if live_hyp_ids: hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids] h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) att_tm1 = att_t[live_hyp_ids] hypotheses = new_hypotheses t += 1 else: break completed_hypotheses.sort(key=lambda hyp: -hyp.score) return completed_hypotheses
def get_primitive_field_actions(self, realized_field): assert realized_field.cardinality == 'single' if realized_field.value is not None: return [GenTokenAction(realized_field.value)] else: return []