Exemple #1
0
def get_sketch_prod(examples: List, table_dict: Dict) -> List:
    """
    If it contains all three types of columns, then the grammar is complete
    Also return sketch action list and their slots 
    """
    for example in examples:
        table_id = example["context"]
        table_lines = table_dict[table_id]["raw_lines"]

        tokenized_question = [Token(token) for token in example["tokens"]]
        context = TableQuestionContext.read_from_lines(table_lines,
                                                       tokenized_question)
        context.take_corenlp_entities(example["entities"])

        world = WikiTableAbstractLanguage(context)
        if len(context.column_types) >= 3 and len(context._num2id) > 0 and \
            len(context._entity2id) > 0 and len(context._date2id) > 0:
            actions = world.get_nonterminal_productions()
            sketch_actions = world._get_sketch_productions(actions)

            # index all the possible actions
            action_set = set()
            for k, v in sketch_actions.items():
                action_set = action_set.union(set(v))
            id2prod = list(action_set)
            prod2id = {v: k for k, v in enumerate(id2prod)}

            return id2prod, prod2id
Exemple #2
0
def get_sketch_prod_and_slot(examples: List, table_dict: Dict,
                             sketch_list: List, sketch_action_list: List):
    """
    If it contains all three types of columns, then the grammar is complete
    Also return sketch action list and their slots 
    """
    for example in examples:
        table_id = example["context"]
        table_lines = table_dict[table_id]["raw_lines"]

        tokenized_question = [Token(token) for token in example["tokens"]]
        context = TableQuestionContext.read_from_lines(table_lines,
                                                       tokenized_question)
        context.take_corenlp_entities(example["entities"])

        # annoymize number and date
        # context.annoymized_tokens = example["tmp_tokens"]

        world = WikiTableAbstractLanguage(context)
        if len(context.column_types) >= 3 and len(context._num2id) > 0 and \
            len(context._entity2id) > 0 and len(context._date2id) > 0:
            actions = world.get_nonterminal_productions()
            sketch_actions = world._get_sketch_productions(actions)

            # index all the possible actions
            action_set = set()
            for k, v in sketch_actions.items():
                action_set = action_set.union(set(v))
            id2prod = list(action_set)
            prod2id = {v: k for k, v in enumerate(id2prod)}
            # return id2prod, prod2id

            # lf to actions
            sketch_lf2actions = dict()
            for sketch_actions in sketch_action_list:
                lf = world.action_sequence_to_logical_form(sketch_actions)
                sketch_lf2actions[lf] = sketch_actions

            # sort by length in decreasing order
            slot_dict = defaultdict(dict)
            sketch_action_seqs = []
            for sketch in sketch_list:
                sketch_actions = sketch_lf2actions[sketch]
                sketch_actions = tuple(sketch_actions)
                sketch_action_seqs.append(sketch_actions)

                for action_ind, action in enumerate(sketch_actions):
                    assert action in prod2id
                    lhs, rhs = action.split(" -> ")
                    if lhs in [
                            "Column", "StringColumn", "NumberColumn",
                            "ComparableColumn", "DateColumn", "str", "Number",
                            "Date"
                    ] and rhs == "#PH#":
                        slot_dict[sketch_actions][action_ind] = lhs
                    elif lhs == "List[Row]" and rhs == "#PH#":
                        slot_dict[sketch_actions][action_ind] = lhs

            return id2prod, prod2id, sketch_action_seqs, slot_dict
