예제 #1
0
def get_sketch_prod(examples: List, table_dict: Dict) -> List:
    """
    If it contains all three types of columns, then the grammar is complete
    Also return sketch action list and their slots 
    """
    for example in examples:

        table_id = example["context"]
        processed_table = table_dict[table_id]
        context = WikiSQLContext.read_from_json(example, processed_table)
        context.take_features(example)
        world = WikiSQLLanguage(context)

        if len(context.column_types) >= 2 and len(context._num2id) > 0 and \
            len(context._entity2id) > 0:
            actions = world.get_nonterminal_productions()
            sketch_actions = world._get_sketch_productions(actions)

            # index all the possible actions
            action_set = set()
            for k, v in sketch_actions.items():
                action_set = action_set.union(set(v))
            id2prod = list(action_set)
            prod2id = {v: k for k, v in enumerate(id2prod)}

            return id2prod, prod2id
예제 #2
0
    def evaluate(self, context: WikiSQLContext, sketch2program: Dict) -> bool:
        world = WikiSQLLanguage(context)
        ret_dic = defaultdict(int)

        # encode question and offline sketches
        token_in_table_feat = context.question_in_table_feat
        token_encodes, token_reps = self.encode_question(
            context.question_tokens, token_in_table_feat)
        candidate_rep_dic = self.construct_candidates(world, token_encodes)

        sketch_actions, program_actions = self.seq2seq.decode(world, token_reps, \
                token_encodes, candidate_rep_dic)

        sketch_lf = world.action_sequence_to_logical_form(sketch_actions)
        program_lf = world.action_sequence_to_logical_form(program_actions)
        if sketch_lf in sketch2program:
            sketch_triggered = True
            if program_lf in sketch2program[sketch_lf]:
                lf_triggered = True
            else:
                lf_triggered = False
        else:
            sketch_triggered = False
            lf_triggered = False

        ret_dic["best_program_lf"] = program_lf
        ret_dic["best_program_actions"] = program_actions
        ret_dic["best_sketch_lf"] = sketch_lf
        ret_dic["best_sketch_actions"] = sketch_actions

        return ret_dic
예제 #3
0
 def filter_program_by_execution(self, world: WikiSQLLanguage,
                                 actions: List):
     try:
         world.execute_action_sequence(actions)
         return True
     except:
         return False
예제 #4
0
    def forward(self, context: WikiSQLContext,
                sketch2program: Dict) -> torch.Tensor:
        world = WikiSQLLanguage(context)

        # encode questions
        token_in_table_feat = context.question_in_table_feat
        token_encodes, token_reps = self.encode_question(
            context.question_tokens, token_in_table_feat)

        sketch_lf2actions = self.sketch_lf2actions(world)
        consistent_scores = []
        candidate_rep_dic = self.construct_candidates(world, token_encodes)
        for sketch_lf in sketch2program:
            sketch_actions = sketch_lf2actions[sketch_lf]
            if len(sketch2program[sketch_lf]) > self.CONSISTENT_INST_NUM_BOUND:
                continue
            for program_lf in sketch2program[sketch_lf]:
                program_actions = world.logical_form_to_action_sequence(
                    program_lf)
                seq_log_likeli = self.seq2seq(world, token_reps, token_encodes,
                                              candidate_rep_dic,
                                              sketch_actions, program_actions)
                if seq_log_likeli:
                    consistent_scores.append(seq_log_likeli)

        if len(consistent_scores) > 0:
            return -1 * log_sum_exp(consistent_scores)
        else:
            return None
