def top_k_answer_unique(log_action_dist, action_space):
        """
        Get top k unique entities
            - k = beam_size if the beam size is smaller than or equal to the beam action space size
            - k = beam_action_space_size otherwise
        :param log_action_dist: [batch_size*beam_size, action_space_size]
        :param action_space (r_space, e_space):
            r_space: [batch_size*beam_size, action_space_size]
            e_space: [batch_size*beam_size, action_space_size]
        :return:
            (next_r, next_e), action_prob, action_offset: [batch_size*k]
        """
        full_size = len(log_action_dist)
        assert (full_size % batch_size == 0)
        last_k = int(full_size / batch_size)
        (r_space, e_space), _ = action_space
        action_space_size = r_space.size()[1]

        r_space = r_space.view(batch_size, -1)
        e_space = e_space.view(batch_size, -1)
        log_action_dist = log_action_dist.view(batch_size, -1)
        beam_action_space_size = log_action_dist.size()[1]
        assert (beam_action_space_size % action_space_size == 0)
        k = min(beam_size, beam_action_space_size)
        next_r_list, next_e_list = [], []
        log_action_prob_list = []
        action_offset_list = []
        for i in range(batch_size):
            log_action_dist_b = log_action_dist[i]
            r_space_b = r_space[i]
            e_space_b = e_space[i]
            unique_e_space_b = var_cuda(torch.unique(e_space_b.data.cpu()))
            unique_log_action_dist, unique_idx = unique_max(
                unique_e_space_b, e_space_b, log_action_dist_b)
            k_prime = min(len(unique_e_space_b), k)
            top_unique_log_action_dist, top_unique_idx2 = torch.topk(
                unique_log_action_dist, k_prime)
            top_unique_idx = unique_idx[top_unique_idx2]
            top_unique_beam_offset = top_unique_idx // action_space_size
            top_r = r_space_b[top_unique_idx]
            top_e = e_space_b[top_unique_idx]
            next_r_list.append(top_r.unsqueeze(0))
            next_e_list.append(top_e.unsqueeze(0))
            log_action_prob_list.append(
                top_unique_log_action_dist.unsqueeze(0))
            top_unique_batch_offset = i * last_k
            top_unique_action_offset = top_unique_batch_offset + top_unique_beam_offset
            action_offset_list.append(top_unique_action_offset.unsqueeze(0))
        next_r = ops.pad_and_cat(next_r_list,
                                 padding_value=kg.dummy_r).view(-1)
        next_e = ops.pad_and_cat(next_e_list,
                                 padding_value=kg.dummy_e).view(-1)
        log_action_prob = ops.pad_and_cat(log_action_prob_list,
                                          padding_value=-ops.HUGE_INT)
        action_offset = ops.pad_and_cat(action_offset_list, padding_value=-1)
        return (next_r,
                next_e), log_action_prob.view(-1), action_offset.view(-1)
Example #2
0
 def pad_and_cat_action_space(action_spaces, inv_offset):
     db_r_space, db_e_space, db_action_mask = [], [], []
     for (r_space, e_space), action_mask in action_spaces:
         db_r_space.append(r_space)
         db_e_space.append(e_space)
         db_action_mask.append(action_mask)
     r_space = ops.pad_and_cat(db_r_space, padding_value=kg.dummy_r)[inv_offset]
     e_space = ops.pad_and_cat(db_e_space, padding_value=kg.dummy_e)[inv_offset]
     action_mask = ops.pad_and_cat(db_action_mask, padding_value=0)[inv_offset]
     action_space = ((r_space, e_space), action_mask)
     return action_space