Exemple #3
0
    def decode(self, world: WikiTableAbstractLanguage, token_rep: torch.Tensor,
               token_encodes: torch.Tensor):
        """
        Input: a sequence of sketch actions
        Output: the most probable sequence
        """
        action_dict = world._get_sketch_productions(
            world.get_nonterminal_productions())
        initial_rnn_state = self._get_initial_state(token_rep)

        stack = [START_SYMBOL]
        history = []
        rnn_state = initial_rnn_state
        for i in range(self._max_decoding_steps):
            if len(stack) == 0: break

            cur_non_terminal = stack.pop()
            if cur_non_terminal not in action_dict: continue
            candidates = action_dict[cur_non_terminal]
            candidate_ids = [self.sketch_prod2id[ac] for ac in candidates]

            cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell
            next_hidden, next_memory = self.decoder_lstm(
                rnn_state.previous_action_embedding, (cur_hidden, cur_memory))
            hidden_tran = next_hidden.transpose(0, 1)
            att_feat_v = torch.mm(token_encodes, hidden_tran)  # sent_len * 1
            att_v = F.softmax(att_feat_v, dim=0)
            att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes)

            score_feat_v = torch.cat([next_hidden, att_ret_v], 1)
            score_v = self.score_action(score_feat_v).squeeze()
            filter_score_v_list = [score_v[_id] for _id in candidate_ids]
            filter_score_v = torch.stack(filter_score_v_list, 0)
            prob_v = F.softmax(filter_score_v, dim=0)

            _, pred_id = torch.max(prob_v, dim=0)
            pred_id = pred_id.cpu().item()

            next_action_embed = self.sketch_embed.weight[
                candidate_ids[pred_id]].unsqueeze(0)
            rnn_state = RnnStatelet(next_hidden, next_memory,
                                    next_action_embed, None, None, None)

            prod = candidates[pred_id]
            history.append(prod)
            non_terminals = self._get_right_side_parts(prod)
            stack += list(reversed(non_terminals))

        return tuple(history)
Exemple #4
0
    def forward(self, 
            context: TableQuestionContext,
            sketch2program: Dict) -> torch.Tensor:
        world = WikiTableAbstractLanguage(context)

        # encode questions
        token_in_table_feat = context.question_in_table_feat
        token_encodes, token_reps, last_state = self.encode_question(context.question_tokens, token_in_table_feat)

        sketch_lf2actions = self.sketch_lf2actions(world)
        consistent_scores = []
        candidate_rep_dic = self.construct_candidates(world, token_encodes)
        for sketch_lf in sketch2program:
            if len(sketch2program[sketch_lf]) > self.CONSISTENT_INST_NUM_BOUND:
                continue
            sketch_actions = sketch_lf2actions[sketch_lf]
            seq_log_likeli = self.seq2seq(world, token_reps, token_encodes, sketch_actions)
            _paths, _log_scores = self.slot_filling(world, token_encodes, last_state,
                    candidate_rep_dic, sketch_actions)

            # only one path
            if len(_paths) == 1:
                consistent_scores.append(seq_log_likeli)
                continue

            _gold_scores = []
            for _path, _score in zip(_paths, _log_scores):
                assert _score is not None
                _path_lf = world.action_sequence_to_logical_form(_path)
                if _path_lf in sketch2program[sketch_lf]:
                    _gold_scores.append(_score) 
            
            # aggregate consistent instantiations
            if len(_gold_scores) > 0:
                _score = seq_log_likeli + log_sum_exp(_gold_scores) 
                if torch.isnan(_score) == 0:
                    consistent_scores.append(_score)
                else:
                    logger.warning("Nan loss founded!")

        if len(consistent_scores) > 0:
            return -1 * log_sum_exp(consistent_scores)
        else:
            return None
Exemple #5
0
    def compute_entropy(self, 
            context: TableQuestionContext,
            sketch2program: Dict) -> bool:
        world = WikiTableAbstractLanguage(context)
        ret_dic = defaultdict(int)

        # encode question and offline sketches
        token_in_table_feat = context.question_in_table_feat
        token_encodes, token_reps, last_state = self.encode_question(context.question_tokens, token_in_table_feat)

        entropy = []
        sketch_lf2actions = self.sketch_lf2actions(world)
        candidate_rep_dic = self.construct_candidates(world, token_encodes)
        for sketch_lf in sketch2program:
            sketch_actions = sketch_lf2actions[sketch_lf]    
            sketch_log_score = self.seq2seq(world, token_reps, token_encodes, sketch_actions)
            _paths, _log_scores = self.slot_filling(world, token_encodes, last_state,
                candidate_rep_dic, sketch_actions)

            # only one path
            if len(_paths) == 1:
                if not self.filter_program_by_execution(world, _paths[0]):
                    continue
                _path_lf = world.action_sequence_to_logical_form(_paths[0])
                _seq_score = sketch_log_score
                if _path_lf in sketch2program[sketch_lf]:
                    entropy.append(-1 * _seq_score * torch.exp(_seq_score))
                continue

            # multiple path
            for _path, _score in zip(_paths, _log_scores):
                if not self.filter_program_by_execution(world, _path):
                    continue
                assert _score is not None
                _path_lf = world.action_sequence_to_logical_form(_path)
                _seq_score = _score + sketch_log_score
                if _path_lf in sketch2program[sketch_lf]:
                    entropy.append(-1 * _seq_score * torch.exp(_seq_score))
        
        if len(entropy) > 0:
            ret_dic["entropy"] = sum(entropy).cpu().item()
        return ret_dic