예제 #5
0
def get_sketch_prod_and_slot(examples: List, table_dict: Dict,
                             sketch_list: List, sketch_action_list: List):
    """
    If it contains all three types of columns, then the grammar is complete
    Also return sketch action list and their slots 
    Used for pruned version
    """
    for example in examples:

        table_id = example["context"]
        processed_table = table_dict[table_id]
        context = WikiSQLContext.read_from_json(example, processed_table)
        context.take_features(example)
        world = WikiSQLLanguage(context)

        if len(context.column_types) >= 2 and len(context._num2id) > 0 and \
            len(context._entity2id) > 0:
            actions = world.get_nonterminal_productions()
            sketch_actions = world._get_sketch_productions(actions)

            # index all the possible actions
            action_set = set()
            for k, v in sketch_actions.items():
                action_set = action_set.union(set(v))
            id2prod = list(action_set)
            prod2id = {v: k for k, v in enumerate(id2prod)}

            # lf to actions
            sketch_lf2actions = dict()
            for sketch_actions in sketch_action_list:
                lf = world.action_sequence_to_logical_form(sketch_actions)
                sketch_lf2actions[lf] = sketch_actions

            # sort by length in decreasing order
            slot_dict = defaultdict(dict)
            sketch_action_seqs = []
            for sketch in sketch_list:
                sketch_actions = sketch_lf2actions[sketch]
                sketch_actions = tuple(sketch_actions)
                sketch_action_seqs.append(sketch_actions)

                for action_ind, action in enumerate(sketch_actions):
                    assert action in prod2id
                    lhs, rhs = action.split(" -> ")
                    if lhs in [
                            "Column", "StringColumn", "NumberColumn",
                            "ComparableColumn", "DateColumn", "str", "Number",
                            "Date"
                    ] and rhs == "#PH#":
                        slot_dict[sketch_actions][action_ind] = lhs
                    elif lhs == "List[Row]" and rhs == "#PH#":
                        slot_dict[sketch_actions][action_ind] = lhs

            return id2prod, prod2id, sketch_action_seqs, slot_dict
예제 #6
0
    def decode(self, world: WikiSQLLanguage, token_rep: torch.Tensor,
               token_encodes: torch.Tensor):
        """
        Input: a sequence of sketch actions
        Output: the most probable sequence
        """
        action_dict = world._get_sketch_productions(
            world.get_nonterminal_productions())
        initial_rnn_state = self._get_initial_state(token_rep)

        stack = [START_SYMBOL]
        history = []
        rnn_state = initial_rnn_state
        for i in range(self._max_decoding_steps):
            if len(stack) == 0: break

            cur_non_terminal = stack.pop()
            if cur_non_terminal not in action_dict: continue
            candidates = action_dict[cur_non_terminal]
            candidate_ids = [self.sketch_prod2id[ac] for ac in candidates]

            cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell
            next_hidden, next_memory = self.decoder_lstm(
                rnn_state.previous_action_embedding, (cur_hidden, cur_memory))
            hidden_tran = next_hidden.transpose(0, 1)
            att_feat_v = torch.mm(token_encodes, hidden_tran)  # sent_len * 1
            att_v = F.softmax(att_feat_v, dim=0)
            att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes)

            score_feat_v = torch.cat([next_hidden, att_ret_v], 1)
            score_v = self.score_action(score_feat_v).squeeze()
            filter_score_v_list = [score_v[_id] for _id in candidate_ids]
            filter_score_v = torch.stack(filter_score_v_list, 0)
            prob_v = F.softmax(filter_score_v, dim=0)

            _, pred_id = torch.max(prob_v, dim=0)
            pred_id = pred_id.cpu().item()

            next_action_embed = self.sketch_embed.weight[
                candidate_ids[pred_id]].unsqueeze(0)
            rnn_state = RnnStatelet(next_hidden, next_memory,
                                    next_action_embed, None, None, None)

            prod = candidates[pred_id]
            history.append(prod)
            non_terminals = self._get_right_side_parts(prod)
            stack += list(reversed(non_terminals))

        return tuple(history)