Example #3
0
def pad_and_cat_action_space(action_spaces: List[ActionSpace], inv_offset,
                             kg: KnowledgeGraph):
    db_r_space, db_e_space, db_action_mask = [], [], []
    forks = []
    for acsp in action_spaces:
        forks += acsp.forks
        db_r_space.append(acsp.r_space)
        db_e_space.append(acsp.e_space)
        db_action_mask.append(acsp.action_mask)
    r_space = ops.pad_and_cat(db_r_space, padding_value=kg.dummy_r)[inv_offset]
    e_space = ops.pad_and_cat(db_e_space, padding_value=kg.dummy_e)[inv_offset]
    action_mask = ops.pad_and_cat(db_action_mask, padding_value=0)[inv_offset]
    action_space = ActionSpace(forks, r_space, e_space, action_mask)
    return action_space
 def virtual_step(self, e_set, r):
     """
     Given a set of entities (e_set), find the set of entities (e_set_out) which has at least one incoming edge
     labeled r and the source entity is in e_set.
     """
     batch_size = len(e_set)
     e_set_1D = e_set.view(-1)
     r_space = self.action_space[0][0][e_set_1D]
     e_space = self.action_space[0][1][e_set_1D]
     e_space = (r_space.view(batch_size, -1) == r.unsqueeze(1)).long() * e_space.view(batch_size, -1)
     e_set_out = []
     for i in range(len(e_space)):
         e_set_out_b = var_cuda(unique(e_space[i].data))
         e_set_out.append(e_set_out_b.unsqueeze(0))
     e_set_out = ops.pad_and_cat(e_set_out, padding_value=self.dummy_e)
     return e_set_out
Example #5
0
    def do_it_with_bucketing(
        self,
        X2,
        current_entity,
        kg,
        merge_aspace_batching_outcome,
        obs: Observation,
        policy_nn_fun,
    ):
        entropy_list = []
        references = []
        buckect_action_spaces, inthis_bucket_indizes = self.get_action_space_in_buckets(
            current_entity, obs, kg)
        action_spaces = []
        action_dists = []

        for as_b, inthis_bucket in zip(buckect_action_spaces,
                                       inthis_bucket_indizes):
            X2_b = X2[inthis_bucket, :]
            action_dist_b, entropy_b = policy_nn_fun(X2_b, as_b)
            references.extend(inthis_bucket)
            action_spaces.append(as_b)
            action_dists.append(action_dist_b)
            entropy_list.append(entropy_b)
        inv_offset = [
            i for i, _ in sorted(enumerate(references), key=lambda x: x[1])
        ]
        entropy = torch.cat(entropy_list, dim=0)[inv_offset]
        action = BucketActions(action_spaces, action_dists, inv_offset,
                               entropy)

        if merge_aspace_batching_outcome:
            action_space = pad_and_cat_action_space(buckect_action_spaces,
                                                    inv_offset, kg)
            action_dist = ops.pad_and_cat(action.action_dists,
                                          padding_value=0)[inv_offset]
            action = BucketActions([action_space], [action_dist], None,
                                   entropy)
        return action