Exemple #6
0
 def filter_program_by_execution(self,
                             world:WikiTableAbstractLanguage, 
                             actions: List):
     try:
         ret = world.execute_action_sequence(actions)
         if ret:
             return True
         else:
             return False
     except:
         return False
Exemple #7
0
    def forward(self, world: WikiTableAbstractLanguage,
                token_rep: torch.Tensor, token_encodes: torch.Tensor,
                sketch_actions: List):
        """
        Input: a sequence of sketch actions
        """
        action_dict = world._get_sketch_productions(
            world.get_nonterminal_productions())
        initial_rnn_state = self._get_initial_state(token_rep)

        seq_likeli = []
        rnn_state = initial_rnn_state
        for i, prod in enumerate(sketch_actions):
            left_side, _ = prod.split(" -> ")
            candidates = action_dict[left_side]
            candidate_ids = [self.sketch_prod2id[ac] for ac in candidates]

            cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell
            next_hidden, next_memory = self.decoder_lstm(
                rnn_state.previous_action_embedding, (cur_hidden, cur_memory))
            hidden_tran = next_hidden.transpose(0, 1)
            att_feat_v = torch.mm(token_encodes, hidden_tran)  # sent_len * 1
            att_v = F.softmax(att_feat_v, dim=0)
            att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes)

            score_feat_v = torch.cat([next_hidden, att_ret_v], 1)
            score_v = self.score_action(score_feat_v).squeeze()
            filter_score_v_list = [score_v[_id] for _id in candidate_ids]
            filter_score_v = torch.stack(filter_score_v_list, 0)
            log_likeli = F.log_softmax(filter_score_v, dim=0)

            gold_id = candidate_ids.index(self.sketch_prod2id[prod])
            seq_likeli.append(log_likeli[gold_id])

            next_action_embed = self.sketch_embed.weight[
                self.sketch_prod2id[prod]].unsqueeze(0)
            rnn_state = RnnStatelet(next_hidden, next_memory,
                                    next_action_embed, None, None, None)

        return sum(seq_likeli)
Exemple #8
0
def check_multi_col(world: WikiTableAbstractLanguage,
                    sketch_actions: List,
                    program_actions: List) -> bool:
    prod_dic = world.get_nonterminal_productions()
    slot_dic = gen_slot2action_dic(world, prod_dic, sketch_actions, program_actions)
    row_slot_acs = []
    col_slot_acs = []
    for idx in slot_dic:
        slot_type = get_left_side_part(sketch_actions[idx])
        if slot_type == "List[Row]":
            row_slot_acs.append(slot_dic[idx])
        else:
            col_slot_acs.append(slot_dic[idx])
    
    if len(row_slot_acs) == 0 or len(col_slot_acs) == 0:
        return False

    for col_slot_ac in col_slot_acs:
        col_name = get_right_side_parts(col_slot_ac)[0]
        for row_slot_ac in row_slot_acs:
            if col_name not in "_".join(row_slot_ac):
                return True
    return False
    