예제 #7
0
    def forward(self, context: WikiSQLContext,
                sketch2program: Dict) -> torch.Tensor:
        world = WikiSQLLanguage(context)

        # encode questions
        token_in_table_feat = context.question_in_table_feat
        token_encodes, token_reps, last_state = self.encode_question(
            context.question_tokens, token_in_table_feat)

        sketch_lf2actions = self.sketch_lf2actions(world)
        consistent_scores = []
        candidate_rep_dic = self.construct_candidates(world, token_encodes)
        for sketch_lf in sketch2program:
            if len(sketch2program[sketch_lf]) > self.CONSISTENT_INST_NUM_BOUND:
                continue
            sketch_actions = sketch_lf2actions[sketch_lf]
            seq_log_likeli = self.seq2seq(world, token_reps, token_encodes,
                                          sketch_actions)
            _paths, _log_scores = self.slot_filling(world, token_encodes,
                                                    last_state,
                                                    candidate_rep_dic,
                                                    sketch_actions)

            # only one path
            if len(_paths) == 1:
                consistent_scores.append(seq_log_likeli)
                continue

            _gold_scores = []
            for _path, _score in zip(_paths, _log_scores):
                assert _score is not None
                _path_lf = world.action_sequence_to_logical_form(_path)
                # logger.info(_path_lf)
                if _path_lf in sketch2program[sketch_lf]:
                    _gold_scores.append(_score)

            # aggregate consistent instantiations
            if len(_gold_scores) > 0:
                _score = seq_log_likeli + log_sum_exp(_gold_scores)
                if torch.isnan(_score) == 0:
                    consistent_scores.append(_score)
                else:
                    logger.warning("Nan loss founded!")

        if len(consistent_scores) > 0:
            return -1 * log_sum_exp(consistent_scores)
        else:
            return None
예제 #8
0
    def forward(self, world: WikiSQLLanguage, token_rep: torch.Tensor,
                token_encodes: torch.Tensor, sketch_actions: List):
        """
        Input: a sequence of sketch actions
        """
        action_dict = world._get_sketch_productions(
            world.get_nonterminal_productions())
        initial_rnn_state = self._get_initial_state(token_rep)

        seq_likeli = []
        rnn_state = initial_rnn_state
        for i, prod in enumerate(sketch_actions):
            left_side, _ = prod.split(" -> ")
            candidates = action_dict[left_side]
            candidate_ids = [self.sketch_prod2id[ac] for ac in candidates]

            cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell
            next_hidden, next_memory = self.decoder_lstm(
                rnn_state.previous_action_embedding, (cur_hidden, cur_memory))
            hidden_tran = next_hidden.transpose(0, 1)
            att_feat_v = torch.mm(token_encodes, hidden_tran)  # sent_len * 1
            att_v = F.softmax(att_feat_v, dim=0)
            att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes)

            score_feat_v = torch.cat([next_hidden, att_ret_v], 1)
            score_v = self.score_action(score_feat_v).squeeze()
            filter_score_v_list = [score_v[_id] for _id in candidate_ids]
            filter_score_v = torch.stack(filter_score_v_list, 0)
            log_likeli = F.log_softmax(filter_score_v, dim=0)

            gold_id = candidate_ids.index(self.sketch_prod2id[prod])
            seq_likeli.append(log_likeli[gold_id])

            next_action_embed = self.sketch_embed.weight[
                self.sketch_prod2id[prod]].unsqueeze(0)
            rnn_state = RnnStatelet(next_hidden, next_memory,
                                    next_action_embed, None, None, None)

        return sum(seq_likeli)
예제 #9
0
def check_multi_col(world: WikiSQLLanguage, sketch_actions: List,
                    program_actions: List) -> bool:
    prod_dic = world.get_nonterminal_productions()
    slot_dic = gen_slot2action_dic(world, prod_dic, sketch_actions,
                                   program_actions)
    row_slot_acs = []
    col_slot_acs = []
    for idx in slot_dic:
        slot_type = get_left_side_part(sketch_actions[idx])
        if slot_type == "List[Row]":
            row_slot_acs.append(slot_dic[idx])
        else:
            col_slot_acs.append(slot_dic[idx])

    if len(row_slot_acs) == 0 or len(col_slot_acs) == 0:
        return False

    for col_slot_ac in col_slot_acs:
        col_name = get_right_side_parts(col_slot_ac)[0]
        for row_slot_ac in row_slot_acs:
            if col_name not in "_".join(row_slot_ac):
                return True
    return False
