Esempio n. 1
0
def get_sketch_prod(examples: List, table_dict: Dict) -> List:
    """
    If it contains all three types of columns, then the grammar is complete
    Also return sketch action list and their slots 
    """
    for example in examples:
        table_id = example["context"]
        table_lines = table_dict[table_id]["raw_lines"]

        tokenized_question = [Token(token) for token in example["tokens"]]
        context = TableQuestionContext.read_from_lines(table_lines,
                                                       tokenized_question)
        context.take_corenlp_entities(example["entities"])

        world = WikiTableAbstractLanguage(context)
        if len(context.column_types) >= 3 and len(context._num2id) > 0 and \
            len(context._entity2id) > 0 and len(context._date2id) > 0:
            actions = world.get_nonterminal_productions()
            sketch_actions = world._get_sketch_productions(actions)

            # index all the possible actions
            action_set = set()
            for k, v in sketch_actions.items():
                action_set = action_set.union(set(v))
            id2prod = list(action_set)
            prod2id = {v: k for k, v in enumerate(id2prod)}

            return id2prod, prod2id
Esempio n. 2
0
def get_sketch_prod_and_slot(examples: List, table_dict: Dict,
                             sketch_list: List, sketch_action_list: List):
    """
    If it contains all three types of columns, then the grammar is complete
    Also return sketch action list and their slots 
    """
    for example in examples:
        table_id = example["context"]
        table_lines = table_dict[table_id]["raw_lines"]

        tokenized_question = [Token(token) for token in example["tokens"]]
        context = TableQuestionContext.read_from_lines(table_lines,
                                                       tokenized_question)
        context.take_corenlp_entities(example["entities"])

        # annoymize number and date
        # context.annoymized_tokens = example["tmp_tokens"]

        world = WikiTableAbstractLanguage(context)
        if len(context.column_types) >= 3 and len(context._num2id) > 0 and \
            len(context._entity2id) > 0 and len(context._date2id) > 0:
            actions = world.get_nonterminal_productions()
            sketch_actions = world._get_sketch_productions(actions)

            # index all the possible actions
            action_set = set()
            for k, v in sketch_actions.items():
                action_set = action_set.union(set(v))
            id2prod = list(action_set)
            prod2id = {v: k for k, v in enumerate(id2prod)}
            # return id2prod, prod2id

            # lf to actions
            sketch_lf2actions = dict()
            for sketch_actions in sketch_action_list:
                lf = world.action_sequence_to_logical_form(sketch_actions)
                sketch_lf2actions[lf] = sketch_actions

            # sort by length in decreasing order
            slot_dict = defaultdict(dict)
            sketch_action_seqs = []
            for sketch in sketch_list:
                sketch_actions = sketch_lf2actions[sketch]
                sketch_actions = tuple(sketch_actions)
                sketch_action_seqs.append(sketch_actions)

                for action_ind, action in enumerate(sketch_actions):
                    assert action in prod2id
                    lhs, rhs = action.split(" -> ")
                    if lhs in [
                            "Column", "StringColumn", "NumberColumn",
                            "ComparableColumn", "DateColumn", "str", "Number",
                            "Date"
                    ] and rhs == "#PH#":
                        slot_dict[sketch_actions][action_ind] = lhs
                    elif lhs == "List[Row]" and rhs == "#PH#":
                        slot_dict[sketch_actions][action_ind] = lhs

            return id2prod, prod2id, sketch_action_seqs, slot_dict
Esempio n. 3
0
    def forward(self, 
            context: TableQuestionContext,
            sketch2program: Dict) -> torch.Tensor:
        world = WikiTableAbstractLanguage(context)

        # encode questions
        token_in_table_feat = context.question_in_table_feat
        token_encodes, token_reps, last_state = self.encode_question(context.question_tokens, token_in_table_feat)

        sketch_lf2actions = self.sketch_lf2actions(world)
        consistent_scores = []
        candidate_rep_dic = self.construct_candidates(world, token_encodes)
        for sketch_lf in sketch2program:
            if len(sketch2program[sketch_lf]) > self.CONSISTENT_INST_NUM_BOUND:
                continue
            sketch_actions = sketch_lf2actions[sketch_lf]
            seq_log_likeli = self.seq2seq(world, token_reps, token_encodes, sketch_actions)
            _paths, _log_scores = self.slot_filling(world, token_encodes, last_state,
                    candidate_rep_dic, sketch_actions)

            # only one path
            if len(_paths) == 1:
                consistent_scores.append(seq_log_likeli)
                continue

            _gold_scores = []
            for _path, _score in zip(_paths, _log_scores):
                assert _score is not None
                _path_lf = world.action_sequence_to_logical_form(_path)
                if _path_lf in sketch2program[sketch_lf]:
                    _gold_scores.append(_score) 
            
            # aggregate consistent instantiations
            if len(_gold_scores) > 0:
                _score = seq_log_likeli + log_sum_exp(_gold_scores) 
                if torch.isnan(_score) == 0:
                    consistent_scores.append(_score)
                else:
                    logger.warning("Nan loss founded!")

        if len(consistent_scores) > 0:
            return -1 * log_sum_exp(consistent_scores)
        else:
            return None