Exemple #9
0
    def evaluate(self, 
            context: TableQuestionContext,
            sketch2program: Dict) -> bool:
        world = WikiTableAbstractLanguage(context)
        ret_dic = defaultdict(int)

        # encode question and offline sketches
        token_in_table_feat = context.question_in_table_feat
        token_encodes, token_reps, last_state = self.encode_question(context.question_tokens, token_in_table_feat)

        sketch_actions_and_scores = self.seq2seq.beam_decode(world, 
            token_reps, token_encodes, self.EVAL_NUM_SKETCH_BOUND)
        
        max_score = None
        best_sketch_actions = None
        best_sketch_lf = None
        best_program_actions = None
        best_program_lf = None
        candidate_rep_dic = self.construct_candidates(world, token_encodes)
        for sketch_actions, sketch_log_score in sketch_actions_and_scores:
            sketch_lf = world.action_sequence_to_logical_form(sketch_actions)
            _paths, _log_scores = self.slot_filling(world, token_encodes, last_state,
                candidate_rep_dic, sketch_actions)
            # logger.info(f"{sketch_lf}, score: {torch.exp(sketch_log_score)}")

            if self.__class__.__name__ == "ConcreteProgrammer":
                assert self._ConcreteProgrammer__cur_align_prob_log is not None
                align_prob_log = self._ConcreteProgrammer__cur_align_prob_log.squeeze()
                # print(f"Align matrix prob: {torch.exp(align_prob_log)}")
                sketch_log_score = sketch_log_score + align_prob_log
                self._ConcreteProgrammer__cur_align_prob_log = None

            # only one path
            if len(_paths) == 1:
                if not self.filter_program_by_execution(world, _paths[0]):
                    continue
                _path_lf = world.action_sequence_to_logical_form(_paths[0])
                _seq_score = sketch_log_score
                if max_score is None or _seq_score > max_score:
                    max_score = _seq_score
                    best_sketch_lf = sketch_lf
                    best_sketch_actions = sketch_actions
                    best_program_lf = _path_lf 
                    best_program_actions = _paths[0]
                continue

            # multiple path
            for _path, _score in zip(_paths, _log_scores):
                if not self.filter_program_by_execution(world, _path):
                    continue
                assert _score is not None
                _path_lf = world.action_sequence_to_logical_form(_path)
                _seq_score = _score + sketch_log_score
                if max_score is None or _seq_score > max_score:
                    max_score = _seq_score
                    best_sketch_lf = sketch_lf
                    best_sketch_actions = sketch_actions
                    best_program_lf = _path_lf 
                    best_program_actions = _path
        
        assert max_score is not None
        ret_dic["best_program_lf"] = best_program_lf
        ret_dic["best_program_actions"] = best_program_actions
        ret_dic["best_sketch_lf"] = best_sketch_lf
        ret_dic["best_sketch_actions"] = best_sketch_actions
        ret_dic["best_score"] = torch.exp(max_score)
        ret_dic["is_multi_col"] = check_multi_col(world, 
                        best_sketch_actions, best_program_actions)

        if best_sketch_lf in sketch2program:
            ret_dic["sketch_triggered"] = True
            if best_program_lf in sketch2program[best_sketch_lf]:
                ret_dic["lf_triggered"] = True
            else:
                ret_dic["lf_triggered"] = False
        else:
            ret_dic["sketch_triggered"] = False
            ret_dic["lf_triggered"] = False

        return ret_dic
Exemple #10
0
    def slot_filling(self, 
                        world:WikiTableAbstractLanguage, 
                        token_encodes:torch.Tensor, 
                        token_state:torch.Tensor,
                        candidate_rep_dic: Dict,
                        sketch_actions: List):
        """
        1) collect scores for each individual slot 2) find all the paths recursively
        """
        slot_dict = world.get_slot_dict(sketch_actions)
        sketch_encodes, sketch_rep = self.encode_sketch(sketch_actions, token_state)
        candidate_score_dic = self.collect_candidate_scores(world, token_encodes, 
                candidate_rep_dic, sketch_encodes, slot_dict)

        possible_paths = []
        path_scores = []
        def recur_compute(prefix, score, i):
            if i == len(sketch_actions):
                possible_paths.append(prefix)
                path_scores.append(score)
                return
            if i in slot_dict:
                _slot_type = slot_dict[i]
                if _slot_type not in candidate_rep_dic:
                    return   # this sketch does not apply here

                slot_rep = sketch_encodes[i] 
                candidate_v, candidiate_actions = candidate_rep_dic[_slot_type]

                if len(candidiate_actions) == 1:
                    action = candidiate_actions[0]
                    new_prefix = prefix[:]
                    if isinstance(action, list):
                        new_prefix += action
                    else:
                        new_prefix.append(action)
                    recur_compute(new_prefix, score, i + 1) 
                    return
                
                if len(candidiate_actions) > self.CANDIDATE_ACTION_NUM_BOUND:
                    _, top_k = torch.topk(candidate_score_dic[i], self.CANDIDATE_ACTION_NUM_BOUND, dim=0)
                    ac_idxs = top_k.cpu().numpy()    
                else:
                    ac_idxs = range(len(candidiate_actions))

                # for ac_ind, action in enumerate(candidiate_actions):
                for ac_ind in ac_idxs:
                    action = candidiate_actions[ac_ind]
                    new_prefix = prefix[:]
                    if score:
                        new_score = score + candidate_score_dic[i][ac_ind]
                    else:
                        new_score = candidate_score_dic[i][ac_ind]
                    if isinstance(action, list):
                        new_prefix += action
                    else:
                        new_prefix.append(action)
                    recur_compute(new_prefix, new_score, i + 1)
            else:
                new_prefix = prefix[:]
                new_prefix.append(sketch_actions[i])
                recur_compute(new_prefix, score, i + 1) 
        
        recur_compute([], None, 0)
        return possible_paths, path_scores