예제 #10
0
    def beam_decode(self, world: WikiSQLLanguage, token_rep: torch.Tensor,
                    token_encodes: torch.Tensor, beam_size: int):
        """
        Input: a sequence of sketch actions
        Output: output top-k most probable sequence
        """
        action_dict = world._get_sketch_productions(
            world.get_nonterminal_productions())
        initial_rnn_state = self._get_initial_state(token_rep)

        incomplete = [([START_SYMBOL], [], initial_rnn_state, None)
                      ]  # stack,history,rnn_state
        completed = []

        for i in range(self._max_decoding_steps):
            next_paths = []
            for stack, history, rnn_state, seq_score in incomplete:
                cur_non_terminal = stack.pop()
                if cur_non_terminal not in action_dict: continue
                candidates = action_dict[cur_non_terminal]
                candidate_ids = [self.sketch_prod2id[ac] for ac in candidates]

                cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell
                next_hidden, next_memory = self.decoder_lstm(
                    rnn_state.previous_action_embedding,
                    (cur_hidden, cur_memory))
                hidden_tran = next_hidden.transpose(0, 1)
                att_feat_v = torch.mm(token_encodes,
                                      hidden_tran)  # sent_len * 1
                att_v = F.softmax(att_feat_v, dim=0)
                att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes)

                score_feat_v = torch.cat([next_hidden, att_ret_v], 1)
                score_v = self.score_action(score_feat_v).squeeze()
                filter_score_v_list = [score_v[_id] for _id in candidate_ids]
                filter_score_v = torch.stack(filter_score_v_list, 0)
                prob_v = F.log_softmax(filter_score_v, dim=0)

                pred_logits, pred_ids = torch.topk(prob_v,
                                                   min(beam_size,
                                                       prob_v.size()[0]),
                                                   dim=0)

                for _logits, _idx in zip(pred_logits, pred_ids):
                    next_action_embed = self.sketch_embed.weight[
                        candidate_ids[_idx]].unsqueeze(0)
                    rnn_state = RnnStatelet(next_hidden, next_memory,
                                            next_action_embed, None, None,
                                            None)

                    prod = candidates[_idx]
                    _history = history[:]
                    _history.append(prod)
                    non_terminals = self._get_right_side_parts(prod)
                    _stack = stack[:]
                    for ac in reversed(non_terminals):
                        if ac in action_dict:
                            _stack.append(ac)
                    if seq_score is None:
                        _score = _logits
                    else:
                        _score = _logits + seq_score

                    next_paths.append((_stack, _history, rnn_state, _score))

            incomplete = []
            for stack, history, rnn_state, seq_score in next_paths:
                if len(stack) == 0:
                    if world.action_sequence_to_logical_form(
                            history) != "#PH#":
                        completed.append((history, seq_score))
                else:
                    incomplete.append((stack, history, rnn_state, seq_score))

            if len(completed) > beam_size:
                completed = sorted(completed, key=lambda x: -x[1])
                completed = completed[:beam_size]
                break

            if len(incomplete) > beam_size:
                incomplete = sorted(incomplete, key=lambda x: -x[3])
                incomplete = incomplete[:beam_size]

        return completed