Example #6
0
    def transit(self,
                e,
                obs,
                kg,
                kg_pred=None,
                fn_kg=None,
                use_action_space_bucketing=True,
                merge_aspace_batching_outcome=False,
                use_kg_pred=False,
                inference=False):
        """
        Compute the next action distribution based on
            (a) the current node (entity) in KG and the query relation
            (b) action history representation
        :param e: agent location (node) at step t.
        :param obs: agent observation at step t.
            e_s: source node
            q: query relation
            e_t: target node
            last_step: If set, the agent is carrying out the last step.
            last_r: label of edge traversed in the previous step
            seen_nodes: notes seen on the paths
        :param kg: Knowledge graph environment.
        :param use_action_space_bucketing: If set, group the action space of different nodes 
            into buckets by their sizes.
        :param merge_aspace_batch_outcome: If set, merge the transition probability distribution
            generated of different action space bucket into a single batch.
        :return
            With aspace batching and without merging the outcomes:
                db_outcomes: (Dynamic Batch) (action_space, action_dist)
                    action_space: (Batch) padded possible action indices
                    action_dist: (Batch) distribution over actions.
                inv_offset: Indices to set the dynamic batching output back to the original order.
                entropy: (Batch) entropy of action distribution.
            Else:
                action_dist: (Batch) distribution over actions.
                entropy: (Batch) entropy of action distribution.
        """
        e_s, q, e_t, last_step, last_r, seen_nodes = obs

        # Representation of the current state (current node and other observations)
        Q = kg.get_relation_embeddings(q)
        H = self.path[-1][0][-1, :, :]
        if self.relation_only:
            X = torch.cat([H, Q], dim=-1)
        elif self.relation_only_in_path:
            E_s = kg.get_entity_embeddings(e_s)
            E = kg.get_entity_embeddings(e)
            X = torch.cat([E, H, E_s, Q], dim=-1)
        elif use_kg_pred:
            E = kg.get_entity_embeddings(e)
            X = torch.cat([E, H, Q, kg_pred], dim=-1)
            # X = torch.cat([E, H, Q, kg.get_entity_embeddings(e_t)], dim=-1)
        else:
            E = kg.get_entity_embeddings(e)
            X = torch.cat([E, H, Q], dim=-1)

        # MLP
        X = self.W1(X)
        X = F.relu(X)
        X = self.W1Dropout(X)
        X = self.W2(X)
        X2 = self.W2Dropout(X)
        relation_att = torch.matmul(self.W_att(X2),
                                    kg.get_all_relation_embeddings().t())
        # B x |R|
        # Trick -> mask SIM relation
        relation_att = torch.nn.functional.softmax(relation_att, dim=-1)

        def policy_nn_fun(X2, action_space):
            (r_space, e_space), action_mask = action_space
            A = self.get_action_embedding((r_space, e_space), kg)
            action_dist = F.softmax(
                torch.squeeze(A @ torch.unsqueeze(X2, 2), 2) -
                (1 - action_mask) * ops.HUGE_INT,
                dim=-1)
            # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask)
            return action_dist, ops.entropy(action_dist)

        def pad_and_cat_action_space(action_spaces, inv_offset):
            db_r_space, db_e_space, db_action_mask = [], [], []
            for (r_space, e_space), action_mask in action_spaces:
                db_r_space.append(r_space)
                db_e_space.append(e_space)
                db_action_mask.append(action_mask)
            r_space = ops.pad_and_cat(db_r_space,
                                      padding_value=kg.dummy_r)[inv_offset]
            e_space = ops.pad_and_cat(db_e_space,
                                      padding_value=kg.dummy_e)[inv_offset]
            action_mask = ops.pad_and_cat(db_action_mask,
                                          padding_value=0)[inv_offset]
            action_space = ((r_space, e_space), action_mask)
            return action_space

        if use_action_space_bucketing:
            """
            """
            db_outcomes = []
            entropy_list = []
            references = []
            db_action_spaces, db_references = self.get_action_space_in_buckets(
                e, obs, kg, relation_att=relation_att, inference=inference)
            for action_space_b, reference_b in zip(db_action_spaces,
                                                   db_references):
                X2_b = X2[reference_b, :]
                action_dist_b, entropy_b = policy_nn_fun(X2_b, action_space_b)
                references.extend(reference_b)
                db_outcomes.append((action_space_b, action_dist_b))
                entropy_list.append(entropy_b)
            inv_offset = [
                i for i, _ in sorted(enumerate(references), key=lambda x: x[1])
            ]
            entropy = torch.cat(entropy_list, dim=0)[inv_offset]
            if merge_aspace_batching_outcome:
                db_action_dist = []
                for _, action_dist in db_outcomes:
                    db_action_dist.append(action_dist)
                action_space = pad_and_cat_action_space(
                    db_action_spaces, inv_offset)
                action_dist = ops.pad_and_cat(db_action_dist,
                                              padding_value=0)[inv_offset]
                db_outcomes = [(action_space, action_dist)]
                inv_offset = None
        else:
            action_space = self.get_action_space(e, obs, kg)
            action_dist, entropy = policy_nn_fun(X2, action_space)
            db_outcomes = [(action_space, action_dist)]
            inv_offset = None

        return db_outcomes, inv_offset, entropy
    def transit(self,
                e,
                obs,
                kg,
                mode,
                use_action_space_bucketing=True,
                merge_aspace_batching_outcome=False):
        """
        Compute the next action distribution based on
            (a) the current node (entity) in KG and the query relation
            (b) action history representation
        :param e: agent location (node) at step t.
        :param obs: agent observation at step t.
            e_s: source node
            q: query relation
            e_t: target node
            first_step: If set, the agent is carrying out the first step
            last_step: If set, the agent is carrying out the last step.
            last_r: label of edge traversed in the previous step
            seen_nodes: nodes seen on the paths
        :param kg: Knowledge graph environment.
        :param use_action_space_bucketing: If set, group the action space of different nodes 
            into buckets by their sizes.
        :param merge_aspace_batch_outcome: If set, merge the transition probability distribution
            generated of different action space bucket into a single batch.
        :return
            With aspace batching and without merging the outcomes:
                db_outcomes: (Dynamic Batch) (action_space, action_dist)
                    action_space: (Batch) padded possible action indices
                    action_dist: (Batch) distribution over actions.
                inv_offset: Indices to set the dynamic batching output back to the original order.
                entropy: (Batch) entropy of action distribution.
            Else:
                action_dist: (Batch) distribution over actions.
                entropy: (Batch) entropy of action distribution.
        """

        e_s, emb_e_s, q, e_t, first_step, last_step, last_r, seen_nodes = obs

        # Representation of the current state (current node and other observations)
        Q = self.graph_transformer.dropout(self.graph_transformer.emb_r(q))
        H = self.path[-1][0][-1, :, :]
        if self.relation_only:
            X = torch.cat([H, Q], dim=-1)
        elif self.relation_only_in_path:
            E_s = emb_e_s
            if first_step:
                E = emb_e_s
            else:
                if mode == 'train':
                    E, _ = self.graph_transformer(e, q, self.dg.training_graph,
                                                  self.dg.seen_id2entity,
                                                  self.bandwidth, mode)
                else:
                    E, _ = self.graph_transformer(e, q, self.dg.eval_graph,
                                                  self.dg.seen_id2entity,
                                                  self.bandwidth, mode)

            if E.size()[0] != E_s.size(0):
                expansion_size = int(E.size()[0] / E_s.size(0))
                E_s = E_s.unsqueeze(1).expand(-1, expansion_size, -1)
                E_s = torch.flatten(E_s, start_dim=0).view(-1, self.entity_dim)
            X = torch.cat([E, H, E_s, Q], dim=-1)
        else:
            if first_step:
                E = emb_e_s
            else:
                if mode == 'train':
                    E, _ = self.graph_transformer(e, q, self.dg.training_graph,
                                                  self.dg.seen_id2entity,
                                                  self.bandwidth, mode)
                else:
                    if self.args.inference:
                        E, _ = self.graph_transformer(e, q, self.dg.aux_graph,
                                                      self.dg.seen_id2entity,
                                                      self.bandwidth, 'test')
                    else:
                        E, _ = self.graph_transformer(e, q, self.dg.eval_graph,
                                                      self.dg.seen_id2entity,
                                                      self.bandwidth, 'eval')

            X = torch.cat([E, H, Q], dim=-1)

        # MLP
        X = self.W1(X)
        X = F.relu(X)
        X = self.W1Dropout(X)
        X = self.W2(X)
        X2 = self.W2Dropout(X)

        def policy_nn_fun(X2, action_space):
            (r_space, e_space), action_mask = action_space
            A = self.get_action_embedding((r_space, e_space), kg)
            action_dist = F.softmax(
                torch.squeeze(A @ torch.unsqueeze(X2, 2), 2) -
                (1 - action_mask) * ops.HUGE_INT,
                dim=-1)
            # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask)
            return action_dist, ops.entropy(action_dist)

        def pad_and_cat_action_space(action_spaces, inv_offset):
            db_r_space, db_e_space, db_action_mask = [], [], []
            for (r_space, e_space), action_mask in action_spaces:
                db_r_space.append(r_space)
                db_e_space.append(e_space)
                db_action_mask.append(action_mask)
            r_space = ops.pad_and_cat(db_r_space,
                                      padding_value=kg.dummy_r)[inv_offset]
            e_space = ops.pad_and_cat(db_e_space,
                                      padding_value=kg.dummy_e)[inv_offset]
            action_mask = ops.pad_and_cat(db_action_mask,
                                          padding_value=0)[inv_offset]
            action_space = ((r_space, e_space), action_mask)
            return action_space

        if use_action_space_bucketing:
            db_outcomes = []
            entropy_list = []
            references = []
            db_action_spaces, db_references = self.get_action_space_in_buckets(
                e, obs, kg)
            for action_space_b, reference_b in zip(db_action_spaces,
                                                   db_references):
                X2_b = X2[reference_b, :]
                action_dist_b, entropy_b = policy_nn_fun(X2_b, action_space_b)
                references.extend(reference_b)
                db_outcomes.append((action_space_b, action_dist_b))
                entropy_list.append(entropy_b)
            inv_offset = [
                i for i, _ in sorted(enumerate(references), key=lambda x: x[1])
            ]
            entropy = torch.cat(entropy_list, dim=0)[inv_offset]
            if merge_aspace_batching_outcome:
                db_action_dist = []
                for _, action_dist in db_outcomes:
                    db_action_dist.append(action_dist)
                action_space = pad_and_cat_action_space(
                    db_action_spaces, inv_offset)
                action_dist = ops.pad_and_cat(db_action_dist,
                                              padding_value=0)[inv_offset]
                db_outcomes = [(action_space, action_dist)]
                inv_offset = None
        else:
            action_space = self.get_action_space(e, obs, kg)
            action_dist, entropy = policy_nn_fun(X2, action_space)
            db_outcomes = [(action_space, action_dist)]
            inv_offset = None
        return db_outcomes, inv_offset, entropy