Exemple #11
0
 def sketch_lf2actions(self, world: WikiTableAbstractLanguage):
     lf2actions = dict()
     for actions in self.sketch_actions_cache:
         lf = world.action_sequence_to_logical_form(actions)
         lf2actions[lf] = actions
     return lf2actions
Exemple #12
0
    def beam_decode(self, world: WikiTableAbstractLanguage,
                    token_rep: torch.Tensor, token_encodes: torch.Tensor,
                    beam_size: int):
        """
        Input: a sequence of sketch actions
        Output: output top-k most probable sequence
        """
        action_dict = world._get_sketch_productions(
            world.get_nonterminal_productions())
        initial_rnn_state = self._get_initial_state(token_rep)

        incomplete = [([START_SYMBOL], [], initial_rnn_state, None)
                      ]  # stack,history,rnn_state
        completed = []

        for i in range(self._max_decoding_steps):
            next_paths = []
            for stack, history, rnn_state, seq_score in incomplete:
                cur_non_terminal = stack.pop()
                if cur_non_terminal not in action_dict: continue
                candidates = action_dict[cur_non_terminal]
                candidate_ids = [self.sketch_prod2id[ac] for ac in candidates]

                cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell
                next_hidden, next_memory = self.decoder_lstm(
                    rnn_state.previous_action_embedding,
                    (cur_hidden, cur_memory))
                hidden_tran = next_hidden.transpose(0, 1)
                att_feat_v = torch.mm(token_encodes,
                                      hidden_tran)  # sent_len * 1
                att_v = F.softmax(att_feat_v, dim=0)
                att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes)

                score_feat_v = torch.cat([next_hidden, att_ret_v], 1)
                score_v = self.score_action(score_feat_v).squeeze()
                filter_score_v_list = [score_v[_id] for _id in candidate_ids]
                filter_score_v = torch.stack(filter_score_v_list, 0)
                prob_v = F.log_softmax(filter_score_v, dim=0)

                pred_logits, pred_ids = torch.topk(prob_v,
                                                   min(beam_size,
                                                       prob_v.size()[0]),
                                                   dim=0)

                for _logits, _idx in zip(pred_logits, pred_ids):
                    next_action_embed = self.sketch_embed.weight[
                        candidate_ids[_idx]].unsqueeze(0)
                    rnn_state = RnnStatelet(next_hidden, next_memory,
                                            next_action_embed, None, None,
                                            None)

                    prod = candidates[_idx]
                    _history = history[:]
                    _history.append(prod)
                    non_terminals = self._get_right_side_parts(prod)
                    _stack = stack[:]
                    for ac in reversed(non_terminals):
                        if ac in action_dict:
                            _stack.append(ac)
                    if seq_score is None:
                        _score = _logits
                    else:
                        _score = _logits + seq_score

                    next_paths.append((_stack, _history, rnn_state, _score))

            incomplete = []
            for stack, history, rnn_state, seq_score in next_paths:
                if len(stack) == 0:
                    if world.action_sequence_to_logical_form(
                            history) != "#PH#":
                        completed.append((history, seq_score))
                else:
                    incomplete.append((stack, history, rnn_state, seq_score))

            if len(completed) > beam_size:
                completed = sorted(completed, key=lambda x: -x[1])
                completed = completed[:beam_size]
                break

            if len(incomplete) > beam_size:
                incomplete = sorted(incomplete, key=lambda x: -x[3])
                incomplete = incomplete[:beam_size]

        return completed
