예제 #1
0
def get_sketch_prod(examples: List, table_dict: Dict) -> List:
    """
    If it contains all three types of columns, then the grammar is complete
    Also return sketch action list and their slots 
    """
    for example in examples:

        table_id = example["context"]
        processed_table = table_dict[table_id]
        context = WikiSQLContext.read_from_json(example, processed_table)
        context.take_features(example)
        world = WikiSQLLanguage(context)

        if len(context.column_types) >= 2 and len(context._num2id) > 0 and \
            len(context._entity2id) > 0:
            actions = world.get_nonterminal_productions()
            sketch_actions = world._get_sketch_productions(actions)

            # index all the possible actions
            action_set = set()
            for k, v in sketch_actions.items():
                action_set = action_set.union(set(v))
            id2prod = list(action_set)
            prod2id = {v: k for k, v in enumerate(id2prod)}

            return id2prod, prod2id
예제 #2
0
def get_sketch_prod_and_slot(examples: List, table_dict: Dict,
                             sketch_list: List, sketch_action_list: List):
    """
    If it contains all three types of columns, then the grammar is complete
    Also return sketch action list and their slots 
    Used for pruned version
    """
    for example in examples:

        table_id = example["context"]
        processed_table = table_dict[table_id]
        context = WikiSQLContext.read_from_json(example, processed_table)
        context.take_features(example)
        world = WikiSQLLanguage(context)

        if len(context.column_types) >= 2 and len(context._num2id) > 0 and \
            len(context._entity2id) > 0:
            actions = world.get_nonterminal_productions()
            sketch_actions = world._get_sketch_productions(actions)

            # index all the possible actions
            action_set = set()
            for k, v in sketch_actions.items():
                action_set = action_set.union(set(v))
            id2prod = list(action_set)
            prod2id = {v: k for k, v in enumerate(id2prod)}

            # lf to actions
            sketch_lf2actions = dict()
            for sketch_actions in sketch_action_list:
                lf = world.action_sequence_to_logical_form(sketch_actions)
                sketch_lf2actions[lf] = sketch_actions

            # sort by length in decreasing order
            slot_dict = defaultdict(dict)
            sketch_action_seqs = []
            for sketch in sketch_list:
                sketch_actions = sketch_lf2actions[sketch]
                sketch_actions = tuple(sketch_actions)
                sketch_action_seqs.append(sketch_actions)

                for action_ind, action in enumerate(sketch_actions):
                    assert action in prod2id
                    lhs, rhs = action.split(" -> ")
                    if lhs in [
                            "Column", "StringColumn", "NumberColumn",
                            "ComparableColumn", "DateColumn", "str", "Number",
                            "Date"
                    ] and rhs == "#PH#":
                        slot_dict[sketch_actions][action_ind] = lhs
                    elif lhs == "List[Row]" and rhs == "#PH#":
                        slot_dict[sketch_actions][action_ind] = lhs

            return id2prod, prod2id, sketch_action_seqs, slot_dict
예제 #3
0
    def decode(self, world: WikiSQLLanguage, token_rep: torch.Tensor,
               token_encodes: torch.Tensor):
        """
        Input: a sequence of sketch actions
        Output: the most probable sequence
        """
        action_dict = world._get_sketch_productions(
            world.get_nonterminal_productions())
        initial_rnn_state = self._get_initial_state(token_rep)

        stack = [START_SYMBOL]
        history = []
        rnn_state = initial_rnn_state
        for i in range(self._max_decoding_steps):
            if len(stack) == 0: break

            cur_non_terminal = stack.pop()
            if cur_non_terminal not in action_dict: continue
            candidates = action_dict[cur_non_terminal]
            candidate_ids = [self.sketch_prod2id[ac] for ac in candidates]

            cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell
            next_hidden, next_memory = self.decoder_lstm(
                rnn_state.previous_action_embedding, (cur_hidden, cur_memory))
            hidden_tran = next_hidden.transpose(0, 1)
            att_feat_v = torch.mm(token_encodes, hidden_tran)  # sent_len * 1
            att_v = F.softmax(att_feat_v, dim=0)
            att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes)

            score_feat_v = torch.cat([next_hidden, att_ret_v], 1)
            score_v = self.score_action(score_feat_v).squeeze()
            filter_score_v_list = [score_v[_id] for _id in candidate_ids]
            filter_score_v = torch.stack(filter_score_v_list, 0)
            prob_v = F.softmax(filter_score_v, dim=0)

            _, pred_id = torch.max(prob_v, dim=0)
            pred_id = pred_id.cpu().item()

            next_action_embed = self.sketch_embed.weight[
                candidate_ids[pred_id]].unsqueeze(0)
            rnn_state = RnnStatelet(next_hidden, next_memory,
                                    next_action_embed, None, None, None)

            prod = candidates[pred_id]
            history.append(prod)
            non_terminals = self._get_right_side_parts(prod)
            stack += list(reversed(non_terminals))

        return tuple(history)
