def get_sketch_prod(examples: List, table_dict: Dict) -> List: """ If it contains all three types of columns, then the grammar is complete Also return sketch action list and their slots """ for example in examples: table_id = example["context"] table_lines = table_dict[table_id]["raw_lines"] tokenized_question = [Token(token) for token in example["tokens"]] context = TableQuestionContext.read_from_lines(table_lines, tokenized_question) context.take_corenlp_entities(example["entities"]) world = WikiTableAbstractLanguage(context) if len(context.column_types) >= 3 and len(context._num2id) > 0 and \ len(context._entity2id) > 0 and len(context._date2id) > 0: actions = world.get_nonterminal_productions() sketch_actions = world._get_sketch_productions(actions) # index all the possible actions action_set = set() for k, v in sketch_actions.items(): action_set = action_set.union(set(v)) id2prod = list(action_set) prod2id = {v: k for k, v in enumerate(id2prod)} return id2prod, prod2id
def get_sketch_prod_and_slot(examples: List, table_dict: Dict, sketch_list: List, sketch_action_list: List): """ If it contains all three types of columns, then the grammar is complete Also return sketch action list and their slots """ for example in examples: table_id = example["context"] table_lines = table_dict[table_id]["raw_lines"] tokenized_question = [Token(token) for token in example["tokens"]] context = TableQuestionContext.read_from_lines(table_lines, tokenized_question) context.take_corenlp_entities(example["entities"]) # annoymize number and date # context.annoymized_tokens = example["tmp_tokens"] world = WikiTableAbstractLanguage(context) if len(context.column_types) >= 3 and len(context._num2id) > 0 and \ len(context._entity2id) > 0 and len(context._date2id) > 0: actions = world.get_nonterminal_productions() sketch_actions = world._get_sketch_productions(actions) # index all the possible actions action_set = set() for k, v in sketch_actions.items(): action_set = action_set.union(set(v)) id2prod = list(action_set) prod2id = {v: k for k, v in enumerate(id2prod)} # return id2prod, prod2id # lf to actions sketch_lf2actions = dict() for sketch_actions in sketch_action_list: lf = world.action_sequence_to_logical_form(sketch_actions) sketch_lf2actions[lf] = sketch_actions # sort by length in decreasing order slot_dict = defaultdict(dict) sketch_action_seqs = [] for sketch in sketch_list: sketch_actions = sketch_lf2actions[sketch] sketch_actions = tuple(sketch_actions) sketch_action_seqs.append(sketch_actions) for action_ind, action in enumerate(sketch_actions): assert action in prod2id lhs, rhs = action.split(" -> ") if lhs in [ "Column", "StringColumn", "NumberColumn", "ComparableColumn", "DateColumn", "str", "Number", "Date" ] and rhs == "#PH#": slot_dict[sketch_actions][action_ind] = lhs elif lhs == "List[Row]" and rhs == "#PH#": slot_dict[sketch_actions][action_ind] = lhs return id2prod, prod2id, sketch_action_seqs, slot_dict
def decode(self, world: WikiTableAbstractLanguage, token_rep: torch.Tensor, token_encodes: torch.Tensor): """ Input: a sequence of sketch actions Output: the most probable sequence """ action_dict = world._get_sketch_productions( world.get_nonterminal_productions()) initial_rnn_state = self._get_initial_state(token_rep) stack = [START_SYMBOL] history = [] rnn_state = initial_rnn_state for i in range(self._max_decoding_steps): if len(stack) == 0: break cur_non_terminal = stack.pop() if cur_non_terminal not in action_dict: continue candidates = action_dict[cur_non_terminal] candidate_ids = [self.sketch_prod2id[ac] for ac in candidates] cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell next_hidden, next_memory = self.decoder_lstm( rnn_state.previous_action_embedding, (cur_hidden, cur_memory)) hidden_tran = next_hidden.transpose(0, 1) att_feat_v = torch.mm(token_encodes, hidden_tran) # sent_len * 1 att_v = F.softmax(att_feat_v, dim=0) att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes) score_feat_v = torch.cat([next_hidden, att_ret_v], 1) score_v = self.score_action(score_feat_v).squeeze() filter_score_v_list = [score_v[_id] for _id in candidate_ids] filter_score_v = torch.stack(filter_score_v_list, 0) prob_v = F.softmax(filter_score_v, dim=0) _, pred_id = torch.max(prob_v, dim=0) pred_id = pred_id.cpu().item() next_action_embed = self.sketch_embed.weight[ candidate_ids[pred_id]].unsqueeze(0) rnn_state = RnnStatelet(next_hidden, next_memory, next_action_embed, None, None, None) prod = candidates[pred_id] history.append(prod) non_terminals = self._get_right_side_parts(prod) stack += list(reversed(non_terminals)) return tuple(history)
def forward(self, context: TableQuestionContext, sketch2program: Dict) -> torch.Tensor: world = WikiTableAbstractLanguage(context) # encode questions token_in_table_feat = context.question_in_table_feat token_encodes, token_reps, last_state = self.encode_question(context.question_tokens, token_in_table_feat) sketch_lf2actions = self.sketch_lf2actions(world) consistent_scores = [] candidate_rep_dic = self.construct_candidates(world, token_encodes) for sketch_lf in sketch2program: if len(sketch2program[sketch_lf]) > self.CONSISTENT_INST_NUM_BOUND: continue sketch_actions = sketch_lf2actions[sketch_lf] seq_log_likeli = self.seq2seq(world, token_reps, token_encodes, sketch_actions) _paths, _log_scores = self.slot_filling(world, token_encodes, last_state, candidate_rep_dic, sketch_actions) # only one path if len(_paths) == 1: consistent_scores.append(seq_log_likeli) continue _gold_scores = [] for _path, _score in zip(_paths, _log_scores): assert _score is not None _path_lf = world.action_sequence_to_logical_form(_path) if _path_lf in sketch2program[sketch_lf]: _gold_scores.append(_score) # aggregate consistent instantiations if len(_gold_scores) > 0: _score = seq_log_likeli + log_sum_exp(_gold_scores) if torch.isnan(_score) == 0: consistent_scores.append(_score) else: logger.warning("Nan loss founded!") if len(consistent_scores) > 0: return -1 * log_sum_exp(consistent_scores) else: return None
def compute_entropy(self, context: TableQuestionContext, sketch2program: Dict) -> bool: world = WikiTableAbstractLanguage(context) ret_dic = defaultdict(int) # encode question and offline sketches token_in_table_feat = context.question_in_table_feat token_encodes, token_reps, last_state = self.encode_question(context.question_tokens, token_in_table_feat) entropy = [] sketch_lf2actions = self.sketch_lf2actions(world) candidate_rep_dic = self.construct_candidates(world, token_encodes) for sketch_lf in sketch2program: sketch_actions = sketch_lf2actions[sketch_lf] sketch_log_score = self.seq2seq(world, token_reps, token_encodes, sketch_actions) _paths, _log_scores = self.slot_filling(world, token_encodes, last_state, candidate_rep_dic, sketch_actions) # only one path if len(_paths) == 1: if not self.filter_program_by_execution(world, _paths[0]): continue _path_lf = world.action_sequence_to_logical_form(_paths[0]) _seq_score = sketch_log_score if _path_lf in sketch2program[sketch_lf]: entropy.append(-1 * _seq_score * torch.exp(_seq_score)) continue # multiple path for _path, _score in zip(_paths, _log_scores): if not self.filter_program_by_execution(world, _path): continue assert _score is not None _path_lf = world.action_sequence_to_logical_form(_path) _seq_score = _score + sketch_log_score if _path_lf in sketch2program[sketch_lf]: entropy.append(-1 * _seq_score * torch.exp(_seq_score)) if len(entropy) > 0: ret_dic["entropy"] = sum(entropy).cpu().item() return ret_dic
def filter_program_by_execution(self, world:WikiTableAbstractLanguage, actions: List): try: ret = world.execute_action_sequence(actions) if ret: return True else: return False except: return False
def forward(self, world: WikiTableAbstractLanguage, token_rep: torch.Tensor, token_encodes: torch.Tensor, sketch_actions: List): """ Input: a sequence of sketch actions """ action_dict = world._get_sketch_productions( world.get_nonterminal_productions()) initial_rnn_state = self._get_initial_state(token_rep) seq_likeli = [] rnn_state = initial_rnn_state for i, prod in enumerate(sketch_actions): left_side, _ = prod.split(" -> ") candidates = action_dict[left_side] candidate_ids = [self.sketch_prod2id[ac] for ac in candidates] cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell next_hidden, next_memory = self.decoder_lstm( rnn_state.previous_action_embedding, (cur_hidden, cur_memory)) hidden_tran = next_hidden.transpose(0, 1) att_feat_v = torch.mm(token_encodes, hidden_tran) # sent_len * 1 att_v = F.softmax(att_feat_v, dim=0) att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes) score_feat_v = torch.cat([next_hidden, att_ret_v], 1) score_v = self.score_action(score_feat_v).squeeze() filter_score_v_list = [score_v[_id] for _id in candidate_ids] filter_score_v = torch.stack(filter_score_v_list, 0) log_likeli = F.log_softmax(filter_score_v, dim=0) gold_id = candidate_ids.index(self.sketch_prod2id[prod]) seq_likeli.append(log_likeli[gold_id]) next_action_embed = self.sketch_embed.weight[ self.sketch_prod2id[prod]].unsqueeze(0) rnn_state = RnnStatelet(next_hidden, next_memory, next_action_embed, None, None, None) return sum(seq_likeli)
def check_multi_col(world: WikiTableAbstractLanguage, sketch_actions: List, program_actions: List) -> bool: prod_dic = world.get_nonterminal_productions() slot_dic = gen_slot2action_dic(world, prod_dic, sketch_actions, program_actions) row_slot_acs = [] col_slot_acs = [] for idx in slot_dic: slot_type = get_left_side_part(sketch_actions[idx]) if slot_type == "List[Row]": row_slot_acs.append(slot_dic[idx]) else: col_slot_acs.append(slot_dic[idx]) if len(row_slot_acs) == 0 or len(col_slot_acs) == 0: return False for col_slot_ac in col_slot_acs: col_name = get_right_side_parts(col_slot_ac)[0] for row_slot_ac in row_slot_acs: if col_name not in "_".join(row_slot_ac): return True return False
def evaluate(self, context: TableQuestionContext, sketch2program: Dict) -> bool: world = WikiTableAbstractLanguage(context) ret_dic = defaultdict(int) # encode question and offline sketches token_in_table_feat = context.question_in_table_feat token_encodes, token_reps, last_state = self.encode_question(context.question_tokens, token_in_table_feat) sketch_actions_and_scores = self.seq2seq.beam_decode(world, token_reps, token_encodes, self.EVAL_NUM_SKETCH_BOUND) max_score = None best_sketch_actions = None best_sketch_lf = None best_program_actions = None best_program_lf = None candidate_rep_dic = self.construct_candidates(world, token_encodes) for sketch_actions, sketch_log_score in sketch_actions_and_scores: sketch_lf = world.action_sequence_to_logical_form(sketch_actions) _paths, _log_scores = self.slot_filling(world, token_encodes, last_state, candidate_rep_dic, sketch_actions) # logger.info(f"{sketch_lf}, score: {torch.exp(sketch_log_score)}") if self.__class__.__name__ == "ConcreteProgrammer": assert self._ConcreteProgrammer__cur_align_prob_log is not None align_prob_log = self._ConcreteProgrammer__cur_align_prob_log.squeeze() # print(f"Align matrix prob: {torch.exp(align_prob_log)}") sketch_log_score = sketch_log_score + align_prob_log self._ConcreteProgrammer__cur_align_prob_log = None # only one path if len(_paths) == 1: if not self.filter_program_by_execution(world, _paths[0]): continue _path_lf = world.action_sequence_to_logical_form(_paths[0]) _seq_score = sketch_log_score if max_score is None or _seq_score > max_score: max_score = _seq_score best_sketch_lf = sketch_lf best_sketch_actions = sketch_actions best_program_lf = _path_lf best_program_actions = _paths[0] continue # multiple path for _path, _score in zip(_paths, _log_scores): if not self.filter_program_by_execution(world, _path): continue assert _score is not None _path_lf = world.action_sequence_to_logical_form(_path) _seq_score = _score + sketch_log_score if max_score is None or _seq_score > max_score: max_score = _seq_score best_sketch_lf = sketch_lf best_sketch_actions = sketch_actions best_program_lf = _path_lf best_program_actions = _path assert max_score is not None ret_dic["best_program_lf"] = best_program_lf ret_dic["best_program_actions"] = best_program_actions ret_dic["best_sketch_lf"] = best_sketch_lf ret_dic["best_sketch_actions"] = best_sketch_actions ret_dic["best_score"] = torch.exp(max_score) ret_dic["is_multi_col"] = check_multi_col(world, best_sketch_actions, best_program_actions) if best_sketch_lf in sketch2program: ret_dic["sketch_triggered"] = True if best_program_lf in sketch2program[best_sketch_lf]: ret_dic["lf_triggered"] = True else: ret_dic["lf_triggered"] = False else: ret_dic["sketch_triggered"] = False ret_dic["lf_triggered"] = False return ret_dic
def slot_filling(self, world:WikiTableAbstractLanguage, token_encodes:torch.Tensor, token_state:torch.Tensor, candidate_rep_dic: Dict, sketch_actions: List): """ 1) collect scores for each individual slot 2) find all the paths recursively """ slot_dict = world.get_slot_dict(sketch_actions) sketch_encodes, sketch_rep = self.encode_sketch(sketch_actions, token_state) candidate_score_dic = self.collect_candidate_scores(world, token_encodes, candidate_rep_dic, sketch_encodes, slot_dict) possible_paths = [] path_scores = [] def recur_compute(prefix, score, i): if i == len(sketch_actions): possible_paths.append(prefix) path_scores.append(score) return if i in slot_dict: _slot_type = slot_dict[i] if _slot_type not in candidate_rep_dic: return # this sketch does not apply here slot_rep = sketch_encodes[i] candidate_v, candidiate_actions = candidate_rep_dic[_slot_type] if len(candidiate_actions) == 1: action = candidiate_actions[0] new_prefix = prefix[:] if isinstance(action, list): new_prefix += action else: new_prefix.append(action) recur_compute(new_prefix, score, i + 1) return if len(candidiate_actions) > self.CANDIDATE_ACTION_NUM_BOUND: _, top_k = torch.topk(candidate_score_dic[i], self.CANDIDATE_ACTION_NUM_BOUND, dim=0) ac_idxs = top_k.cpu().numpy() else: ac_idxs = range(len(candidiate_actions)) # for ac_ind, action in enumerate(candidiate_actions): for ac_ind in ac_idxs: action = candidiate_actions[ac_ind] new_prefix = prefix[:] if score: new_score = score + candidate_score_dic[i][ac_ind] else: new_score = candidate_score_dic[i][ac_ind] if isinstance(action, list): new_prefix += action else: new_prefix.append(action) recur_compute(new_prefix, new_score, i + 1) else: new_prefix = prefix[:] new_prefix.append(sketch_actions[i]) recur_compute(new_prefix, score, i + 1) recur_compute([], None, 0) return possible_paths, path_scores
def sketch_lf2actions(self, world: WikiTableAbstractLanguage): lf2actions = dict() for actions in self.sketch_actions_cache: lf = world.action_sequence_to_logical_form(actions) lf2actions[lf] = actions return lf2actions
def beam_decode(self, world: WikiTableAbstractLanguage, token_rep: torch.Tensor, token_encodes: torch.Tensor, beam_size: int): """ Input: a sequence of sketch actions Output: output top-k most probable sequence """ action_dict = world._get_sketch_productions( world.get_nonterminal_productions()) initial_rnn_state = self._get_initial_state(token_rep) incomplete = [([START_SYMBOL], [], initial_rnn_state, None) ] # stack,history,rnn_state completed = [] for i in range(self._max_decoding_steps): next_paths = [] for stack, history, rnn_state, seq_score in incomplete: cur_non_terminal = stack.pop() if cur_non_terminal not in action_dict: continue candidates = action_dict[cur_non_terminal] candidate_ids = [self.sketch_prod2id[ac] for ac in candidates] cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell next_hidden, next_memory = self.decoder_lstm( rnn_state.previous_action_embedding, (cur_hidden, cur_memory)) hidden_tran = next_hidden.transpose(0, 1) att_feat_v = torch.mm(token_encodes, hidden_tran) # sent_len * 1 att_v = F.softmax(att_feat_v, dim=0) att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes) score_feat_v = torch.cat([next_hidden, att_ret_v], 1) score_v = self.score_action(score_feat_v).squeeze() filter_score_v_list = [score_v[_id] for _id in candidate_ids] filter_score_v = torch.stack(filter_score_v_list, 0) prob_v = F.log_softmax(filter_score_v, dim=0) pred_logits, pred_ids = torch.topk(prob_v, min(beam_size, prob_v.size()[0]), dim=0) for _logits, _idx in zip(pred_logits, pred_ids): next_action_embed = self.sketch_embed.weight[ candidate_ids[_idx]].unsqueeze(0) rnn_state = RnnStatelet(next_hidden, next_memory, next_action_embed, None, None, None) prod = candidates[_idx] _history = history[:] _history.append(prod) non_terminals = self._get_right_side_parts(prod) _stack = stack[:] for ac in reversed(non_terminals): if ac in action_dict: _stack.append(ac) if seq_score is None: _score = _logits else: _score = _logits + seq_score next_paths.append((_stack, _history, rnn_state, _score)) incomplete = [] for stack, history, rnn_state, seq_score in next_paths: if len(stack) == 0: if world.action_sequence_to_logical_form( history) != "#PH#": completed.append((history, seq_score)) else: incomplete.append((stack, history, rnn_state, seq_score)) if len(completed) > beam_size: completed = sorted(completed, key=lambda x: -x[1]) completed = completed[:beam_size] break if len(incomplete) > beam_size: incomplete = sorted(incomplete, key=lambda x: -x[3]) incomplete = incomplete[:beam_size] return completed
def forward(self, world: WikiTableAbstractLanguage, token_rep: torch.Tensor, token_encodes: torch.Tensor, candidate_rep_dic: torch.Tensor, sketch_actions: List, program_actions: List): """ Input: a sequence of sketch actions """ prod_action_dict = world.get_nonterminal_productions() sketch_action_dict = world._get_sketch_productions(prod_action_dict) initial_rnn_state = self._get_initial_state(token_rep) slot2action_dic = gen_slot2action_dic(world, prod_action_dict, sketch_actions, program_actions) program_ref_actions = program_actions[:] seq_likeli = [] rnn_state = initial_rnn_state for i, prod in enumerate(sketch_actions): left_side, right_side = prod.split(" -> ") if right_side != "#PH#": candidates = sketch_action_dict[left_side] candidate_ids = [self.sketch_prod2id[ac] for ac in candidates] cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell next_hidden, next_memory = self.decoder_lstm( rnn_state.previous_action_embedding, (cur_hidden, cur_memory)) hidden_tran = next_hidden.transpose(0, 1) att_feat_v = torch.mm(token_encodes, hidden_tran) # sent_len * 1 att_v = F.softmax(att_feat_v, dim=0) att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes) score_feat_v = torch.cat([next_hidden, att_ret_v], 1) score_v = self.score_action(score_feat_v).squeeze(0) filter_score_v_list = [score_v[_id] for _id in candidate_ids] filter_score_v = torch.stack(filter_score_v_list, 0) log_likeli = F.log_softmax(filter_score_v, dim=0) gold_id = candidate_ids.index(self.sketch_prod2id[prod]) seq_likeli.append(log_likeli[gold_id]) next_action_embed = self.sketch_embed.weight[ self.sketch_prod2id[prod]].unsqueeze(0) else: assert left_side == "List[Row]" or "Column" in left_side assert i in slot2action_dic candidate_v, candidate_actions = candidate_rep_dic[left_side] try: gold_id = candidate_actions.index(slot2action_dic[i]) except: # not included, e.g and/or are order-invariant return None # fit the memory for some extreme case if len(candidate_actions) > 256: _s = max(0, gold_id - 128) _e = min(gold_id + 128, len(candidate_actions)) candidate_v = candidate_v[_s:_e] candidate_actions = candidate_actions[_s:_e] gold_id = candidate_actions.index(slot2action_dic[i]) assert gold_id >= 0 cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell next_hidden, next_memory = self.decoder_lstm( rnn_state.previous_action_embedding, (cur_hidden, cur_memory)) hidden_tran = next_hidden.transpose(0, 1) att_feat_v = torch.mm(token_encodes, hidden_tran) # sent_len * 1 att_v = F.softmax(att_feat_v, dim=0) att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes) score_feat_v = torch.cat([next_hidden, att_ret_v], 1) num_candidate = candidate_v.size()[0] if left_side == "List[Row]": score_feat_v = score_feat_v.expand(num_candidate, -1) att_over_sel = self.row2score(candidate_v, score_feat_v).squeeze(1) att_over_sel = F.log_softmax(att_over_sel, dim=0) seq_likeli.append(att_over_sel[gold_id]) next_action_embed = self.row2action( candidate_v[gold_id]).unsqueeze(0) else: score_feat_v = score_feat_v.expand(num_candidate, -1) att_over_col = self.col2score(candidate_v, score_feat_v).squeeze(1) att_over_col = F.log_softmax(att_over_col, dim=0) seq_likeli.append(att_over_col[gold_id]) next_action_embed = self.col2action( candidate_v[gold_id]).unsqueeze(0) rnn_state = RnnStatelet(next_hidden, next_memory, next_action_embed, None, None, None) return sum(seq_likeli)
def decode(self, world: WikiTableAbstractLanguage, token_rep: torch.Tensor, token_encodes: torch.Tensor, candidate_rep_dic: torch.Tensor): """ Input: a sequence of sketch actions Output: the most probable sequence """ action_dict = world._get_sketch_productions( world.get_nonterminal_productions()) initial_rnn_state = self._get_initial_state(token_rep) stack = [START_SYMBOL] history = [] sketch_history = [] rnn_state = initial_rnn_state for i in range(self._max_decoding_steps): if len(stack) == 0: break cur_non_terminal = stack.pop() if cur_non_terminal not in action_dict: continue if cur_non_terminal == "List[Row]": candidate_v, candidate_actions = candidate_rep_dic[ cur_non_terminal] cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell next_hidden, next_memory = self.decoder_lstm( rnn_state.previous_action_embedding, (cur_hidden, cur_memory)) hidden_tran = next_hidden.transpose(0, 1) att_feat_v = torch.mm(token_encodes, hidden_tran) # sent_len * 1 att_v = F.softmax(att_feat_v, dim=0) att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes) score_feat_v = torch.cat([next_hidden, att_ret_v], 1) num_candidate = candidate_v.size()[0] score_feat_v = score_feat_v.expand(num_candidate, -1) att_over_sel = self.row2score(candidate_v, score_feat_v).squeeze(1) att_over_sel = F.softmax(att_over_sel, dim=0) _, pred_id = torch.max(att_over_sel, dim=0) pred_id = pred_id.cpu().item() next_action_embed = self.row2action( candidate_v[pred_id]).unsqueeze(0) history += candidate_actions[pred_id] sketch_history.append("List[Row] -> #PH#") elif "Column" == cur_non_terminal[-6:]: candidate_v, candidate_actions = candidate_rep_dic[ cur_non_terminal] cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell next_hidden, next_memory = self.decoder_lstm( rnn_state.previous_action_embedding, (cur_hidden, cur_memory)) hidden_tran = next_hidden.transpose(0, 1) att_feat_v = torch.mm(token_encodes, hidden_tran) # sent_len * 1 att_v = F.softmax(att_feat_v, dim=0) att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes) score_feat_v = torch.cat([next_hidden, att_ret_v], 1) num_candidate = candidate_v.size()[0] score_feat_v = score_feat_v.expand(num_candidate, -1) att_over_col = self.col2score(candidate_v, score_feat_v).squeeze(1) att_over_col = F.softmax(att_over_col, dim=0) _, pred_id = torch.max(att_over_col, dim=0) pred_id = pred_id.cpu().item() next_action_embed = self.col2action( candidate_v[pred_id]).unsqueeze(0) history.append(candidate_actions[pred_id]) sketch_history.append(f"{cur_non_terminal} -> #PH#") else: candidates = action_dict[cur_non_terminal] candidate_ids = [self.sketch_prod2id[ac] for ac in candidates] cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell next_hidden, next_memory = self.decoder_lstm( rnn_state.previous_action_embedding, (cur_hidden, cur_memory)) hidden_tran = next_hidden.transpose(0, 1) att_feat_v = torch.mm(token_encodes, hidden_tran) # sent_len * 1 att_v = F.softmax(att_feat_v, dim=0) att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes) score_feat_v = torch.cat([next_hidden, att_ret_v], 1) score_v = self.score_action(score_feat_v).squeeze(0) filter_score_v_list = [score_v[_id] for _id in candidate_ids] filter_score_v = torch.stack(filter_score_v_list, 0) prob_v = F.softmax(filter_score_v, dim=0) _, pred_id = torch.max(prob_v, dim=0) pred_id = pred_id.cpu().item() next_action_embed = self.sketch_embed.weight[ candidate_ids[pred_id]].unsqueeze(0) prod = candidates[pred_id] history.append(prod) sketch_history.append(prod) non_terminals = get_right_side_parts(prod) for _a in reversed(non_terminals): if _a in action_dict: stack.append(_a) rnn_state = RnnStatelet(next_hidden, next_memory, next_action_embed, None, None, None) return tuple(sketch_history), tuple(history)