def bulid_model_rl(self): self.state_lstm = KGState_LSTM(self.args, history_len=1) self.state_prop_l1 = nn.Linear(4 * self.embed_size, 2 * self.embed_size) self.state_prop_l2 = nn.Linear(2 * self.embed_size, 1) self.transfor_state = nn.Linear(2 * self.embed_size, 2 * self.embed_size) self.state_tr_query = nn.Linear(self.embed_size * 3, self.embed_size) self.l1 = nn.Linear(2 * self.embed_size, self.hidden_sizes[1]) self.l2 = nn.Linear(self.hidden_sizes[0], self.hidden_sizes[1]) self.actor = nn.Linear(self.hidden_sizes[1], self.act_dim) self.critic = nn.Linear(self.hidden_sizes[1], 1) self.saved_actions = [] self.rewards = [] self.entropy = []
class ActorCritic_lstm_mf_base_test(nn.Module): def __init__(self, args, user_triplet_set, rela_2_index, act_dim, gamma=0.99, hidden_sizes=[512, 256]): super(ActorCritic_lstm_mf_base_test, self).__init__() self.args = args self.act_dim = act_dim self.device = args.device self.sub_batch_size = args.sub_batch_size self.gamma = gamma self.p_hop = args.p_hop self.hidden_sizes = hidden_sizes self.n_memory = args.n_memory self.cast_st_save = args.cast_st_save self.embed_size = args.embed_size self.user_o = args.user_o self.h0_embbed = args.h0_embbed self.user_triplet_set = user_triplet_set self.rela_2_index = rela_2_index if self.args.envir == 'p1': self._get_next_node_type = self._get_next_node_type_meta self.kg_emb = KnowledgeEmbedding_memory(args) elif self.args.envir == 'p2': self._get_next_node_type = self._get_next_node_type_graph self.kg_emb = KnowledgeEmbedding_memory_graph(args) dataset = load_dataset_core(args, args.dataset) self.et_idx2ty = dataset.et_idx2ty self.entity_list = dataset.entity_list self.rela_list = dataset.rela_list self.bulid_model_rl() def _get_next_node_type_meta(self, curr_node_type, next_relation, next_entity_id): # print('curr_node_type = ', curr_node_type) # print('next_relation = ', next_relation) return KG_RELATION[curr_node_type][next_relation] def _get_next_node_type_graph(self, curr_node_type, next_relation, next_entity_id): return self.et_idx2ty[next_entity_id] def bulid_model_rl(self): self.state_lstm = KGState_LSTM(self.args, history_len=1) self.state_prop_l1 = nn.Linear(4 * self.embed_size, 2 * self.embed_size) self.state_prop_l2 = nn.Linear(2 * self.embed_size, 1) self.transfor_state = nn.Linear(2 * self.embed_size, 2 * self.embed_size) self.state_tr_query = nn.Linear(self.embed_size * 3, self.embed_size) self.l1 = nn.Linear(2 * self.embed_size, self.hidden_sizes[1]) self.l2 = nn.Linear(self.hidden_sizes[0], self.hidden_sizes[1]) self.actor = nn.Linear(self.hidden_sizes[1], self.act_dim) self.critic = nn.Linear(self.hidden_sizes[1], 1) self.saved_actions = [] self.rewards = [] self.entropy = [] # def forward(self, inputs): # state, batch_next_action_emb, act_mask = inputs # state: [bs, state_dim], act_mask: [bs, act_dim] # print('state = ', state.shape) # # state = state.squeeze() # print('state = ', state.shape) # # state_tr = state.unsqueeze(1).repeat(1, batch_next_action_emb.shape[1], 1) # print('batch_next_action_emb = ', batch_next_action_emb.shape) # # # input() # # state_output = th.cat([state_tr, batch_next_action_emb], -1) # state = state.squeeze() # state_tr = state.unsqueeze(1).repeat(1, batch_next_action_emb.shape[1], 1) # probs = state_tr * batch_next_action_emb # probs = probs.sum(-1) # probs = probs.masked_fill(~act_mask, value=torch.tensor(-1e10)) # act_probs = F.softmax(probs, dim=-1) # x = self.l1(state) # x = F.dropout(F.elu(x), p=0.4) # state_values = self.critic(x) # Tensor of [bs, 1] # return act_probs, state_values def forward(self, inputs): state, batch_next_action_emb, act_mask = inputs # state: [bs, state_dim], act_mask: [bs, act_dim] state = state.squeeze().unsqueeze(0) # print('state = ', state.shape) # # state = state.squeeze() # print('state = ', state.shape) # # state_tr = state.unsqueeze(1).repeat(1, batch_next_action_emb.shape[1], 1) # print('batch_next_action_emb = ', batch_next_action_emb.shape) state_tr = state.unsqueeze(1).repeat(1, batch_next_action_emb.shape[1], 1) probs = state_tr * batch_next_action_emb probs = probs.sum(-1) probs = probs.masked_fill(~act_mask, value=torch.tensor(-1e10)) act_probs = F.softmax(probs, dim=-1) x = self.l1(state) x = F.dropout(F.elu(x), p=0.4) state_values = self.critic(x) # Tensor of [bs, 1] return act_probs, state_values # def forward(self, inputs): # state, batch_next_action_emb, act_mask = inputs # state: [bs, state_dim], act_mask: [bs, act_dim] # state = state.squeeze().unsqueeze(0) # state_tr = state.unsqueeze(1).repeat(1, batch_next_action_emb.shape[1], 1) # state_output = th.cat([state_tr, batch_next_action_emb], -1) # # [batch_size, n_memory] # state_output_ = self.state_prop_l1(state_output).squeeze().unsqueeze(0) # probs = self.state_prop_l2(state_output_).squeeze().unsqueeze(0) # probs = probs.masked_fill(~act_mask, value=torch.tensor(-1e10)) # act_probs = F.softmax(probs, dim=-1) # x = self.l1(state) # x = F.dropout(F.elu(x), p=0.4) # state_values = self.critic(x) # Tensor of [bs, 1] # return act_probs, state_values # # [batch_size, n_memory] # state_output_ = self.state_prop_l1(state_output).squeeze().unsqueeze(0) # probs = self.state_prop_l2(state_output_).squeeze().unsqueeze(0) # probs = probs.masked_fill(~act_mask, value=torch.tensor(-1e10)) # act_probs = F.softmax(probs, dim=-1) # x = self.l1(state) # x = F.dropout(F.elu(x), p=0.4) # state_values = self.critic(x) # Tensor of [bs, 1] # return act_probs, state_values def select_action(self, batch_state, batch_next_action_emb, batch_act_mask, device): act_mask = torch.BoolTensor(batch_act_mask).to( device) # Tensor of [bs, act_dim] probs, value = self( (batch_state, batch_next_action_emb, act_mask)) # act_probs: [bs, act_dim], state_value: [bs, 1] m = Categorical(probs) acts = m.sample() # Tensor of [bs, ], requires_grad=False # [CAVEAT] If sampled action is out of action_space, choose the first action in action_space. valid_idx = act_mask.gather(1, acts.view(-1, 1)).view(-1) acts[valid_idx == 0] = 0 self.saved_actions.append(SavedAction(m.log_prob(acts), value)) self.entropy.append(m.entropy()) return acts.cpu().numpy().tolist() def reset(self, uids=None): self.uids = [uid for uid in uids for _ in range(self.sub_batch_size)] self.dummy_rela = torch.ones( max(self.user_triplet_set) * 10 + 1, 1, self.embed_size) # print('len(self.user_triplet_set) = ', len(self.user_triplet_set)) # print('max(self.user_triplet_set) = ', max(self.user_triplet_set)) # self.dummy_rela = self.dummy_rela.to(self.device) self.dummy_rela = nn.Parameter(self.dummy_rela, requires_grad=True).to(self.device) self.prev_state_h, self.prev_state_c = self.state_lstm.set_up_hidden_state( len(self.uids)) def update_path_info(self, up_date_hop): new_uids = [] for row in up_date_hop: new_uids.append(self.uids[row]) self.uids = new_uids new_prev_state_h = [] new_prev_state_c = [] for row in up_date_hop: new_prev_state_h.append(self.prev_state_h[:, row, :].unsqueeze(1)) new_prev_state_c.append(self.prev_state_c[:, row, :].unsqueeze(1)) self.prev_state_h = th.cat(new_prev_state_h, 1).to(self.device) self.prev_state_c = th.cat(new_prev_state_c, 1).to(self.device) def generate_st_emb(self, batch_path, up_date_hop=None): if up_date_hop != None: self.update_path_info(up_date_hop) self.current_step = {} all_state = th.cat([ self._get_state_update(index, path).unsqueeze(0) for index, path in enumerate(batch_path) ], 0) state_output, self.prev_state_h, self.prev_state_c = self.state_lstm( all_state, self.prev_state_h, self.prev_state_c) return state_output def action_encoder(self, relation_emb, entitiy_emb): action_embedding = th.cat([relation_emb, entitiy_emb], -1) return action_embedding def _get_state_update(self, index, path): """Return state of numpy vector: [user_embed, curr_node_embed, last_node_embed, last_relation].""" if len(path) == 1: if self.user_o == True: user_embed = self.global_user[index, :].unsqueeze(0) else: user_embed = self.kg_emb.lookup_emb( USER, type_index=torch.LongTensor([path[0][-1]]).to( self.device))[0].unsqueeze(0) curr_node_embed = user_embed dummy_rela = self.dummy_rela[path[0][-1], :, :] st_emb = self.action_encoder(dummy_rela, user_embed) else: last_relation, curr_node_type, curr_node_id = path[-1] curr_node_embed = self.kg_emb.lookup_emb( curr_node_type, type_index=torch.LongTensor([curr_node_id]).to( self.device))[0].unsqueeze(0) last_relation_embed = self.kg_emb.lookup_rela_emb( last_relation)[0].unsqueeze(0) st_emb = self.action_encoder(last_relation_embed, curr_node_embed) return st_emb def generate_act_emb(self, batch_path, batch_curr_actions): self.current_step['user'] = str(self.uids[0]) batch_path_tmp = batch_path.copy() batch_curr_actions_tmp = batch_curr_actions.copy() batch_path_tmp[0] = batch_path[0].copy() batch_curr_actions_tmp[0] = batch_curr_actions[0].copy() # print('batch_path = ', batch_path) batch_path_tmp[0][0] = list(batch_path_tmp[0][0]).copy() batch_path_tmp[0][0][0] = self.args.rela_2_name[ batch_path_tmp[0][0][0]] if batch_path_tmp[0][0][ 0] in self.args.rela_2_name else batch_path_tmp[0][0][0] batch_path_tmp[0][0][2] = str(batch_path_tmp[0][0][2]) batch_path_tmp[0][0][2] = self.args.index_2_entity[ batch_path_tmp[0][0][2]] if batch_path_tmp[0][0][ 2] in self.args.index_2_entity else batch_path_tmp[0][0][2] b_path = [ ','.join([str(uni) for uni in bpa]) for bpa in batch_path_tmp[0] ] for index_1 in range(len(batch_curr_actions_tmp[0])): batch_curr_actions_tmp[0][index_1] = list( batch_curr_actions_tmp[0][index_1]).copy() batch_curr_actions_tmp[0][index_1][0] = self.args.rela_2_name[ batch_curr_actions_tmp[0][index_1] [0]] if batch_curr_actions_tmp[0][index_1][ 0] in self.args.rela_2_name else batch_curr_actions_tmp[0][ index_1][0] batch_curr_actions_tmp[0][index_1][1] = str( batch_curr_actions_tmp[0][index_1][1]) batch_curr_actions_tmp[0][index_1][1] = self.args.index_2_entity[batch_curr_actions_tmp[0][index_1][1]] \ if batch_curr_actions_tmp[0][index_1][1] in self.args.index_2_entity else batch_curr_actions_tmp[0][index_1][1] b_action = [ ', '.join([str(uni) for uni in bpa]) for bpa in batch_curr_actions_tmp[0] ] self.current_step["path"] = 'path = ' + ', next = '.join(b_path) self.current_step["actions"] = b_action all_action_set = [ self._get_actions(index, actions_sets[0], actions_sets[1]) for index, actions_sets in enumerate( zip(batch_path, batch_curr_actions)) ] enti_emb = th.cat( [action_set[0].unsqueeze(0) for action_set in all_action_set], 0) next_action_state = th.cat( [action_set[1].unsqueeze(0) for action_set in all_action_set], 0) print('next_action_state = ', next_action_state.shape) # inpu return next_action_state def _get_actions(self, index, curr_path, curr_actions): last_relation, curr_node_type, curr_node_id = curr_path[-1] entities_embs = [] relation_embs = [] for action_set in curr_actions: if action_set[0] == SELF_LOOP: next_node_type = curr_node_type else: next_node_type = self._get_next_node_type( curr_node_type, action_set[0], action_set[1]) enti_emb = self.kg_emb.lookup_emb( next_node_type, type_index=torch.LongTensor([action_set[1]]).to(self.device)) entities_embs.append(enti_emb) rela_emb = self.kg_emb.lookup_rela_emb(action_set[0]) relation_embs.append(rela_emb) pad_emb = self.kg_emb.lookup_rela_emb(PADDING) for _ in range(self.act_dim - len(entities_embs)): entities_embs.append(pad_emb) relation_embs.append(pad_emb) enti_emb = th.cat(entities_embs, 0) rela_emb = th.cat(relation_embs, 0) next_action_state = th.cat([enti_emb, rela_emb], -1) return [enti_emb, next_action_state] def _get_next_node_type(self, curr_node_type, next_relation, next_entity_id): pass def update(self, optimizer, env_model, device, ent_weight, step): if len(self.rewards) <= 0: del self.rewards[:] del self.saved_actions[:] del self.entropy[:] return 0.0, 0.0, 0.0 batch_rewards = np.vstack( self.rewards).T # numpy array of [bs, #steps] batch_rewards = torch.FloatTensor(batch_rewards).to(device) num_steps = batch_rewards.shape[1] for i in range(1, num_steps): batch_rewards[:, num_steps - i - 1] += self.gamma * batch_rewards[:, num_steps - i] actor_loss = 0 critic_loss = 0 entropy_loss = 0 for i in range(0, num_steps): log_prob, value = self.saved_actions[ i] # log_prob: Tensor of [bs, ], value: Tensor of [bs, 1] advantage = batch_rewards[:, i] - value.squeeze( 1) # Tensor of [bs, ] actor_loss += -log_prob * advantage.detach() # Tensor of [bs, ] critic_loss += advantage.pow(2) # Tensor of [bs, ] entropy_loss += -self.entropy[i] # Tensor of [bs, ] actor_loss = actor_loss.mean() critic_loss = critic_loss.mean() entropy_loss = entropy_loss.mean() loss = actor_loss + critic_loss + ent_weight * entropy_loss optimizer.zero_grad() loss.backward() # if step % 100 == 0: # # print('step = ', step) # plot_grad_flow_v2(self.named_parameters(), self.gradient_plot_save, step) optimizer.step() del self.rewards[:] del self.saved_actions[:] del self.entropy[:] return loss.item(), actor_loss.item(), critic_loss.item( ), entropy_loss.item() def _record_case_study(self): # print('self.cast_st_save = ', self.cast_st_save) eva_file = open(self.cast_st_save, "a") eva_file.write("*" * 50) eva_file.write('\n') eva_file.write('user = '******'user']) eva_file.write('\n') # print('path = ', self.current_step['path']) # input() eva_file.write(self.current_step['path']) eva_file.write('\n') eva_file.write("*" * 50) eva_file.write('\n') for ac_pro in self.current_step['actions_pro']: # print(ac_pro) eva_file.write(ac_pro) eva_file.write('\n') eva_file.write("next_action = " + self.current_step['next_acts']) eva_file.write('\n') eva_file.write("*" * 50) eva_file.close() def _record_case_study_az(self): # print('self.cast_st_save = ', self.cast_st_save) eva_file = open(self.cast_st_save, "a") eva_file.write("*" * 50) eva_file.write('\n') eva_file.write('user = '******'user']) eva_file.write('\n') eva_file.write(self.current_step['path']) eva_file.write('\n') eva_file.write("*" * 50) eva_file.write('\n') eva_file.write("querying result") eva_file.write('\n') for hop in range(self.p_hop): eva_file.write("*" * 50) eva_file.write('\n') eva_file.write("hop = " + str(hop)) eva_file.write('\n') tmp_list = [] for rn_step in range(self.reasoning_step): tmp_list.append(self.current_step[rn_step][hop]) for state_s in zip(*tmp_list): # print('state_s = ', list(state_s)) eva_file.write(' n_hop = '.join(list(state_s))) eva_file.write('\n') eva_file.write("*" * 50) eva_file.write('\n') for ac_pro in self.current_step['actions_pro']: # print(ac_pro) eva_file.write(ac_pro) eva_file.write('\n') eva_file.write("next_action = " + self.current_step['next_acts']) eva_file.write('\n') eva_file.write("*" * 50) eva_file.close()