예제 #4
0
    def forward(self, world: WikiSQLLanguage, token_rep: torch.Tensor,
                token_encodes: torch.Tensor, sketch_actions: List):
        """
        Input: a sequence of sketch actions
        """
        action_dict = world._get_sketch_productions(
            world.get_nonterminal_productions())
        initial_rnn_state = self._get_initial_state(token_rep)

        seq_likeli = []
        rnn_state = initial_rnn_state
        for i, prod in enumerate(sketch_actions):
            left_side, _ = prod.split(" -> ")
            candidates = action_dict[left_side]
            candidate_ids = [self.sketch_prod2id[ac] for ac in candidates]

            cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell
            next_hidden, next_memory = self.decoder_lstm(
                rnn_state.previous_action_embedding, (cur_hidden, cur_memory))
            hidden_tran = next_hidden.transpose(0, 1)
            att_feat_v = torch.mm(token_encodes, hidden_tran)  # sent_len * 1
            att_v = F.softmax(att_feat_v, dim=0)
            att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes)

            score_feat_v = torch.cat([next_hidden, att_ret_v], 1)
            score_v = self.score_action(score_feat_v).squeeze()
            filter_score_v_list = [score_v[_id] for _id in candidate_ids]
            filter_score_v = torch.stack(filter_score_v_list, 0)
            log_likeli = F.log_softmax(filter_score_v, dim=0)

            gold_id = candidate_ids.index(self.sketch_prod2id[prod])
            seq_likeli.append(log_likeli[gold_id])

            next_action_embed = self.sketch_embed.weight[
                self.sketch_prod2id[prod]].unsqueeze(0)
            rnn_state = RnnStatelet(next_hidden, next_memory,
                                    next_action_embed, None, None, None)

        return sum(seq_likeli)
예제 #5
0
def check_multi_col(world: WikiSQLLanguage, sketch_actions: List,
                    program_actions: List) -> bool:
    prod_dic = world.get_nonterminal_productions()
    slot_dic = gen_slot2action_dic(world, prod_dic, sketch_actions,
                                   program_actions)
    row_slot_acs = []
    col_slot_acs = []
    for idx in slot_dic:
        slot_type = get_left_side_part(sketch_actions[idx])
        if slot_type == "List[Row]":
            row_slot_acs.append(slot_dic[idx])
        else:
            col_slot_acs.append(slot_dic[idx])

    if len(row_slot_acs) == 0 or len(col_slot_acs) == 0:
        return False

    for col_slot_ac in col_slot_acs:
        col_name = get_right_side_parts(col_slot_ac)[0]
        for row_slot_ac in row_slot_acs:
            if col_name not in "_".join(row_slot_ac):
                return True
    return False
예제 #6
0
    def beam_decode(self, world: WikiSQLLanguage, token_rep: torch.Tensor,
                    token_encodes: torch.Tensor, beam_size: int):
        """
        Input: a sequence of sketch actions
        Output: output top-k most probable sequence
        """
        action_dict = world._get_sketch_productions(
            world.get_nonterminal_productions())
        initial_rnn_state = self._get_initial_state(token_rep)

        incomplete = [([START_SYMBOL], [], initial_rnn_state, None)
                      ]  # stack,history,rnn_state
        completed = []

        for i in range(self._max_decoding_steps):
            next_paths = []
            for stack, history, rnn_state, seq_score in incomplete:
                cur_non_terminal = stack.pop()
                if cur_non_terminal not in action_dict: continue
                candidates = action_dict[cur_non_terminal]
                candidate_ids = [self.sketch_prod2id[ac] for ac in candidates]

                cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell
                next_hidden, next_memory = self.decoder_lstm(
                    rnn_state.previous_action_embedding,
                    (cur_hidden, cur_memory))
                hidden_tran = next_hidden.transpose(0, 1)
                att_feat_v = torch.mm(token_encodes,
                                      hidden_tran)  # sent_len * 1
                att_v = F.softmax(att_feat_v, dim=0)
                att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes)

                score_feat_v = torch.cat([next_hidden, att_ret_v], 1)
                score_v = self.score_action(score_feat_v).squeeze()
                filter_score_v_list = [score_v[_id] for _id in candidate_ids]
                filter_score_v = torch.stack(filter_score_v_list, 0)
                prob_v = F.log_softmax(filter_score_v, dim=0)

                pred_logits, pred_ids = torch.topk(prob_v,
                                                   min(beam_size,
                                                       prob_v.size()[0]),
                                                   dim=0)

                for _logits, _idx in zip(pred_logits, pred_ids):
                    next_action_embed = self.sketch_embed.weight[
                        candidate_ids[_idx]].unsqueeze(0)
                    rnn_state = RnnStatelet(next_hidden, next_memory,
                                            next_action_embed, None, None,
                                            None)

                    prod = candidates[_idx]
                    _history = history[:]
                    _history.append(prod)
                    non_terminals = self._get_right_side_parts(prod)
                    _stack = stack[:]
                    for ac in reversed(non_terminals):
                        if ac in action_dict:
                            _stack.append(ac)
                    if seq_score is None:
                        _score = _logits
                    else:
                        _score = _logits + seq_score

                    next_paths.append((_stack, _history, rnn_state, _score))

            incomplete = []
            for stack, history, rnn_state, seq_score in next_paths:
                if len(stack) == 0:
                    if world.action_sequence_to_logical_form(
                            history) != "#PH#":
                        completed.append((history, seq_score))
                else:
                    incomplete.append((stack, history, rnn_state, seq_score))

            if len(completed) > beam_size:
                completed = sorted(completed, key=lambda x: -x[1])
                completed = completed[:beam_size]
                break

            if len(incomplete) > beam_size:
                incomplete = sorted(incomplete, key=lambda x: -x[3])
                incomplete = incomplete[:beam_size]

        return completed