예제 #11
0
    def compute_entropy(self, context: WikiSQLContext, sketch2program: Dict,
                        keywords: List) -> Dict:
        """
        Return a dictionary for different analysis
        """
        world = WikiSQLLanguage(context)
        ret_dic = defaultdict(int)

        # encode question and offline sketches
        token_in_table_feat = context.question_in_table_feat
        token_encodes, token_reps, last_state = self.encode_question(
            context.question_tokens, token_in_table_feat)

        consist_prob_logs = []
        # consist_sketch_logs = []
        entropy = []

        candidate_rep_dic = self.construct_candidates(world, token_encodes)
        sketch_lf2actions = self.sketch_lf2actions(world)
        for sketch_lf in sketch2program:
            sketch_actions = sketch_lf2actions[sketch_lf]
            sketch_log_score = self.seq2seq(world, token_reps, token_encodes,
                                            sketch_actions)
            _paths, _log_scores = self.slot_filling(world, token_encodes,
                                                    last_state,
                                                    candidate_rep_dic,
                                                    sketch_actions)

            # sketch entropy
            # entropy.append(-1 * sketch_log_score * torch.exp(sketch_log_score))
            # consist_sketch_logs.append(sketch_log_score)

            # only one path
            if len(_paths) == 1:
                if not self.filter_program_by_execution(world, _paths[0]):
                    continue
                _path_lf = world.action_sequence_to_logical_form(_paths[0])
                _seq_score = sketch_log_score
                if _path_lf in sketch2program[sketch_lf]:
                    if self.filter_by_keywords(_path_lf, keywords):
                        entropy.append(_seq_score)
                    # consist_prob_logs.append(_seq_score)
                    consist_prob_logs.append(_seq_score)
                continue

            # multiple path
            for _path, _score in zip(_paths, _log_scores):
                if not self.filter_program_by_execution(world, _path):
                    continue
                _path_lf = world.action_sequence_to_logical_form(_path)
                _seq_score = _score + sketch_log_score
                if _path_lf in sketch2program[sketch_lf]:
                    if self.filter_by_keywords(_path_lf, keywords):
                        entropy.append(_seq_score)
                    consist_prob_logs.append(_seq_score)

        if len(entropy) > 0:
            # print(f"length {len(entropy)}")
            ret_dic["entropy"] = log_sum_exp(consist_prob_logs) - log_sum_exp(
                entropy)
            ret_dic["triggered"] = True
            if max(consist_prob_logs) > max(entropy):
                ret_dic["is_correct"] = False
            else:
                ret_dic["is_correct"] = True
            p = torch.exp(max(entropy)) / sum(
                torch.exp(l_) for l_ in consist_prob_logs)
            ret_dic["proportion"] = p
        else:
            ret_dic["triggered"] = False
            # print("Not matched")
        return ret_dic
예제 #12
0
    def evaluate(self, context: WikiSQLContext, sketch2program: Dict) -> Dict:
        """
        Return a dictionary for different analysis
        """
        world = WikiSQLLanguage(context)
        ret_dic = defaultdict(int)

        # encode question and offline sketches
        token_in_table_feat = context.question_in_table_feat
        token_encodes, token_reps, last_state = self.encode_question(
            context.question_tokens, token_in_table_feat)

        sketch_actions_and_scores = self.seq2seq.beam_decode(
            world, token_reps, token_encodes, self.EVAL_NUM_SKETCH_BOUND)

        max_score = None
        best_sketch_actions = None
        best_sketch_lf = None
        best_program_actions = None
        best_program_lf = None
        candidate_rep_dic = self.construct_candidates(world, token_encodes)
        for sketch_actions, sketch_log_score in sketch_actions_and_scores:
            sketch_lf = world.action_sequence_to_logical_form(sketch_actions)
            _paths, _log_scores = self.slot_filling(world, token_encodes,
                                                    last_state,
                                                    candidate_rep_dic,
                                                    sketch_actions)

            if self.__class__.__name__ == "ConcreteProgrammer":
                assert self._ConcreteProgrammer__cur_align_prob_log is not None
                align_prob_log = self._ConcreteProgrammer__cur_align_prob_log.squeeze(
                )
                sketch_log_score = sketch_log_score + align_prob_log
                self._ConcreteProgrammer__cur_align_prob_log = None

            # only one path
            if len(_paths) == 1:
                if not self.filter_program_by_execution(world, _paths[0]):
                    continue
                _path_lf = world.action_sequence_to_logical_form(_paths[0])
                _seq_score = sketch_log_score
                if max_score is None or _seq_score > max_score:
                    max_score = _seq_score
                    best_sketch_lf = sketch_lf
                    best_sketch_actions = sketch_actions
                    best_program_lf = _path_lf
                    best_program_actions = _paths[0]
                continue

            # multiple path
            for _path, _score in zip(_paths, _log_scores):
                if not self.filter_program_by_execution(world, _path):
                    continue
                assert _score is not None
                _path_lf = world.action_sequence_to_logical_form(_path)
                _seq_score = _score + sketch_log_score
                if max_score is None or _seq_score > max_score:
                    max_score = _seq_score
                    best_sketch_lf = sketch_lf
                    best_sketch_actions = sketch_actions
                    best_program_lf = _path_lf
                    best_program_actions = _path

        assert max_score is not None
        ret_dic["best_program_lf"] = best_program_lf
        ret_dic["best_program_actions"] = best_program_actions
        ret_dic["best_sketch_lf"] = best_sketch_lf
        ret_dic["best_sketch_actions"] = best_sketch_actions
        ret_dic["best_score"] = torch.exp(max_score)
        ret_dic["is_multi_col"] = check_multi_col(world, best_sketch_actions,
                                                  best_program_actions)

        if best_sketch_lf in sketch2program:
            ret_dic["sketch_triggered"] = True
            if best_program_lf in sketch2program[best_sketch_lf]:
                ret_dic["lf_triggered"] = True
            else:
                ret_dic["lf_triggered"] = False
        else:
            ret_dic["sketch_triggered"] = False
            ret_dic["lf_triggered"] = False

        return ret_dic
