예제 #1
0
def get_sketch_prod_and_slot(examples: List, table_dict: Dict,
                             sketch_list: List, sketch_action_list: List):
    """
    If it contains all three types of columns, then the grammar is complete
    Also return sketch action list and their slots 
    """
    for example in examples:
        table_id = example["context"]
        table_lines = table_dict[table_id]["raw_lines"]

        tokenized_question = [Token(token) for token in example["tokens"]]
        context = TableQuestionContext.read_from_lines(table_lines,
                                                       tokenized_question)
        context.take_corenlp_entities(example["entities"])

        # annoymize number and date
        # context.annoymized_tokens = example["tmp_tokens"]

        world = WikiTableAbstractLanguage(context)
        if len(context.column_types) >= 3 and len(context._num2id) > 0 and \
            len(context._entity2id) > 0 and len(context._date2id) > 0:
            actions = world.get_nonterminal_productions()
            sketch_actions = world._get_sketch_productions(actions)

            # index all the possible actions
            action_set = set()
            for k, v in sketch_actions.items():
                action_set = action_set.union(set(v))
            id2prod = list(action_set)
            prod2id = {v: k for k, v in enumerate(id2prod)}
            # return id2prod, prod2id

            # lf to actions
            sketch_lf2actions = dict()
            for sketch_actions in sketch_action_list:
                lf = world.action_sequence_to_logical_form(sketch_actions)
                sketch_lf2actions[lf] = sketch_actions

            # sort by length in decreasing order
            slot_dict = defaultdict(dict)
            sketch_action_seqs = []
            for sketch in sketch_list:
                sketch_actions = sketch_lf2actions[sketch]
                sketch_actions = tuple(sketch_actions)
                sketch_action_seqs.append(sketch_actions)

                for action_ind, action in enumerate(sketch_actions):
                    assert action in prod2id
                    lhs, rhs = action.split(" -> ")
                    if lhs in [
                            "Column", "StringColumn", "NumberColumn",
                            "ComparableColumn", "DateColumn", "str", "Number",
                            "Date"
                    ] and rhs == "#PH#":
                        slot_dict[sketch_actions][action_ind] = lhs
                    elif lhs == "List[Row]" and rhs == "#PH#":
                        slot_dict[sketch_actions][action_ind] = lhs

            return id2prod, prod2id, sketch_action_seqs, slot_dict
예제 #2
0
    def compute_entropy(self, 
            context: TableQuestionContext,
            sketch2program: Dict) -> bool:
        world = WikiTableAbstractLanguage(context)
        ret_dic = defaultdict(int)

        # encode question and offline sketches
        token_in_table_feat = context.question_in_table_feat
        token_encodes, token_reps, last_state = self.encode_question(context.question_tokens, token_in_table_feat)

        entropy = []
        sketch_lf2actions = self.sketch_lf2actions(world)
        candidate_rep_dic = self.construct_candidates(world, token_encodes)
        for sketch_lf in sketch2program:
            sketch_actions = sketch_lf2actions[sketch_lf]    
            sketch_log_score = self.seq2seq(world, token_reps, token_encodes, sketch_actions)
            _paths, _log_scores = self.slot_filling(world, token_encodes, last_state,
                candidate_rep_dic, sketch_actions)

            # only one path
            if len(_paths) == 1:
                if not self.filter_program_by_execution(world, _paths[0]):
                    continue
                _path_lf = world.action_sequence_to_logical_form(_paths[0])
                _seq_score = sketch_log_score
                if _path_lf in sketch2program[sketch_lf]:
                    entropy.append(-1 * _seq_score * torch.exp(_seq_score))
                continue

            # multiple path
            for _path, _score in zip(_paths, _log_scores):
                if not self.filter_program_by_execution(world, _path):
                    continue
                assert _score is not None
                _path_lf = world.action_sequence_to_logical_form(_path)
                _seq_score = _score + sketch_log_score
                if _path_lf in sketch2program[sketch_lf]:
                    entropy.append(-1 * _seq_score * torch.exp(_seq_score))
        
        if len(entropy) > 0:
            ret_dic["entropy"] = sum(entropy).cpu().item()
        return ret_dic