Exemple #13
0
    def forward(self, world: WikiTableAbstractLanguage,
                token_rep: torch.Tensor, token_encodes: torch.Tensor,
                candidate_rep_dic: torch.Tensor, sketch_actions: List,
                program_actions: List):
        """
        Input: a sequence of sketch actions
        """
        prod_action_dict = world.get_nonterminal_productions()
        sketch_action_dict = world._get_sketch_productions(prod_action_dict)
        initial_rnn_state = self._get_initial_state(token_rep)
        slot2action_dic = gen_slot2action_dic(world, prod_action_dict,
                                              sketch_actions, program_actions)

        program_ref_actions = program_actions[:]
        seq_likeli = []
        rnn_state = initial_rnn_state
        for i, prod in enumerate(sketch_actions):
            left_side, right_side = prod.split(" -> ")

            if right_side != "#PH#":
                candidates = sketch_action_dict[left_side]
                candidate_ids = [self.sketch_prod2id[ac] for ac in candidates]

                cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell
                next_hidden, next_memory = self.decoder_lstm(
                    rnn_state.previous_action_embedding,
                    (cur_hidden, cur_memory))
                hidden_tran = next_hidden.transpose(0, 1)
                att_feat_v = torch.mm(token_encodes,
                                      hidden_tran)  # sent_len * 1
                att_v = F.softmax(att_feat_v, dim=0)
                att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes)

                score_feat_v = torch.cat([next_hidden, att_ret_v], 1)
                score_v = self.score_action(score_feat_v).squeeze(0)
                filter_score_v_list = [score_v[_id] for _id in candidate_ids]
                filter_score_v = torch.stack(filter_score_v_list, 0)
                log_likeli = F.log_softmax(filter_score_v, dim=0)

                gold_id = candidate_ids.index(self.sketch_prod2id[prod])
                seq_likeli.append(log_likeli[gold_id])
                next_action_embed = self.sketch_embed.weight[
                    self.sketch_prod2id[prod]].unsqueeze(0)

            else:
                assert left_side == "List[Row]" or "Column" in left_side
                assert i in slot2action_dic

                candidate_v, candidate_actions = candidate_rep_dic[left_side]
                try:
                    gold_id = candidate_actions.index(slot2action_dic[i])
                except:
                    # not included, e.g and/or are order-invariant
                    return None

                # fit the memory for some extreme case
                if len(candidate_actions) > 256:
                    _s = max(0, gold_id - 128)
                    _e = min(gold_id + 128, len(candidate_actions))
                    candidate_v = candidate_v[_s:_e]
                    candidate_actions = candidate_actions[_s:_e]
                    gold_id = candidate_actions.index(slot2action_dic[i])
                    assert gold_id >= 0

                cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell
                next_hidden, next_memory = self.decoder_lstm(
                    rnn_state.previous_action_embedding,
                    (cur_hidden, cur_memory))
                hidden_tran = next_hidden.transpose(0, 1)
                att_feat_v = torch.mm(token_encodes,
                                      hidden_tran)  # sent_len * 1
                att_v = F.softmax(att_feat_v, dim=0)
                att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes)
                score_feat_v = torch.cat([next_hidden, att_ret_v], 1)
                num_candidate = candidate_v.size()[0]

                if left_side == "List[Row]":
                    score_feat_v = score_feat_v.expand(num_candidate, -1)
                    att_over_sel = self.row2score(candidate_v,
                                                  score_feat_v).squeeze(1)
                    att_over_sel = F.log_softmax(att_over_sel, dim=0)

                    seq_likeli.append(att_over_sel[gold_id])
                    next_action_embed = self.row2action(
                        candidate_v[gold_id]).unsqueeze(0)
                else:
                    score_feat_v = score_feat_v.expand(num_candidate, -1)
                    att_over_col = self.col2score(candidate_v,
                                                  score_feat_v).squeeze(1)
                    att_over_col = F.log_softmax(att_over_col, dim=0)

                    seq_likeli.append(att_over_col[gold_id])
                    next_action_embed = self.col2action(
                        candidate_v[gold_id]).unsqueeze(0)

            rnn_state = RnnStatelet(next_hidden, next_memory,
                                    next_action_embed, None, None, None)

        return sum(seq_likeli)