Esempio n. 4
0
    def compute_entropy(self, 
            context: TableQuestionContext,
            sketch2program: Dict) -> bool:
        world = WikiTableAbstractLanguage(context)
        ret_dic = defaultdict(int)

        # encode question and offline sketches
        token_in_table_feat = context.question_in_table_feat
        token_encodes, token_reps, last_state = self.encode_question(context.question_tokens, token_in_table_feat)

        entropy = []
        sketch_lf2actions = self.sketch_lf2actions(world)
        candidate_rep_dic = self.construct_candidates(world, token_encodes)
        for sketch_lf in sketch2program:
            sketch_actions = sketch_lf2actions[sketch_lf]    
            sketch_log_score = self.seq2seq(world, token_reps, token_encodes, sketch_actions)
            _paths, _log_scores = self.slot_filling(world, token_encodes, last_state,
                candidate_rep_dic, sketch_actions)

            # only one path
            if len(_paths) == 1:
                if not self.filter_program_by_execution(world, _paths[0]):
                    continue
                _path_lf = world.action_sequence_to_logical_form(_paths[0])
                _seq_score = sketch_log_score
                if _path_lf in sketch2program[sketch_lf]:
                    entropy.append(-1 * _seq_score * torch.exp(_seq_score))
                continue

            # multiple path
            for _path, _score in zip(_paths, _log_scores):
                if not self.filter_program_by_execution(world, _path):
                    continue
                assert _score is not None
                _path_lf = world.action_sequence_to_logical_form(_path)
                _seq_score = _score + sketch_log_score
                if _path_lf in sketch2program[sketch_lf]:
                    entropy.append(-1 * _seq_score * torch.exp(_seq_score))
        
        if len(entropy) > 0:
            ret_dic["entropy"] = sum(entropy).cpu().item()
        return ret_dic
Esempio n. 5
0
    def evaluate(self, 
            context: TableQuestionContext,
            sketch2program: Dict) -> bool:
        world = WikiTableAbstractLanguage(context)
        ret_dic = defaultdict(int)

        # encode question and offline sketches
        token_in_table_feat = context.question_in_table_feat
        token_encodes, token_reps, last_state = self.encode_question(context.question_tokens, token_in_table_feat)

        sketch_actions_and_scores = self.seq2seq.beam_decode(world, 
            token_reps, token_encodes, self.EVAL_NUM_SKETCH_BOUND)
        
        max_score = None
        best_sketch_actions = None
        best_sketch_lf = None
        best_program_actions = None
        best_program_lf = None
        candidate_rep_dic = self.construct_candidates(world, token_encodes)
        for sketch_actions, sketch_log_score in sketch_actions_and_scores:
            sketch_lf = world.action_sequence_to_logical_form(sketch_actions)
            _paths, _log_scores = self.slot_filling(world, token_encodes, last_state,
                candidate_rep_dic, sketch_actions)
            # logger.info(f"{sketch_lf}, score: {torch.exp(sketch_log_score)}")

            if self.__class__.__name__ == "ConcreteProgrammer":
                assert self._ConcreteProgrammer__cur_align_prob_log is not None
                align_prob_log = self._ConcreteProgrammer__cur_align_prob_log.squeeze()
                # print(f"Align matrix prob: {torch.exp(align_prob_log)}")
                sketch_log_score = sketch_log_score + align_prob_log
                self._ConcreteProgrammer__cur_align_prob_log = None

            # only one path
            if len(_paths) == 1:
                if not self.filter_program_by_execution(world, _paths[0]):
                    continue
                _path_lf = world.action_sequence_to_logical_form(_paths[0])
                _seq_score = sketch_log_score
                if max_score is None or _seq_score > max_score:
                    max_score = _seq_score
                    best_sketch_lf = sketch_lf
                    best_sketch_actions = sketch_actions
                    best_program_lf = _path_lf 
                    best_program_actions = _paths[0]
                continue

            # multiple path
            for _path, _score in zip(_paths, _log_scores):
                if not self.filter_program_by_execution(world, _path):
                    continue
                assert _score is not None
                _path_lf = world.action_sequence_to_logical_form(_path)
                _seq_score = _score + sketch_log_score
                if max_score is None or _seq_score > max_score:
                    max_score = _seq_score
                    best_sketch_lf = sketch_lf
                    best_sketch_actions = sketch_actions
                    best_program_lf = _path_lf 
                    best_program_actions = _path
        
        assert max_score is not None
        ret_dic["best_program_lf"] = best_program_lf
        ret_dic["best_program_actions"] = best_program_actions
        ret_dic["best_sketch_lf"] = best_sketch_lf
        ret_dic["best_sketch_actions"] = best_sketch_actions
        ret_dic["best_score"] = torch.exp(max_score)
        ret_dic["is_multi_col"] = check_multi_col(world, 
                        best_sketch_actions, best_program_actions)

        if best_sketch_lf in sketch2program:
            ret_dic["sketch_triggered"] = True
            if best_program_lf in sketch2program[best_sketch_lf]:
                ret_dic["lf_triggered"] = True
            else:
                ret_dic["lf_triggered"] = False
        else:
            ret_dic["sketch_triggered"] = False
            ret_dic["lf_triggered"] = False

        return ret_dic