def _walk(self, actions: Dict, prod_score_dic: Dict, world: WikiTablesVariableFreeWorld) -> List: """ search in the action space without data selection operations like lookup. the operations used reflect the semantics of a question, it is more abstract and the space would be much smaller """ # Buffer of NTs to expand, previous actions incomplete_paths = [([str(type_)], [f"{START_SYMBOL} -> {type_}"], prod_score_dic[f"{START_SYMBOL} -> {type_}"]) for type_ in world.get_valid_starting_types()] _completed_paths = [] multi_match_substitutions = world.get_multi_match_mapping() while incomplete_paths: next_paths = [] for nonterminal_buffer, history, cur_score in incomplete_paths: # Taking the last non-terminal added to the buffer. We're going depth-first. nonterminal = nonterminal_buffer.pop() next_actions = [] if nonterminal in multi_match_substitutions: for current_nonterminal in [ nonterminal ] + multi_match_substitutions[nonterminal]: if current_nonterminal in actions: next_actions.extend(actions[current_nonterminal]) elif nonterminal not in actions: continue else: next_actions.extend(actions[nonterminal]) # Iterating over all possible next actions. for action in next_actions: if action not in [ "e -> mul_row_select", "r -> one_row_select" ] and action in history: continue new_history = history + [action] new_nonterminal_buffer = nonterminal_buffer[:] # Since we expand the last action added to the buffer, the left child should be # added after the right child. for right_side_part in reversed( self._get_right_side_parts(action)): if types.is_nonterminal(right_side_part): new_nonterminal_buffer.append(right_side_part) new_prod_score = prod_score_dic[action] + cur_score next_paths.append( (new_nonterminal_buffer, new_history, new_prod_score)) incomplete_paths = [] for nonterminal_buffer, path, score in next_paths: # An empty buffer means that we've completed this path. if not nonterminal_buffer: # if path only has two operations, it is start->string: if len(path) > 2: _completed_paths.append((path, score)) elif len(path) < self._max_path_length: incomplete_paths.append((nonterminal_buffer, path, score)) return _completed_paths
def _score_prod(self, token_reps: torch.Tensor, actions: List[str], world: WikiTablesVariableFreeWorld) -> Dict: """ produce scores for each production rule return dict that has a scalar score for each production rule """ sent_len, _token_rnn_feat_size = token_reps.size() assert self.token_rnn_feat_size == _token_rnn_feat_size action_list = [] for k, v in actions.items(): action_list += v action_list += [ f"{START_SYMBOL} -> {type_}" for type_ in world.get_valid_starting_types() ] action_list = list(set(action_list)) prod_num = len(action_list) prod_id_list = [self.prod2id[prod] for prod in action_list] prod_id_tensor = torch.LongTensor(prod_id_list) prod_id_vec = self.prod_embed( prod_id_tensor) # prod_num * prod_embed_size token_mat = token_reps.unsqueeze(0).expand(prod_num, sent_len, self.token_rnn_feat_size) att_scores = self.att(prod_id_vec, token_mat) # prod_num * sent_len att_rep_vec = torch.mm(att_scores, token_reps) # prod_num * token_embed feat_vec = torch.cat([prod_id_vec, att_rep_vec], 1) hiddden_vec = F.relu(self.hidden_func(feat_vec)) score_vec = self.score_func(hiddden_vec) # prod_num * 1 score_vec = score_vec.squeeze(1) # prod_num prod_score_dic = dict() for i, prod in enumerate(action_list): prod_score_dic[prod] = score_vec[i] return prod_score_dic