예제 #3
0
    def forward(self, 
            context: TableQuestionContext,
            sketch2program: Dict) -> torch.Tensor:
        world = WikiTableAbstractLanguage(context)

        # encode questions
        token_in_table_feat = context.question_in_table_feat
        token_encodes, token_reps, last_state = self.encode_question(context.question_tokens, token_in_table_feat)

        sketch_lf2actions = self.sketch_lf2actions(world)
        consistent_scores = []
        candidate_rep_dic = self.construct_candidates(world, token_encodes)
        for sketch_lf in sketch2program:
            if len(sketch2program[sketch_lf]) > self.CONSISTENT_INST_NUM_BOUND:
                continue
            sketch_actions = sketch_lf2actions[sketch_lf]
            seq_log_likeli = self.seq2seq(world, token_reps, token_encodes, sketch_actions)
            _paths, _log_scores = self.slot_filling(world, token_encodes, last_state,
                    candidate_rep_dic, sketch_actions)

            # only one path
            if len(_paths) == 1:
                consistent_scores.append(seq_log_likeli)
                continue

            _gold_scores = []
            for _path, _score in zip(_paths, _log_scores):
                assert _score is not None
                _path_lf = world.action_sequence_to_logical_form(_path)
                if _path_lf in sketch2program[sketch_lf]:
                    _gold_scores.append(_score) 
            
            # aggregate consistent instantiations
            if len(_gold_scores) > 0:
                _score = seq_log_likeli + log_sum_exp(_gold_scores) 
                if torch.isnan(_score) == 0:
                    consistent_scores.append(_score)
                else:
                    logger.warning("Nan loss founded!")

        if len(consistent_scores) > 0:
            return -1 * log_sum_exp(consistent_scores)
        else:
            return None
예제 #4
0
    def evaluate(self, 
            context: TableQuestionContext,
            sketch2program: Dict) -> bool:
        world = WikiTableAbstractLanguage(context)
        ret_dic = defaultdict(int)

        # encode question and offline sketches
        token_in_table_feat = context.question_in_table_feat
        token_encodes, token_reps, last_state = self.encode_question(context.question_tokens, token_in_table_feat)

        sketch_actions_and_scores = self.seq2seq.beam_decode(world, 
            token_reps, token_encodes, self.EVAL_NUM_SKETCH_BOUND)
        
        max_score = None
        best_sketch_actions = None
        best_sketch_lf = None
        best_program_actions = None
        best_program_lf = None
        candidate_rep_dic = self.construct_candidates(world, token_encodes)
        for sketch_actions, sketch_log_score in sketch_actions_and_scores:
            sketch_lf = world.action_sequence_to_logical_form(sketch_actions)
            _paths, _log_scores = self.slot_filling(world, token_encodes, last_state,
                candidate_rep_dic, sketch_actions)
            # logger.info(f"{sketch_lf}, score: {torch.exp(sketch_log_score)}")

            if self.__class__.__name__ == "ConcreteProgrammer":
                assert self._ConcreteProgrammer__cur_align_prob_log is not None
                align_prob_log = self._ConcreteProgrammer__cur_align_prob_log.squeeze()
                # print(f"Align matrix prob: {torch.exp(align_prob_log)}")
                sketch_log_score = sketch_log_score + align_prob_log
                self._ConcreteProgrammer__cur_align_prob_log = None

            # only one path
            if len(_paths) == 1:
                if not self.filter_program_by_execution(world, _paths[0]):
                    continue
                _path_lf = world.action_sequence_to_logical_form(_paths[0])
                _seq_score = sketch_log_score
                if max_score is None or _seq_score > max_score:
                    max_score = _seq_score
                    best_sketch_lf = sketch_lf
                    best_sketch_actions = sketch_actions
                    best_program_lf = _path_lf 
                    best_program_actions = _paths[0]
                continue

            # multiple path
            for _path, _score in zip(_paths, _log_scores):
                if not self.filter_program_by_execution(world, _path):
                    continue
                assert _score is not None
                _path_lf = world.action_sequence_to_logical_form(_path)
                _seq_score = _score + sketch_log_score
                if max_score is None or _seq_score > max_score:
                    max_score = _seq_score
                    best_sketch_lf = sketch_lf
                    best_sketch_actions = sketch_actions
                    best_program_lf = _path_lf 
                    best_program_actions = _path
        
        assert max_score is not None
        ret_dic["best_program_lf"] = best_program_lf
        ret_dic["best_program_actions"] = best_program_actions
        ret_dic["best_sketch_lf"] = best_sketch_lf
        ret_dic["best_sketch_actions"] = best_sketch_actions
        ret_dic["best_score"] = torch.exp(max_score)
        ret_dic["is_multi_col"] = check_multi_col(world, 
                        best_sketch_actions, best_program_actions)

        if best_sketch_lf in sketch2program:
            ret_dic["sketch_triggered"] = True
            if best_program_lf in sketch2program[best_sketch_lf]:
                ret_dic["lf_triggered"] = True
            else:
                ret_dic["lf_triggered"] = False
        else:
            ret_dic["sketch_triggered"] = False
            ret_dic["lf_triggered"] = False

        return ret_dic