Example #8
0
    def transit_fusion(self,
                       e,
                       obs,
                       kg,
                       fn,
                       fn_kg,
                       use_action_space_bucketing=True,
                       merge_aspace_batching_outcome=False):
        """
        Compute the next action distribution based on
            (a) the current node (entity) in KG and the query relation
            (b) action history representation
        :param e: agent location (node) at step t.
        :param obs: agent observation at step t.
            e_s: source node
            q: query relation
            e_t: target node
            last_step: If set, the agent is carrying out the last step.
            last_r: label of edge traversed in the previous step
            seen_nodes: notes seen on the paths
        :param kg: Knowledge graph environment.
        :param use_action_space_bucketing: If set, group the action space of different nodes
            into buckets by their sizes.
        :param merge_aspace_batch_outcome: If set, merge the transition probability distribution
            generated of different action space bucket into a single batch.
        :return
            With aspace batching and without merging the outcomes:
                db_outcomes: (Dynamic Batch) (action_space, action_dist)
                    action_space: (Batch) padded possible action indices
                    action_dist: (Batch) distribution over actions.
                inv_offset: Indices to set the dynamic batching output back to the original order.
                entropy: (Batch) entropy of action distribution.
            Else:
                action_dist: (Batch) distribution over actions.
                entropy: (Batch) entropy of action distribution.
        """
        e_s, q, e_t, last_step, last_r, seen_nodes = obs

        # Representation of the current state (current node and other observations)
        Q = kg.get_relation_embeddings(q)
        H = self.path[-1][0][-1, :, :]
        if self.using_attention == True:
            ht_number = len(self.path)
            ht_list = []
            score_list = []
            score_list2 = []
            for one_ht in range(ht_number):
                ht = self.path[one_ht][0][-1, :, :]
                ht_list.append(ht)
                #r=Q.t()
                #ht=self.W3(ht)
                #score_r_ht=r.mm(ht).squeeze(1)
                score_r_ht = (Q * ht).squeeze(1)
                score_list.append(score_r_ht)
            score_all = score_list[0].exp()
            for score_number in range(score_list.__len__()):
                if score_number == 0:
                    continue
                else:
                    score_all = score_all + score_list[score_number].exp()
            for ht_n in range(ht_list.__len__()):
                alpha_t_r2 = (score_list[ht_n].exp() / score_all)

                alpha_t_r = alpha_t_r2.mul(ht_list[ht_n])
                score_list2.append(alpha_t_r)
            score_all2 = score_list2[0]
            for score_number2 in range(score_list2.__len__()):
                if score_number2 == 0:
                    continue
                else:
                    score_all2 = score_all2 + score_list[score_number2]
            H = score_all2
        if self.relation_only:
            X = torch.cat([H, Q], dim=-1)
        elif self.relation_only_in_path:
            E_s = kg.get_entity_embeddings(e_s)
            E = kg.get_entity_embeddings(e)
            X = torch.cat([E, H, E_s, Q], dim=-1)
        else:
            E = kg.get_entity_embeddings(e)
            X = torch.cat([E, H, Q], dim=-1)

        # MLP
        X = self.W1(X)
        X = F.relu(X)
        X = self.W1Dropout(X)
        X = self.W2(X)
        X2 = self.W2Dropout(X)

        def policy_nn_fun_fusion(e, X2, action_space, fn, fn_kg):
            (r_space, e_space), action_mask = action_space
            dim1, dim2 = e_space.size()
            e1 = e
            for i in range(dim2 - 1):
                e1 = torch.cat((e1, e), 0)
            e1 = e1.view(dim1, dim2)
            e2 = e_space
            r = r_space
            em_score_can = []
            for i in range(dim1):
                e1_can = torch.squeeze(e1[i])
                r_can = torch.squeeze(r[i])
                e2_can = torch.squeeze(e2[i])
                em_score_can.append(
                    torch.squeeze(fn.forward_fact(e1_can, r_can, e2_can,
                                                  fn_kg)))

            em_score = em_score_can[0]
            for i in range(1, dim1):
                em_score = torch.cat((em_score, em_score_can[i]), 0)
            A = self.get_action_embedding((r_space, e_space), kg)
            # X2_2=torch.unsqueeze(X2, 2)
            # X2_3=A @ X2_2
            # X2_4=torch.squeeze(X2_3, 2)
            # X2_5=X2_4-(1 - action_mask) * ops.HUGE_INT
            # action_dist=F.softmax(X2_5,dim=-1)
            action_dist = F.softmax(
                torch.squeeze(A @ torch.unsqueeze(X2, 2), 2) -
                (1 - action_mask) * ops.HUGE_INT,
                dim=-1)
            # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask)
            #em_score=torch.squeeze(em_score)
            em_score = em_score.view(dim1, dim2)
            #long+short term dicision
            if self.long_short_term == 0:
                action_dist = action_dist * em_score
            #only long
            if self.long_short_term == 1:
                action_dist = action_dist
            #only short
            if self.long_short_term == 2:
                action_dist = em_score
            return action_dist, ops.entropy(action_dist)

        def pad_and_cat_action_space(action_spaces, inv_offset):
            db_r_space, db_e_space, db_action_mask = [], [], []
            for (r_space, e_space), action_mask in action_spaces:
                db_r_space.append(r_space)
                db_e_space.append(e_space)
                db_action_mask.append(action_mask)
            r_space = ops.pad_and_cat(db_r_space,
                                      padding_value=kg.dummy_r)[inv_offset]
            e_space = ops.pad_and_cat(db_e_space,
                                      padding_value=kg.dummy_e)[inv_offset]
            action_mask = ops.pad_and_cat(db_action_mask,
                                          padding_value=0)[inv_offset]
            action_space = ((r_space, e_space), action_mask)
            return action_space

        if use_action_space_bucketing:
            """

            """
            db_outcomes = []
            entropy_list = []
            references = []
            db_e, db_action_spaces, db_references = self.get_action_space_in_buckets_fusion(
                e, obs, kg)
            for e_b, action_space_b, reference_b in zip(
                    db_e, db_action_spaces, db_references):
                X2_b = X2[reference_b, :]
                action_dist_b, entropy_b = policy_nn_fun_fusion(
                    e_b, X2_b, action_space_b, fn, fn_kg)
                references.extend(reference_b)
                db_outcomes.append((action_space_b, action_dist_b))
                entropy_list.append(entropy_b)
            inv_offset = [
                i for i, _ in sorted(enumerate(references), key=lambda x: x[1])
            ]
            entropy = torch.cat(entropy_list, dim=0)[inv_offset]
            if merge_aspace_batching_outcome:
                db_action_dist = []
                for _, action_dist in db_outcomes:
                    db_action_dist.append(action_dist)
                action_space = pad_and_cat_action_space(
                    db_action_spaces, inv_offset)
                action_dist = ops.pad_and_cat(db_action_dist,
                                              padding_value=0)[inv_offset]
                db_outcomes = [(action_space, action_dist)]
                inv_offset = None
        else:
            action_space = self.get_action_space(e, obs, kg)
            action_dist, entropy = policy_nn_fun_fusion(
                e, X2, action_space, fn, fn_kg)
            db_outcomes = [(action_space, action_dist)]
            inv_offset = None

        return db_outcomes, inv_offset, entropy
    def top_k_action(log_action_dist, action_space, return_merge_scores=None):
        """
        Get top k actions.
            - k = beam_size if the beam size is smaller than or equal to the beam action space size
            - k = beam_action_space_size otherwise
        :param log_action_dist: [batch_size*k, action_space_size]
        :param action_space (r_space, e_space):
            r_space: [batch_size*k, action_space_size]
            e_space: [batch_size*k, action_space_size]
        :return:
            (next_r, next_e), action_prob, action_offset: [batch_size*new_k]
        """
        full_size = len(log_action_dist)
        assert (full_size % batch_size == 0)
        last_k = int(full_size / batch_size)

        (r_space, e_space), _ = action_space
        action_space_size = r_space.size()[1]
        # => [batch_size, k'*action_space_size]
        log_action_dist = log_action_dist.view(batch_size, -1)
        beam_action_space_size = log_action_dist.size()[1]
        k = min(beam_size, beam_action_space_size)

        if return_merge_scores is not None:
            if return_merge_scores == 'sum':
                reduce_method = torch.sum
            elif return_merge_scores == 'mean':
                reduce_method = torch.mean
            else:
                reduce_method = None

            all_action_ind = torch.LongTensor([range(beam_action_space_size) for _ in range(len(log_action_dist))]).cuda()
            # _, all_action_ind = torch.topk(all_action_ind, beam_action_space_size, largest=False)
            # print("all_action_ind:", all_action_ind.shape, all_action_ind)
            # print ("DEBUG all_action_ind:", all_action_ind.shape, all_action_ind)
            all_next_r = ops.batch_lookup(r_space.view(batch_size, -1), all_action_ind)
            all_next_e = ops.batch_lookup(e_space.view(batch_size, -1), all_action_ind)

            # print ("DEBUG all_next_e:", all_next_e.shape, all_next_e)
            # print ("DEBUG all_next_r:", all_next_r.shape, all_next_r)

            real_log_action_prob, real_next_e, real_action_ind, real_next_r = ops.merge_same(log_action_dist, all_next_e, all_next_r, method=reduce_method)


            # print("DEBUG real_log_action_prob:", real_log_action_prob.shape, real_log_action_prob)
            # print("DEBUG real_next_e:", real_next_e.shape, real_next_e)

            next_e_list, next_r_list, action_ind_list, log_action_prob_list = [], [], [], []
            for i in range(batch_size):
                k_prime = min(len(real_log_action_prob[i]), k)
                top_log_prob, top_ind = torch.topk(real_log_action_prob[i], k_prime)
                top_next_e, top_next_r, top_ind = real_next_e[i][top_ind], real_next_r[i][top_ind], real_action_ind[i][top_ind]
                next_e_list.append(top_next_e.unsqueeze(0))
                next_r_list.append(top_next_r.unsqueeze(0))
                action_ind_list.append(top_ind.unsqueeze(0))
                log_action_prob_list.append(top_log_prob.unsqueeze(0))

            # print("DEBUG -->next_e_list:", next_e_list, next_e_list)
            # print("DEBUG -->next_r_list:", next_r_list, next_r_list)

            next_r = ops.pad_and_cat(next_r_list, padding_value=kg.dummy_r).view(-1)
            next_e = ops.pad_and_cat(next_e_list, padding_value=kg.dummy_e).view(-1)
            log_action_prob = ops.pad_and_cat(log_action_prob_list, padding_value=0.0).view(-1)
            action_ind = ops.pad_and_cat(action_ind_list, padding_value=-1).view(-1)

            # print("DEBUG next_r, next_e:", next_e.shape, next_r.shape)

            # next_r = ops.pad_and_cat(next_r_list, padding_value=kg.dummy_r, padding_dim=0).view(-1)
            # next_e = ops.pad_and_cat(next_e_list, padding_value=kg.dummy_e, padding_dim=0).view(-1)
            # log_action_prob = ops.pad_and_cat(log_action_prob_list, padding_value=0.0, padding_dim=0).view(-1)
            # action_ind = ops.pad_and_cat(action_ind_list, padding_value=-1, padding_dim=0).view(-1)
        else:
            log_action_prob, action_ind = torch.topk(log_action_dist, k)
            next_r = ops.batch_lookup(r_space.view(batch_size, -1), action_ind).view(-1)
            next_e = ops.batch_lookup(e_space.view(batch_size, -1), action_ind).view(-1)

        # print ("log_action_dist:", log_action_dist)
        #old start
        # log_action_prob, action_ind = torch.topk(log_action_dist, k)
        # next_r = ops.batch_lookup(r_space.view(batch_size, -1), action_ind).view(-1)
        # next_e = ops.batch_lookup(e_space.view(batch_size, -1), action_ind).view(-1)
        #old end

        # [batch_size, k] => [batch_size*k]
        log_action_prob = log_action_prob.view(-1)
        # *** compute parent offset
        # [batch_size, k]
        action_beam_offset = action_ind / action_space_size
        # [batch_size, 1]
        action_batch_offset = int_var_cuda(torch.arange(batch_size) * last_k).unsqueeze(1)
        # [batch_size, k] => [batch_size*k]
        action_offset = (action_batch_offset + action_beam_offset).view(-1)
        return (next_r, next_e), log_action_prob, action_offset