def top_k_answer_unique(log_action_dist, action_space): """ Get top k unique entities - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*beam_size, action_space_size] :param action_space (r_space, e_space): r_space: [batch_size*beam_size, action_space_size] e_space: [batch_size*beam_size, action_space_size] :return: (next_r, next_e), action_prob, action_offset: [batch_size*k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) (r_space, e_space), _ = action_space action_space_size = r_space.size()[1] r_space = r_space.view(batch_size, -1) e_space = e_space.view(batch_size, -1) log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] assert (beam_action_space_size % action_space_size == 0) k = min(beam_size, beam_action_space_size) next_r_list, next_e_list = [], [] log_action_prob_list = [] action_offset_list = [] for i in range(batch_size): log_action_dist_b = log_action_dist[i] r_space_b = r_space[i] e_space_b = e_space[i] unique_e_space_b = var_cuda(torch.unique(e_space_b.data.cpu())) unique_log_action_dist, unique_idx = unique_max( unique_e_space_b, e_space_b, log_action_dist_b) k_prime = min(len(unique_e_space_b), k) top_unique_log_action_dist, top_unique_idx2 = torch.topk( unique_log_action_dist, k_prime) top_unique_idx = unique_idx[top_unique_idx2] top_unique_beam_offset = top_unique_idx // action_space_size top_r = r_space_b[top_unique_idx] top_e = e_space_b[top_unique_idx] next_r_list.append(top_r.unsqueeze(0)) next_e_list.append(top_e.unsqueeze(0)) log_action_prob_list.append( top_unique_log_action_dist.unsqueeze(0)) top_unique_batch_offset = i * last_k top_unique_action_offset = top_unique_batch_offset + top_unique_beam_offset action_offset_list.append(top_unique_action_offset.unsqueeze(0)) next_r = ops.pad_and_cat(next_r_list, padding_value=kg.dummy_r).view(-1) next_e = ops.pad_and_cat(next_e_list, padding_value=kg.dummy_e).view(-1) log_action_prob = ops.pad_and_cat(log_action_prob_list, padding_value=-ops.HUGE_INT) action_offset = ops.pad_and_cat(action_offset_list, padding_value=-1) return (next_r, next_e), log_action_prob.view(-1), action_offset.view(-1)
def pad_and_cat_action_space(action_spaces, inv_offset): db_r_space, db_e_space, db_action_mask = [], [], [] for (r_space, e_space), action_mask in action_spaces: db_r_space.append(r_space) db_e_space.append(e_space) db_action_mask.append(action_mask) r_space = ops.pad_and_cat(db_r_space, padding_value=kg.dummy_r)[inv_offset] e_space = ops.pad_and_cat(db_e_space, padding_value=kg.dummy_e)[inv_offset] action_mask = ops.pad_and_cat(db_action_mask, padding_value=0)[inv_offset] action_space = ((r_space, e_space), action_mask) return action_space
def pad_and_cat_action_space(action_spaces: List[ActionSpace], inv_offset, kg: KnowledgeGraph): db_r_space, db_e_space, db_action_mask = [], [], [] forks = [] for acsp in action_spaces: forks += acsp.forks db_r_space.append(acsp.r_space) db_e_space.append(acsp.e_space) db_action_mask.append(acsp.action_mask) r_space = ops.pad_and_cat(db_r_space, padding_value=kg.dummy_r)[inv_offset] e_space = ops.pad_and_cat(db_e_space, padding_value=kg.dummy_e)[inv_offset] action_mask = ops.pad_and_cat(db_action_mask, padding_value=0)[inv_offset] action_space = ActionSpace(forks, r_space, e_space, action_mask) return action_space
def virtual_step(self, e_set, r): """ Given a set of entities (e_set), find the set of entities (e_set_out) which has at least one incoming edge labeled r and the source entity is in e_set. """ batch_size = len(e_set) e_set_1D = e_set.view(-1) r_space = self.action_space[0][0][e_set_1D] e_space = self.action_space[0][1][e_set_1D] e_space = (r_space.view(batch_size, -1) == r.unsqueeze(1)).long() * e_space.view(batch_size, -1) e_set_out = [] for i in range(len(e_space)): e_set_out_b = var_cuda(unique(e_space[i].data)) e_set_out.append(e_set_out_b.unsqueeze(0)) e_set_out = ops.pad_and_cat(e_set_out, padding_value=self.dummy_e) return e_set_out
def do_it_with_bucketing( self, X2, current_entity, kg, merge_aspace_batching_outcome, obs: Observation, policy_nn_fun, ): entropy_list = [] references = [] buckect_action_spaces, inthis_bucket_indizes = self.get_action_space_in_buckets( current_entity, obs, kg) action_spaces = [] action_dists = [] for as_b, inthis_bucket in zip(buckect_action_spaces, inthis_bucket_indizes): X2_b = X2[inthis_bucket, :] action_dist_b, entropy_b = policy_nn_fun(X2_b, as_b) references.extend(inthis_bucket) action_spaces.append(as_b) action_dists.append(action_dist_b) entropy_list.append(entropy_b) inv_offset = [ i for i, _ in sorted(enumerate(references), key=lambda x: x[1]) ] entropy = torch.cat(entropy_list, dim=0)[inv_offset] action = BucketActions(action_spaces, action_dists, inv_offset, entropy) if merge_aspace_batching_outcome: action_space = pad_and_cat_action_space(buckect_action_spaces, inv_offset, kg) action_dist = ops.pad_and_cat(action.action_dists, padding_value=0)[inv_offset] action = BucketActions([action_space], [action_dist], None, entropy) return action
def transit(self, e, obs, kg, kg_pred=None, fn_kg=None, use_action_space_bucketing=True, merge_aspace_batching_outcome=False, use_kg_pred=False, inference=False): """ Compute the next action distribution based on (a) the current node (entity) in KG and the query relation (b) action history representation :param e: agent location (node) at step t. :param obs: agent observation at step t. e_s: source node q: query relation e_t: target node last_step: If set, the agent is carrying out the last step. last_r: label of edge traversed in the previous step seen_nodes: notes seen on the paths :param kg: Knowledge graph environment. :param use_action_space_bucketing: If set, group the action space of different nodes into buckets by their sizes. :param merge_aspace_batch_outcome: If set, merge the transition probability distribution generated of different action space bucket into a single batch. :return With aspace batching and without merging the outcomes: db_outcomes: (Dynamic Batch) (action_space, action_dist) action_space: (Batch) padded possible action indices action_dist: (Batch) distribution over actions. inv_offset: Indices to set the dynamic batching output back to the original order. entropy: (Batch) entropy of action distribution. Else: action_dist: (Batch) distribution over actions. entropy: (Batch) entropy of action distribution. """ e_s, q, e_t, last_step, last_r, seen_nodes = obs # Representation of the current state (current node and other observations) Q = kg.get_relation_embeddings(q) H = self.path[-1][0][-1, :, :] if self.relation_only: X = torch.cat([H, Q], dim=-1) elif self.relation_only_in_path: E_s = kg.get_entity_embeddings(e_s) E = kg.get_entity_embeddings(e) X = torch.cat([E, H, E_s, Q], dim=-1) elif use_kg_pred: E = kg.get_entity_embeddings(e) X = torch.cat([E, H, Q, kg_pred], dim=-1) # X = torch.cat([E, H, Q, kg.get_entity_embeddings(e_t)], dim=-1) else: E = kg.get_entity_embeddings(e) X = torch.cat([E, H, Q], dim=-1) # MLP X = self.W1(X) X = F.relu(X) X = self.W1Dropout(X) X = self.W2(X) X2 = self.W2Dropout(X) relation_att = torch.matmul(self.W_att(X2), kg.get_all_relation_embeddings().t()) # B x |R| # Trick -> mask SIM relation relation_att = torch.nn.functional.softmax(relation_att, dim=-1) def policy_nn_fun(X2, action_space): (r_space, e_space), action_mask = action_space A = self.get_action_embedding((r_space, e_space), kg) action_dist = F.softmax( torch.squeeze(A @ torch.unsqueeze(X2, 2), 2) - (1 - action_mask) * ops.HUGE_INT, dim=-1) # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask) return action_dist, ops.entropy(action_dist) def pad_and_cat_action_space(action_spaces, inv_offset): db_r_space, db_e_space, db_action_mask = [], [], [] for (r_space, e_space), action_mask in action_spaces: db_r_space.append(r_space) db_e_space.append(e_space) db_action_mask.append(action_mask) r_space = ops.pad_and_cat(db_r_space, padding_value=kg.dummy_r)[inv_offset] e_space = ops.pad_and_cat(db_e_space, padding_value=kg.dummy_e)[inv_offset] action_mask = ops.pad_and_cat(db_action_mask, padding_value=0)[inv_offset] action_space = ((r_space, e_space), action_mask) return action_space if use_action_space_bucketing: """ """ db_outcomes = [] entropy_list = [] references = [] db_action_spaces, db_references = self.get_action_space_in_buckets( e, obs, kg, relation_att=relation_att, inference=inference) for action_space_b, reference_b in zip(db_action_spaces, db_references): X2_b = X2[reference_b, :] action_dist_b, entropy_b = policy_nn_fun(X2_b, action_space_b) references.extend(reference_b) db_outcomes.append((action_space_b, action_dist_b)) entropy_list.append(entropy_b) inv_offset = [ i for i, _ in sorted(enumerate(references), key=lambda x: x[1]) ] entropy = torch.cat(entropy_list, dim=0)[inv_offset] if merge_aspace_batching_outcome: db_action_dist = [] for _, action_dist in db_outcomes: db_action_dist.append(action_dist) action_space = pad_and_cat_action_space( db_action_spaces, inv_offset) action_dist = ops.pad_and_cat(db_action_dist, padding_value=0)[inv_offset] db_outcomes = [(action_space, action_dist)] inv_offset = None else: action_space = self.get_action_space(e, obs, kg) action_dist, entropy = policy_nn_fun(X2, action_space) db_outcomes = [(action_space, action_dist)] inv_offset = None return db_outcomes, inv_offset, entropy
def transit(self, e, obs, kg, mode, use_action_space_bucketing=True, merge_aspace_batching_outcome=False): """ Compute the next action distribution based on (a) the current node (entity) in KG and the query relation (b) action history representation :param e: agent location (node) at step t. :param obs: agent observation at step t. e_s: source node q: query relation e_t: target node first_step: If set, the agent is carrying out the first step last_step: If set, the agent is carrying out the last step. last_r: label of edge traversed in the previous step seen_nodes: nodes seen on the paths :param kg: Knowledge graph environment. :param use_action_space_bucketing: If set, group the action space of different nodes into buckets by their sizes. :param merge_aspace_batch_outcome: If set, merge the transition probability distribution generated of different action space bucket into a single batch. :return With aspace batching and without merging the outcomes: db_outcomes: (Dynamic Batch) (action_space, action_dist) action_space: (Batch) padded possible action indices action_dist: (Batch) distribution over actions. inv_offset: Indices to set the dynamic batching output back to the original order. entropy: (Batch) entropy of action distribution. Else: action_dist: (Batch) distribution over actions. entropy: (Batch) entropy of action distribution. """ e_s, emb_e_s, q, e_t, first_step, last_step, last_r, seen_nodes = obs # Representation of the current state (current node and other observations) Q = self.graph_transformer.dropout(self.graph_transformer.emb_r(q)) H = self.path[-1][0][-1, :, :] if self.relation_only: X = torch.cat([H, Q], dim=-1) elif self.relation_only_in_path: E_s = emb_e_s if first_step: E = emb_e_s else: if mode == 'train': E, _ = self.graph_transformer(e, q, self.dg.training_graph, self.dg.seen_id2entity, self.bandwidth, mode) else: E, _ = self.graph_transformer(e, q, self.dg.eval_graph, self.dg.seen_id2entity, self.bandwidth, mode) if E.size()[0] != E_s.size(0): expansion_size = int(E.size()[0] / E_s.size(0)) E_s = E_s.unsqueeze(1).expand(-1, expansion_size, -1) E_s = torch.flatten(E_s, start_dim=0).view(-1, self.entity_dim) X = torch.cat([E, H, E_s, Q], dim=-1) else: if first_step: E = emb_e_s else: if mode == 'train': E, _ = self.graph_transformer(e, q, self.dg.training_graph, self.dg.seen_id2entity, self.bandwidth, mode) else: if self.args.inference: E, _ = self.graph_transformer(e, q, self.dg.aux_graph, self.dg.seen_id2entity, self.bandwidth, 'test') else: E, _ = self.graph_transformer(e, q, self.dg.eval_graph, self.dg.seen_id2entity, self.bandwidth, 'eval') X = torch.cat([E, H, Q], dim=-1) # MLP X = self.W1(X) X = F.relu(X) X = self.W1Dropout(X) X = self.W2(X) X2 = self.W2Dropout(X) def policy_nn_fun(X2, action_space): (r_space, e_space), action_mask = action_space A = self.get_action_embedding((r_space, e_space), kg) action_dist = F.softmax( torch.squeeze(A @ torch.unsqueeze(X2, 2), 2) - (1 - action_mask) * ops.HUGE_INT, dim=-1) # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask) return action_dist, ops.entropy(action_dist) def pad_and_cat_action_space(action_spaces, inv_offset): db_r_space, db_e_space, db_action_mask = [], [], [] for (r_space, e_space), action_mask in action_spaces: db_r_space.append(r_space) db_e_space.append(e_space) db_action_mask.append(action_mask) r_space = ops.pad_and_cat(db_r_space, padding_value=kg.dummy_r)[inv_offset] e_space = ops.pad_and_cat(db_e_space, padding_value=kg.dummy_e)[inv_offset] action_mask = ops.pad_and_cat(db_action_mask, padding_value=0)[inv_offset] action_space = ((r_space, e_space), action_mask) return action_space if use_action_space_bucketing: db_outcomes = [] entropy_list = [] references = [] db_action_spaces, db_references = self.get_action_space_in_buckets( e, obs, kg) for action_space_b, reference_b in zip(db_action_spaces, db_references): X2_b = X2[reference_b, :] action_dist_b, entropy_b = policy_nn_fun(X2_b, action_space_b) references.extend(reference_b) db_outcomes.append((action_space_b, action_dist_b)) entropy_list.append(entropy_b) inv_offset = [ i for i, _ in sorted(enumerate(references), key=lambda x: x[1]) ] entropy = torch.cat(entropy_list, dim=0)[inv_offset] if merge_aspace_batching_outcome: db_action_dist = [] for _, action_dist in db_outcomes: db_action_dist.append(action_dist) action_space = pad_and_cat_action_space( db_action_spaces, inv_offset) action_dist = ops.pad_and_cat(db_action_dist, padding_value=0)[inv_offset] db_outcomes = [(action_space, action_dist)] inv_offset = None else: action_space = self.get_action_space(e, obs, kg) action_dist, entropy = policy_nn_fun(X2, action_space) db_outcomes = [(action_space, action_dist)] inv_offset = None return db_outcomes, inv_offset, entropy
def transit_fusion(self, e, obs, kg, fn, fn_kg, use_action_space_bucketing=True, merge_aspace_batching_outcome=False): """ Compute the next action distribution based on (a) the current node (entity) in KG and the query relation (b) action history representation :param e: agent location (node) at step t. :param obs: agent observation at step t. e_s: source node q: query relation e_t: target node last_step: If set, the agent is carrying out the last step. last_r: label of edge traversed in the previous step seen_nodes: notes seen on the paths :param kg: Knowledge graph environment. :param use_action_space_bucketing: If set, group the action space of different nodes into buckets by their sizes. :param merge_aspace_batch_outcome: If set, merge the transition probability distribution generated of different action space bucket into a single batch. :return With aspace batching and without merging the outcomes: db_outcomes: (Dynamic Batch) (action_space, action_dist) action_space: (Batch) padded possible action indices action_dist: (Batch) distribution over actions. inv_offset: Indices to set the dynamic batching output back to the original order. entropy: (Batch) entropy of action distribution. Else: action_dist: (Batch) distribution over actions. entropy: (Batch) entropy of action distribution. """ e_s, q, e_t, last_step, last_r, seen_nodes = obs # Representation of the current state (current node and other observations) Q = kg.get_relation_embeddings(q) H = self.path[-1][0][-1, :, :] if self.using_attention == True: ht_number = len(self.path) ht_list = [] score_list = [] score_list2 = [] for one_ht in range(ht_number): ht = self.path[one_ht][0][-1, :, :] ht_list.append(ht) #r=Q.t() #ht=self.W3(ht) #score_r_ht=r.mm(ht).squeeze(1) score_r_ht = (Q * ht).squeeze(1) score_list.append(score_r_ht) score_all = score_list[0].exp() for score_number in range(score_list.__len__()): if score_number == 0: continue else: score_all = score_all + score_list[score_number].exp() for ht_n in range(ht_list.__len__()): alpha_t_r2 = (score_list[ht_n].exp() / score_all) alpha_t_r = alpha_t_r2.mul(ht_list[ht_n]) score_list2.append(alpha_t_r) score_all2 = score_list2[0] for score_number2 in range(score_list2.__len__()): if score_number2 == 0: continue else: score_all2 = score_all2 + score_list[score_number2] H = score_all2 if self.relation_only: X = torch.cat([H, Q], dim=-1) elif self.relation_only_in_path: E_s = kg.get_entity_embeddings(e_s) E = kg.get_entity_embeddings(e) X = torch.cat([E, H, E_s, Q], dim=-1) else: E = kg.get_entity_embeddings(e) X = torch.cat([E, H, Q], dim=-1) # MLP X = self.W1(X) X = F.relu(X) X = self.W1Dropout(X) X = self.W2(X) X2 = self.W2Dropout(X) def policy_nn_fun_fusion(e, X2, action_space, fn, fn_kg): (r_space, e_space), action_mask = action_space dim1, dim2 = e_space.size() e1 = e for i in range(dim2 - 1): e1 = torch.cat((e1, e), 0) e1 = e1.view(dim1, dim2) e2 = e_space r = r_space em_score_can = [] for i in range(dim1): e1_can = torch.squeeze(e1[i]) r_can = torch.squeeze(r[i]) e2_can = torch.squeeze(e2[i]) em_score_can.append( torch.squeeze(fn.forward_fact(e1_can, r_can, e2_can, fn_kg))) em_score = em_score_can[0] for i in range(1, dim1): em_score = torch.cat((em_score, em_score_can[i]), 0) A = self.get_action_embedding((r_space, e_space), kg) # X2_2=torch.unsqueeze(X2, 2) # X2_3=A @ X2_2 # X2_4=torch.squeeze(X2_3, 2) # X2_5=X2_4-(1 - action_mask) * ops.HUGE_INT # action_dist=F.softmax(X2_5,dim=-1) action_dist = F.softmax( torch.squeeze(A @ torch.unsqueeze(X2, 2), 2) - (1 - action_mask) * ops.HUGE_INT, dim=-1) # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask) #em_score=torch.squeeze(em_score) em_score = em_score.view(dim1, dim2) #long+short term dicision if self.long_short_term == 0: action_dist = action_dist * em_score #only long if self.long_short_term == 1: action_dist = action_dist #only short if self.long_short_term == 2: action_dist = em_score return action_dist, ops.entropy(action_dist) def pad_and_cat_action_space(action_spaces, inv_offset): db_r_space, db_e_space, db_action_mask = [], [], [] for (r_space, e_space), action_mask in action_spaces: db_r_space.append(r_space) db_e_space.append(e_space) db_action_mask.append(action_mask) r_space = ops.pad_and_cat(db_r_space, padding_value=kg.dummy_r)[inv_offset] e_space = ops.pad_and_cat(db_e_space, padding_value=kg.dummy_e)[inv_offset] action_mask = ops.pad_and_cat(db_action_mask, padding_value=0)[inv_offset] action_space = ((r_space, e_space), action_mask) return action_space if use_action_space_bucketing: """ """ db_outcomes = [] entropy_list = [] references = [] db_e, db_action_spaces, db_references = self.get_action_space_in_buckets_fusion( e, obs, kg) for e_b, action_space_b, reference_b in zip( db_e, db_action_spaces, db_references): X2_b = X2[reference_b, :] action_dist_b, entropy_b = policy_nn_fun_fusion( e_b, X2_b, action_space_b, fn, fn_kg) references.extend(reference_b) db_outcomes.append((action_space_b, action_dist_b)) entropy_list.append(entropy_b) inv_offset = [ i for i, _ in sorted(enumerate(references), key=lambda x: x[1]) ] entropy = torch.cat(entropy_list, dim=0)[inv_offset] if merge_aspace_batching_outcome: db_action_dist = [] for _, action_dist in db_outcomes: db_action_dist.append(action_dist) action_space = pad_and_cat_action_space( db_action_spaces, inv_offset) action_dist = ops.pad_and_cat(db_action_dist, padding_value=0)[inv_offset] db_outcomes = [(action_space, action_dist)] inv_offset = None else: action_space = self.get_action_space(e, obs, kg) action_dist, entropy = policy_nn_fun_fusion( e, X2, action_space, fn, fn_kg) db_outcomes = [(action_space, action_dist)] inv_offset = None return db_outcomes, inv_offset, entropy
def top_k_action(log_action_dist, action_space, return_merge_scores=None): """ Get top k actions. - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*k, action_space_size] :param action_space (r_space, e_space): r_space: [batch_size*k, action_space_size] e_space: [batch_size*k, action_space_size] :return: (next_r, next_e), action_prob, action_offset: [batch_size*new_k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) (r_space, e_space), _ = action_space action_space_size = r_space.size()[1] # => [batch_size, k'*action_space_size] log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] k = min(beam_size, beam_action_space_size) if return_merge_scores is not None: if return_merge_scores == 'sum': reduce_method = torch.sum elif return_merge_scores == 'mean': reduce_method = torch.mean else: reduce_method = None all_action_ind = torch.LongTensor([range(beam_action_space_size) for _ in range(len(log_action_dist))]).cuda() # _, all_action_ind = torch.topk(all_action_ind, beam_action_space_size, largest=False) # print("all_action_ind:", all_action_ind.shape, all_action_ind) # print ("DEBUG all_action_ind:", all_action_ind.shape, all_action_ind) all_next_r = ops.batch_lookup(r_space.view(batch_size, -1), all_action_ind) all_next_e = ops.batch_lookup(e_space.view(batch_size, -1), all_action_ind) # print ("DEBUG all_next_e:", all_next_e.shape, all_next_e) # print ("DEBUG all_next_r:", all_next_r.shape, all_next_r) real_log_action_prob, real_next_e, real_action_ind, real_next_r = ops.merge_same(log_action_dist, all_next_e, all_next_r, method=reduce_method) # print("DEBUG real_log_action_prob:", real_log_action_prob.shape, real_log_action_prob) # print("DEBUG real_next_e:", real_next_e.shape, real_next_e) next_e_list, next_r_list, action_ind_list, log_action_prob_list = [], [], [], [] for i in range(batch_size): k_prime = min(len(real_log_action_prob[i]), k) top_log_prob, top_ind = torch.topk(real_log_action_prob[i], k_prime) top_next_e, top_next_r, top_ind = real_next_e[i][top_ind], real_next_r[i][top_ind], real_action_ind[i][top_ind] next_e_list.append(top_next_e.unsqueeze(0)) next_r_list.append(top_next_r.unsqueeze(0)) action_ind_list.append(top_ind.unsqueeze(0)) log_action_prob_list.append(top_log_prob.unsqueeze(0)) # print("DEBUG -->next_e_list:", next_e_list, next_e_list) # print("DEBUG -->next_r_list:", next_r_list, next_r_list) next_r = ops.pad_and_cat(next_r_list, padding_value=kg.dummy_r).view(-1) next_e = ops.pad_and_cat(next_e_list, padding_value=kg.dummy_e).view(-1) log_action_prob = ops.pad_and_cat(log_action_prob_list, padding_value=0.0).view(-1) action_ind = ops.pad_and_cat(action_ind_list, padding_value=-1).view(-1) # print("DEBUG next_r, next_e:", next_e.shape, next_r.shape) # next_r = ops.pad_and_cat(next_r_list, padding_value=kg.dummy_r, padding_dim=0).view(-1) # next_e = ops.pad_and_cat(next_e_list, padding_value=kg.dummy_e, padding_dim=0).view(-1) # log_action_prob = ops.pad_and_cat(log_action_prob_list, padding_value=0.0, padding_dim=0).view(-1) # action_ind = ops.pad_and_cat(action_ind_list, padding_value=-1, padding_dim=0).view(-1) else: log_action_prob, action_ind = torch.topk(log_action_dist, k) next_r = ops.batch_lookup(r_space.view(batch_size, -1), action_ind).view(-1) next_e = ops.batch_lookup(e_space.view(batch_size, -1), action_ind).view(-1) # print ("log_action_dist:", log_action_dist) #old start # log_action_prob, action_ind = torch.topk(log_action_dist, k) # next_r = ops.batch_lookup(r_space.view(batch_size, -1), action_ind).view(-1) # next_e = ops.batch_lookup(e_space.view(batch_size, -1), action_ind).view(-1) #old end # [batch_size, k] => [batch_size*k] log_action_prob = log_action_prob.view(-1) # *** compute parent offset # [batch_size, k] action_beam_offset = action_ind / action_space_size # [batch_size, 1] action_batch_offset = int_var_cuda(torch.arange(batch_size) * last_k).unsqueeze(1) # [batch_size, k] => [batch_size*k] action_offset = (action_batch_offset + action_beam_offset).view(-1) return (next_r, next_e), log_action_prob, action_offset