Exemple #14
0
    def decode(self, world: WikiTableAbstractLanguage, token_rep: torch.Tensor,
               token_encodes: torch.Tensor, candidate_rep_dic: torch.Tensor):
        """
        Input: a sequence of sketch actions
        Output: the most probable sequence
        """
        action_dict = world._get_sketch_productions(
            world.get_nonterminal_productions())
        initial_rnn_state = self._get_initial_state(token_rep)

        stack = [START_SYMBOL]
        history = []
        sketch_history = []
        rnn_state = initial_rnn_state
        for i in range(self._max_decoding_steps):
            if len(stack) == 0: break

            cur_non_terminal = stack.pop()
            if cur_non_terminal not in action_dict: continue

            if cur_non_terminal == "List[Row]":
                candidate_v, candidate_actions = candidate_rep_dic[
                    cur_non_terminal]
                cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell
                next_hidden, next_memory = self.decoder_lstm(
                    rnn_state.previous_action_embedding,
                    (cur_hidden, cur_memory))
                hidden_tran = next_hidden.transpose(0, 1)
                att_feat_v = torch.mm(token_encodes,
                                      hidden_tran)  # sent_len * 1
                att_v = F.softmax(att_feat_v, dim=0)
                att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes)
                score_feat_v = torch.cat([next_hidden, att_ret_v], 1)
                num_candidate = candidate_v.size()[0]

                score_feat_v = score_feat_v.expand(num_candidate, -1)
                att_over_sel = self.row2score(candidate_v,
                                              score_feat_v).squeeze(1)
                att_over_sel = F.softmax(att_over_sel, dim=0)
                _, pred_id = torch.max(att_over_sel, dim=0)
                pred_id = pred_id.cpu().item()

                next_action_embed = self.row2action(
                    candidate_v[pred_id]).unsqueeze(0)
                history += candidate_actions[pred_id]
                sketch_history.append("List[Row] -> #PH#")
            elif "Column" == cur_non_terminal[-6:]:
                candidate_v, candidate_actions = candidate_rep_dic[
                    cur_non_terminal]
                cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell
                next_hidden, next_memory = self.decoder_lstm(
                    rnn_state.previous_action_embedding,
                    (cur_hidden, cur_memory))
                hidden_tran = next_hidden.transpose(0, 1)
                att_feat_v = torch.mm(token_encodes,
                                      hidden_tran)  # sent_len * 1
                att_v = F.softmax(att_feat_v, dim=0)
                att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes)
                score_feat_v = torch.cat([next_hidden, att_ret_v], 1)
                num_candidate = candidate_v.size()[0]

                score_feat_v = score_feat_v.expand(num_candidate, -1)
                att_over_col = self.col2score(candidate_v,
                                              score_feat_v).squeeze(1)
                att_over_col = F.softmax(att_over_col, dim=0)
                _, pred_id = torch.max(att_over_col, dim=0)
                pred_id = pred_id.cpu().item()

                next_action_embed = self.col2action(
                    candidate_v[pred_id]).unsqueeze(0)
                history.append(candidate_actions[pred_id])
                sketch_history.append(f"{cur_non_terminal} -> #PH#")
            else:
                candidates = action_dict[cur_non_terminal]
                candidate_ids = [self.sketch_prod2id[ac] for ac in candidates]

                cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell
                next_hidden, next_memory = self.decoder_lstm(
                    rnn_state.previous_action_embedding,
                    (cur_hidden, cur_memory))
                hidden_tran = next_hidden.transpose(0, 1)
                att_feat_v = torch.mm(token_encodes,
                                      hidden_tran)  # sent_len * 1
                att_v = F.softmax(att_feat_v, dim=0)
                att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes)

                score_feat_v = torch.cat([next_hidden, att_ret_v], 1)
                score_v = self.score_action(score_feat_v).squeeze(0)
                filter_score_v_list = [score_v[_id] for _id in candidate_ids]
                filter_score_v = torch.stack(filter_score_v_list, 0)
                prob_v = F.softmax(filter_score_v, dim=0)

                _, pred_id = torch.max(prob_v, dim=0)
                pred_id = pred_id.cpu().item()

                next_action_embed = self.sketch_embed.weight[
                    candidate_ids[pred_id]].unsqueeze(0)

                prod = candidates[pred_id]
                history.append(prod)
                sketch_history.append(prod)
                non_terminals = get_right_side_parts(prod)
                for _a in reversed(non_terminals):
                    if _a in action_dict:
                        stack.append(_a)

            rnn_state = RnnStatelet(next_hidden, next_memory,
                                    next_action_embed, None, None, None)

        return tuple(sketch_history), tuple(history)