def get_primitive_field_actions(self, realized_field):
        actions = []
        if realized_field.value is not None:
            if realized_field.cardinality == 'multiple':  # expr -> Global(identifier* names)
                field_values = realized_field.value
            else:
                field_values = [realized_field.value]

            tokens = []
            if realized_field.type.name == 'string':
                for field_val in field_values:
                    tokens.extend(field_val.split(' ') + ['</primitive>'])
            else:
                for field_val in field_values:
                    tokens.append(field_val)

            for tok in tokens:
                actions.append(GenTokenAction(tok))
        elif realized_field.type.name == 'singleton' and realized_field.value is None:
            # singleton can be None
            actions.append(GenTokenAction('None'))

        return actions
Exemple #2
0
    def parse(self, question, context, beam_size=5):
        table = context
        args = self.args
        src_sent_var = nn_utils.to_input_variable([question], self.vocab.source,
                                                  cuda=self.args.cuda, training=False)

        utterance_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(question)])
        dec_init_vec = self.init_decoder_state(last_state, last_cell)

        column_word_encodings, table_header_encoding, table_header_mask = self.encode_table_header([table])

        h_tm1 = dec_init_vec
        # (batch_size, query_len, hidden_size)
        utterance_encodings_att_linear = self.att_src_linear(utterance_encodings)

        zero_action_embed = Variable(self.new_tensor(self.args.action_embed_size).zero_())

        t = 0
        hypotheses = [DecodeHypothesis()]
        hyp_states = [[]]
        completed_hypotheses = []

        while len(completed_hypotheses) < beam_size and t < self.args.decode_max_time_step:
            hyp_num = len(hypotheses)

            # (hyp_num, src_sent_len, hidden_size * 2)
            exp_src_encodings = utterance_encodings.expand(hyp_num, utterance_encodings.size(1), utterance_encodings.size(2))
            # (hyp_num, src_sent_len, hidden_size)
            exp_src_encodings_att_linear = utterance_encodings_att_linear.expand(hyp_num,
                                                                                 utterance_encodings_att_linear.size(1),
                                                                                 utterance_encodings_att_linear.size(2))

            # x: [prev_action, parent_production_embed, parent_field_embed, parent_field_type_embed, parent_action_state]
            if t == 0:
                x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True)

                if args.no_parent_field_type_embed is False:
                    offset = args.action_embed_size  # prev_action
                    offset += args.hidden_size * (not args.no_input_feed)
                    offset += args.action_embed_size * (not args.no_parent_production_embed)
                    offset += args.field_embed_size * (not args.no_parent_field_embed)

                    x[0, offset: offset + args.type_embed_size] = \
                        self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]]
            else:
                a_tm1_embeds = []
                for e_id, hyp in enumerate(hypotheses):
                    action_tm1 = hyp.actions[-1]
                    if action_tm1:
                        if isinstance(action_tm1, ApplyRuleAction):
                            a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
                        elif isinstance(action_tm1, ReduceAction):
                            a_tm1_embed = self.production_embed.weight[len(self.grammar)]
                        elif isinstance(action_tm1, WikiSqlSelectColumnAction):
                            a_tm1_embed = self.column_rnn_input(table_header_encoding[0, action_tm1.column_id])
                        elif isinstance(action_tm1, GenTokenAction):
                            a_tm1_embed = self.src_embed.weight[self.vocab.source[action_tm1.token]]
                        else:
                            raise ValueError('unknown action %s' % action_tm1)
                    else:
                        a_tm1_embed = zero_action_embed

                    a_tm1_embeds.append(a_tm1_embed)

                a_tm1_embeds = torch.stack(a_tm1_embeds)

                inputs = [a_tm1_embeds]
                if args.no_input_feed is False:
                    inputs.append(att_tm1)
                if args.no_parent_production_embed is False:
                    # frontier production
                    frontier_prods = [hyp.frontier_node.production for hyp in hypotheses]
                    frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor(
                        [self.grammar.prod2id[prod] for prod in frontier_prods])))
                    inputs.append(frontier_prod_embeds)
                if args.no_parent_field_embed is False:
                    # frontier field
                    frontier_fields = [hyp.frontier_field.field for hyp in hypotheses]
                    frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([
                        self.grammar.field2id[field] for field in frontier_fields])))

                    inputs.append(frontier_field_embeds)
                if args.no_parent_field_type_embed is False:
                    # frontier field type
                    frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses]
                    frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([
                        self.grammar.type2id[type] for type in frontier_field_types])))
                    inputs.append(frontier_field_type_embeds)

                # parent states
                if args.no_parent_state is False:
                    p_ts = [hyp.frontier_node.created_time for hyp in hypotheses]
                    parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)])
                    parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)])

                    if args.lstm == 'parent_feed':
                        h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells)
                    else:
                        inputs.append(parent_states)

                x = torch.cat(inputs, dim=-1)

            (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings,
                                             exp_src_encodings_att_linear,
                                             src_token_mask=None)

            # ApplyRule action probability
            # (batch_size, grammar_size)
            apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1)

            # column attention
            # (batch_size, max_head_num)
            column_attention_weights = self.column_pointer_net(table_header_encoding, table_header_mask,
                                                               att_t.unsqueeze(0)).squeeze(0)
            column_selection_log_prob = torch.log(column_attention_weights)

            # (batch_size, 2)
            primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1)

            # primitive copy prob
            # (batch_size, src_token_num)
            primitive_copy_prob = self.src_pointer_net(utterance_encodings, None,
                                                       att_t.unsqueeze(0)).squeeze(0)

            # (batch_size, primitive_vocab_size)
            primitive_gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1)

            new_hyp_meta = []

            for hyp_id, hyp in enumerate(hypotheses):
                # generate new continuations
                action_types = self.transition_system.get_valid_continuation_types(hyp)

                for action_type in action_types:
                    if action_type == ApplyRuleAction:
                        productions = self.transition_system.get_valid_continuating_productions(hyp)
                        for production in productions:
                            prod_id = self.grammar.prod2id[production]
                            prod_score = apply_rule_log_prob[hyp_id, prod_id]
                            new_hyp_score = hyp.score + prod_score

                            meta_entry = {'action_type': 'apply_rule', 'prod_id': prod_id,
                                          'score': prod_score, 'new_hyp_score': new_hyp_score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)
                    elif action_type == ReduceAction:
                        action_score = apply_rule_log_prob[hyp_id, len(self.grammar)]
                        new_hyp_score = hyp.score + action_score

                        meta_entry = {'action_type': 'apply_rule', 'prod_id': len(self.grammar),
                                      'score': action_score, 'new_hyp_score': new_hyp_score,
                                      'prev_hyp_id': hyp_id}
                        new_hyp_meta.append(meta_entry)
                    elif action_type == WikiSqlSelectColumnAction:
                        for col_id, column in enumerate(table.header):
                            col_sel_score = column_selection_log_prob[hyp_id, col_id]
                            new_hyp_score = hyp.score + col_sel_score

                            meta_entry = {'action_type': 'sel_col', 'col_id': col_id,
                                          'score': col_sel_score, 'new_hyp_score': new_hyp_score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)
                    elif action_type == GenTokenAction:
                        # remember that we can only copy stuff from the input!
                        # we only copy tokens sequentially!!
                        prev_action = hyp.action_infos[-1].action

                        valid_token_pos_list = []
                        if type(prev_action) is GenTokenAction and \
                                not prev_action.is_stop_signal():
                            token_pos = hyp.action_infos[-1].src_token_position + 1
                            if token_pos < len(question):
                                valid_token_pos_list = [token_pos]
                        else:
                            valid_token_pos_list = list(range(len(question)))

                        col_id = hyp.frontier_node['col_idx'].value
                        if table.header[col_id].type == 'real':
                            valid_token_pos_list = [i for i in valid_token_pos_list
                                                    if any(c.isdigit() for c in question[i]) or
                                                    hyp._value_buffer and question[i] in (',', '.', '-', '%')]

                        p_copies = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id]
                        for token_pos in valid_token_pos_list:
                            token = question[token_pos]
                            p_copy = p_copies[token_pos]
                            score_copy = torch.log(p_copy)

                            meta_entry = {'action_type': 'gen_token',
                                          'token': token, 'token_pos': token_pos,
                                          'score': score_copy, 'new_hyp_score': score_copy + hyp.score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)

                        # add generation probability for </primitive>
                        if hyp._value_buffer:
                            eos_prob = primitive_predictor_prob[hyp_id, 0] * \
                                       primitive_gen_from_vocab_prob[hyp_id, self.vocab.primitive['</primitive>']]
                            eos_score = torch.log(eos_prob)

                            meta_entry = {'action_type': 'gen_token',
                                          'token': '</primitive>',
                                          'score': eos_score, 'new_hyp_score': eos_score + hyp.score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)

            if not new_hyp_meta: break

            new_hyp_scores = torch.cat([x['new_hyp_score'] for x in new_hyp_meta])
            top_new_hyp_scores, meta_ids = torch.topk(new_hyp_scores,
                                                      k=min(new_hyp_scores.size(0),
                                                            beam_size - len(completed_hypotheses)))

            live_hyp_ids = []
            new_hypotheses = []
            for new_hyp_score, meta_id in zip(top_new_hyp_scores.data.cpu(), meta_ids.data.cpu()):
                action_info = ActionInfo()
                hyp_meta_entry = new_hyp_meta[meta_id]
                prev_hyp_id = hyp_meta_entry['prev_hyp_id']
                prev_hyp = hypotheses[prev_hyp_id]

                action_type_str = hyp_meta_entry['action_type']
                if action_type_str == 'apply_rule':
                    # ApplyRule action
                    prod_id = hyp_meta_entry['prod_id']
                    if prod_id < len(self.grammar):
                        production = self.grammar.id2prod[prod_id]
                        action = ApplyRuleAction(production)
                    # Reduce action
                    else:
                        action = ReduceAction()
                elif action_type_str == 'sel_col':
                    action = WikiSqlSelectColumnAction(hyp_meta_entry['col_id'])
                else:
                    action = GenTokenAction(hyp_meta_entry['token'])
                    if 'token_pos' in hyp_meta_entry:
                        action_info.copy_from_src = True
                        action_info.src_token_position = hyp_meta_entry['token_pos']

                action_info.action = action
                action_info.t = t

                if t > 0:
                    action_info.parent_t = prev_hyp.frontier_node.created_time
                    action_info.frontier_prod = prev_hyp.frontier_node.production
                    action_info.frontier_field = prev_hyp.frontier_field.field

                new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                new_hyp.score = new_hyp_score

                if new_hyp.completed:
                    completed_hypotheses.append(new_hyp)
                else:
                    new_hypotheses.append(new_hyp)
                    live_hyp_ids.append(prev_hyp_id)

            if live_hyp_ids:
                hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids]
                h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
                att_tm1 = att_t[live_hyp_ids]
                hypotheses = new_hypotheses
                t += 1
            else: break

        completed_hypotheses.sort(key=lambda hyp: -hyp.score)

        return completed_hypotheses
 def get_primitive_field_actions(self, realized_field):
     assert realized_field.cardinality == 'single'
     if realized_field.value is not None:
         return [GenTokenAction(realized_field.value)]
     else:
         return []