Exemple #1
0
    def _walk(self, actions: Dict, prod_score_dic: Dict,
              world: WikiTablesVariableFreeWorld) -> List:
        """
        search in the action space without data selection operations like lookup.
        the operations used reflect the semantics of a question, it is more abstract and the space would be much smaller
        """
        # Buffer of NTs to expand, previous actions
        incomplete_paths = [([str(type_)], [f"{START_SYMBOL} -> {type_}"],
                             prod_score_dic[f"{START_SYMBOL} -> {type_}"])
                            for type_ in world.get_valid_starting_types()]

        _completed_paths = []
        multi_match_substitutions = world.get_multi_match_mapping()

        while incomplete_paths:
            next_paths = []
            for nonterminal_buffer, history, cur_score in incomplete_paths:
                # Taking the last non-terminal added to the buffer. We're going depth-first.
                nonterminal = nonterminal_buffer.pop()
                next_actions = []
                if nonterminal in multi_match_substitutions:
                    for current_nonterminal in [
                            nonterminal
                    ] + multi_match_substitutions[nonterminal]:
                        if current_nonterminal in actions:
                            next_actions.extend(actions[current_nonterminal])
                elif nonterminal not in actions:
                    continue
                else:
                    next_actions.extend(actions[nonterminal])
                # Iterating over all possible next actions.
                for action in next_actions:
                    if action not in [
                            "e -> mul_row_select", "r -> one_row_select"
                    ] and action in history:
                        continue
                    new_history = history + [action]
                    new_nonterminal_buffer = nonterminal_buffer[:]
                    # Since we expand the last action added to the buffer, the left child should be
                    # added after the right child.
                    for right_side_part in reversed(
                            self._get_right_side_parts(action)):
                        if types.is_nonterminal(right_side_part):
                            new_nonterminal_buffer.append(right_side_part)
                    new_prod_score = prod_score_dic[action] + cur_score
                    next_paths.append(
                        (new_nonterminal_buffer, new_history, new_prod_score))
            incomplete_paths = []
            for nonterminal_buffer, path, score in next_paths:
                # An empty buffer means that we've completed this path.
                if not nonterminal_buffer:
                    # if path only has two operations, it is start->string:
                    if len(path) > 2:
                        _completed_paths.append((path, score))
                elif len(path) < self._max_path_length:
                    incomplete_paths.append((nonterminal_buffer, path, score))
        return _completed_paths
Exemple #2
0
    def _score_prod(self, token_reps: torch.Tensor, actions: List[str],
                    world: WikiTablesVariableFreeWorld) -> Dict:
        """
        produce scores for each production rule
        return dict that has a scalar score for each production rule 
        """
        sent_len, _token_rnn_feat_size = token_reps.size()
        assert self.token_rnn_feat_size == _token_rnn_feat_size
        action_list = []
        for k, v in actions.items():
            action_list += v
        action_list += [
            f"{START_SYMBOL} -> {type_}"
            for type_ in world.get_valid_starting_types()
        ]
        action_list = list(set(action_list))
        prod_num = len(action_list)

        prod_id_list = [self.prod2id[prod] for prod in action_list]
        prod_id_tensor = torch.LongTensor(prod_id_list)
        prod_id_vec = self.prod_embed(
            prod_id_tensor)  # prod_num * prod_embed_size
        token_mat = token_reps.unsqueeze(0).expand(prod_num, sent_len,
                                                   self.token_rnn_feat_size)

        att_scores = self.att(prod_id_vec, token_mat)  # prod_num * sent_len
        att_rep_vec = torch.mm(att_scores,
                               token_reps)  # prod_num * token_embed
        feat_vec = torch.cat([prod_id_vec, att_rep_vec], 1)
        hiddden_vec = F.relu(self.hidden_func(feat_vec))
        score_vec = self.score_func(hiddden_vec)  # prod_num * 1
        score_vec = score_vec.squeeze(1)  # prod_num

        prod_score_dic = dict()
        for i, prod in enumerate(action_list):
            prod_score_dic[prod] = score_vec[i]
        return prod_score_dic