예제 #5
0
 def sketch_lf2actions(self, world: WikiTableAbstractLanguage):
     lf2actions = dict()
     for actions in self.sketch_actions_cache:
         lf = world.action_sequence_to_logical_form(actions)
         lf2actions[lf] = actions
     return lf2actions
예제 #6
0
    def beam_decode(self, world: WikiTableAbstractLanguage,
                    token_rep: torch.Tensor, token_encodes: torch.Tensor,
                    beam_size: int):
        """
        Input: a sequence of sketch actions
        Output: output top-k most probable sequence
        """
        action_dict = world._get_sketch_productions(
            world.get_nonterminal_productions())
        initial_rnn_state = self._get_initial_state(token_rep)

        incomplete = [([START_SYMBOL], [], initial_rnn_state, None)
                      ]  # stack,history,rnn_state
        completed = []

        for i in range(self._max_decoding_steps):
            next_paths = []
            for stack, history, rnn_state, seq_score in incomplete:
                cur_non_terminal = stack.pop()
                if cur_non_terminal not in action_dict: continue
                candidates = action_dict[cur_non_terminal]
                candidate_ids = [self.sketch_prod2id[ac] for ac in candidates]

                cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell
                next_hidden, next_memory = self.decoder_lstm(
                    rnn_state.previous_action_embedding,
                    (cur_hidden, cur_memory))
                hidden_tran = next_hidden.transpose(0, 1)
                att_feat_v = torch.mm(token_encodes,
                                      hidden_tran)  # sent_len * 1
                att_v = F.softmax(att_feat_v, dim=0)
                att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes)

                score_feat_v = torch.cat([next_hidden, att_ret_v], 1)
                score_v = self.score_action(score_feat_v).squeeze()
                filter_score_v_list = [score_v[_id] for _id in candidate_ids]
                filter_score_v = torch.stack(filter_score_v_list, 0)
                prob_v = F.log_softmax(filter_score_v, dim=0)

                pred_logits, pred_ids = torch.topk(prob_v,
                                                   min(beam_size,
                                                       prob_v.size()[0]),
                                                   dim=0)

                for _logits, _idx in zip(pred_logits, pred_ids):
                    next_action_embed = self.sketch_embed.weight[
                        candidate_ids[_idx]].unsqueeze(0)
                    rnn_state = RnnStatelet(next_hidden, next_memory,
                                            next_action_embed, None, None,
                                            None)

                    prod = candidates[_idx]
                    _history = history[:]
                    _history.append(prod)
                    non_terminals = self._get_right_side_parts(prod)
                    _stack = stack[:]
                    for ac in reversed(non_terminals):
                        if ac in action_dict:
                            _stack.append(ac)
                    if seq_score is None:
                        _score = _logits
                    else:
                        _score = _logits + seq_score

                    next_paths.append((_stack, _history, rnn_state, _score))

            incomplete = []
            for stack, history, rnn_state, seq_score in next_paths:
                if len(stack) == 0:
                    if world.action_sequence_to_logical_form(
                            history) != "#PH#":
                        completed.append((history, seq_score))
                else:
                    incomplete.append((stack, history, rnn_state, seq_score))

            if len(completed) > beam_size:
                completed = sorted(completed, key=lambda x: -x[1])
                completed = completed[:beam_size]
                break

            if len(incomplete) > beam_size:
                incomplete = sorted(incomplete, key=lambda x: -x[3])
                incomplete = incomplete[:beam_size]

        return completed