예제 #13
0
    def slot_filling(self, world: WikiSQLLanguage, token_encodes: torch.Tensor,
                     token_state: torch.Tensor, candidate_rep_dic: Dict,
                     sketch_actions: List):
        """
        1) collect scores for each individual slot 2) find all the paths recursively
        """
        slot_dict = world.get_slot_dict(sketch_actions)
        sketch_encodes, sketch_rep = self.encode_sketch(
            sketch_actions, token_state)
        candidate_score_dic = self.collect_candidate_scores(
            world, token_encodes, candidate_rep_dic, sketch_encodes, slot_dict)

        possible_paths = []
        path_scores = []

        def recur_compute(prefix, score, i):
            if i == len(sketch_actions):
                possible_paths.append(prefix)
                path_scores.append(score)
                return
            if i in slot_dict:
                _slot_type = slot_dict[i]
                if _slot_type not in candidate_rep_dic:
                    return  # this sketch does not apply here

                slot_rep = sketch_encodes[i]
                candidate_v, candidiate_actions = candidate_rep_dic[_slot_type]

                if len(candidiate_actions) == 1:
                    action = candidiate_actions[0]
                    new_prefix = prefix[:]
                    if isinstance(action, list):
                        new_prefix += action
                    else:
                        new_prefix.append(action)
                    recur_compute(new_prefix, score, i + 1)
                    return

                if len(candidiate_actions) > self.CANDIDATE_ACTION_NUM_BOUND:
                    _, top_k = torch.topk(candidate_score_dic[i],
                                          self.CANDIDATE_ACTION_NUM_BOUND,
                                          dim=0)
                    ac_idxs = top_k.cpu().numpy()
                else:
                    ac_idxs = range(len(candidiate_actions))

                # for ac_ind, action in enumerate(candidiate_actions):
                for ac_ind in ac_idxs:
                    action = candidiate_actions[ac_ind]
                    new_prefix = prefix[:]
                    if score:
                        new_score = score + candidate_score_dic[i][ac_ind]
                    else:
                        new_score = candidate_score_dic[i][ac_ind]
                    if isinstance(action, list):
                        new_prefix += action
                    else:
                        new_prefix.append(action)
                    recur_compute(new_prefix, new_score, i + 1)
            else:
                new_prefix = prefix[:]
                new_prefix.append(sketch_actions[i])
                recur_compute(new_prefix, score, i + 1)

        recur_compute([], None, 0)
        return possible_paths, path_scores
예제 #14
0
 def sketch_lf2actions(self, world: WikiSQLLanguage):
     lf2actions = dict()
     for actions in self.sketch_actions_cache:
         lf = world.action_sequence_to_logical_form(actions)
         lf2actions[lf] = actions
     return lf2actions