def main(): # setup logger if args.resume_dir == "": date = str(datetime.datetime.now()) date = date[:date.rfind(":")].replace("-", "") \ .replace(":", "") \ .replace(" ", "_") log_dir = os.path.join(args.log_root, "log_" + date) else: log_dir = args.resume_dir hparams_file = os.path.join(log_dir, "hparams.json") checkpoints_dir = os.path.join(log_dir, "checkpoints") if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(checkpoints_dir): os.makedirs(checkpoints_dir) if args.resume_dir == "": # write hparams with open(hparams_file, "w") as f: json.dump(args.__dict__, f, indent=2) log_file = os.path.join(log_dir, "log_train.txt") logger = Logger(log_file) # logger.info(args) logger.info("The args corresponding to training process are: ") for (key, value) in vars(args).items(): logger.info("{key:20}: {value:}".format(key=key, value=value)) actor_critic = ActorCritic(args, log_dir, checkpoints_dir) actor_critic.train()
def main(): # setup logger log_dir = args.resume_dir hparams_file = os.path.join(log_dir, "hparams.json") checkpoints_dir = os.path.join(log_dir, "checkpoints") log_file = os.path.join(log_dir, "log_train.txt") logger = Logger(log_file) actor_critic = ActorCritic(args, log_dir, checkpoints_dir) actor_critic.evaluation()
def __init__( self, lr, gamma, k_epochs, eps_clip, n_j, n_m, num_layers, neighbor_pooling_type, input_dim, hidden_dim, num_mlp_layers_feature_extract, num_mlp_layers_actor, hidden_dim_actor, num_mlp_layers_critic, hidden_dim_critic, ): self.lr = lr self.gamma = gamma self.eps_clip = eps_clip self.k_epochs = k_epochs self.policy = ActorCritic( n_j=n_j, n_m=n_m, num_layers=num_layers, learn_eps=False, neighbor_pooling_type=neighbor_pooling_type, input_dim=input_dim, hidden_dim=hidden_dim, num_mlp_layers_feature_extract=num_mlp_layers_feature_extract, num_mlp_layers_actor=num_mlp_layers_actor, hidden_dim_actor=hidden_dim_actor, num_mlp_layers_critic=num_mlp_layers_critic, hidden_dim_critic=hidden_dim_critic, device=device) self.policy_old = deepcopy(self.policy) '''self.policy.load_state_dict( torch.load(path='./{}.pth'.format(str(n_j) + '_' + str(n_m) + '_' + str(1) + '_' + str(99))))''' self.policy_old.load_state_dict(self.policy.state_dict()) self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr) self.scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, step_size=configs.decay_step_size, gamma=configs.decay_ratio) self.V_loss_2 = nn.MSELoss()
#print("Cuda: " + str(torch.cuda.is_available())) if __name__ == '__main__': os.environ['OMP_NUM_THREADS'] = '1' torch.cuda.empty_cache() args = parser.parse_args() SAVEPATH = os.getcwd( ) + '/save/scme_' + args.reward_type + '/mario_a3c_params.pkl' if not os.path.exists(os.getcwd() + '/save/scme_' + args.reward_type): os.makedirs(os.getcwd() + '/save/scme_' + args.reward_type) env = create_mario_env(args.env_name, args.reward_type) shared_model = ActorCritic(env.observation_space.shape[0], len(ACTIONS)) shared_model.share_memory() shared_scme = SCME(env.observation_space.shape[0], len(ACTIONS)) shared_scme.share_memory() if os.path.isfile(SAVEPATH): print('Loading A3C parametets & SCME parameters...') shared_model.load_state_dict(torch.load(SAVEPATH)) shared_scme.load_state_dict(torch.load(SAVEPATH[:-4] + '_scme.pkl')) torch.manual_seed(args.seed) #optimizer = torch.optim.Adam(list(shared_model.parameters()) + list(shared_scme.parameters()), lr=args.lr) optimizer = SharedAdam(list(shared_model.parameters()) + list(shared_scme.parameters()),
def train(rank, args, shared_model, counter, lock, optimizer=None, select_sample=True): torch.manual_seed(args.seed + rank) print("Process No : {} | Sampling : {}".format(rank, select_sample)) FloatTensor = torch.cuda.FloatTensor if args.use_cuda else torch.FloatTensor DoubleTensor = torch.cuda.DoubleTensor if args.use_cuda else torch.DoubleTensor ByteTensor = torch.cuda.ByteTensor if args.use_cuda else torch.ByteTensor env = create_mario_env(args.env_name) env.seed(args.seed + rank) model = ActorCritic(env.observation_space.shape[0], len(ACTIONS)) if args.use_cuda: model.cuda() if optimizer is None: optimizer = optim.Adam(shared_model.parameters(), lr=args.lr) model.train() state = env.reset() state = torch.from_numpy(state) done = True episode_length = 0 for num_iter in count(): if rank == 0: env.render() if num_iter % args.save_interval == 0 and num_iter > 0: print("Saving model at :" + args.save_path) torch.save(shared_model.state_dict(), args.save_path) if num_iter % ( args.save_interval * 2.5 ) == 0 and num_iter > 0 and rank == 1: # Second saver in-case first processes crashes print("Saving model for process 1 at :" + args.save_path) torch.save(shared_model.state_dict(), args.save_path) # Sync with the shared model model.load_state_dict(shared_model.state_dict()) if done: cx = Variable(torch.zeros(1, 512)).type(FloatTensor) hx = Variable(torch.zeros(1, 512)).type(FloatTensor) else: cx = Variable(cx.data).type(FloatTensor) hx = Variable(hx.data).type(FloatTensor) values = [] log_probs = [] rewards = [] entropies = [] reason = '' for step in range(args.num_steps): episode_length += 1 state_inp = Variable(state.unsqueeze(0)).type(FloatTensor) value, logit, (hx, cx) = model((state_inp, (hx, cx))) prob = F.softmax(logit, dim=-1) log_prob = F.log_softmax(logit, dim=-1) entropy = -(log_prob * prob).sum(-1, keepdim=True) entropies.append(entropy) if select_sample: action = prob.multinomial().data else: action = prob.max(-1, keepdim=True)[1].data log_prob = log_prob.gather(-1, Variable(action)) action_out = ACTIONS[action][0, 0] # print("Process: {} Action: {}".format(rank, str(action_out))) state, reward, done, _ = env.step(action_out) done = done or episode_length >= args.max_episode_length reward = max(min(reward, 50), -50) with lock: counter.value += 1 if done: episode_length = 0 env.change_level(0) state = env.reset() print("Process {} has completed.".format(rank)) env.locked_levels = [False] + [True] * 31 state = torch.from_numpy(state) values.append(value) log_probs.append(log_prob) rewards.append(reward) if done: break R = torch.zeros(1, 1) if not done: state_inp = Variable(state.unsqueeze(0)).type(FloatTensor) value, _, _ = model((state_inp, (hx, cx))) R = value.data values.append(Variable(R).type(FloatTensor)) policy_loss = 0 value_loss = 0 R = Variable(R).type(FloatTensor) gae = torch.zeros(1, 1).type(FloatTensor) for i in reversed(range(len(rewards))): R = args.gamma * R + rewards[i] advantage = R - values[i] value_loss = value_loss + 0.5 * advantage.pow(2) # Generalized Advantage Estimataion delta_t = rewards[i] + args.gamma * \ values[i + 1].data - values[i].data gae = gae * args.gamma * args.tau + delta_t policy_loss = policy_loss - \ log_probs[i] * Variable(gae).type(FloatTensor) - args.entropy_coef * entropies[i] total_loss = policy_loss + args.value_loss_coef * value_loss print("Process {} loss :".format(rank), total_loss.data) # print("Process: {} Episode: {}".format(rank, str(episode_length))) optimizer.zero_grad() (total_loss).backward() torch.nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm) ensure_shared_grads(model, shared_model) optimizer.step() print("Process {} closed.".format(rank))
def test(rank, args, shared_model, counter): torch.manual_seed(args.seed + rank) FloatTensor = torch.cuda.FloatTensor if args.use_cuda else torch.FloatTensor DoubleTensor = torch.cuda.DoubleTensor if args.use_cuda else torch.DoubleTensor ByteTensor = torch.cuda.ByteTensor if args.use_cuda else torch.ByteTensor env = create_mario_env(args.env_name) """ need to implement Monitor wrapper with env.change_level """ # expt_dir = 'video' # env = wrappers.Monitor(env, expt_dir, force=True, video_callable=lambda count: count % 10 == 0) env.seed(args.seed + rank) model = ActorCritic(env.observation_space.shape[0], len(ACTIONS)) if args.use_cuda: model.cuda() model.eval() state = env.reset() state = torch.from_numpy(state) reward_sum = 0 done = True savefile = os.getcwd() + '/save/mario_curves.csv' title = ['Time', 'No. Steps', 'Total Reward', 'Episode Length'] with open(savefile, 'a', newline='') as sfile: writer = csv.writer(sfile) writer.writerow(title) start_time = time.time() # a quick hack to prevent the agent from stucking actions = deque(maxlen=4000) episode_length = 0 while True: episode_length += 1 ep_start_time = time.time() # Sync with the shared model if done: model.load_state_dict(shared_model.state_dict()) cx = Variable(torch.zeros(1, 512), volatile=True).type(FloatTensor) hx = Variable(torch.zeros(1, 512), volatile=True).type(FloatTensor) else: cx = Variable(cx.data, volatile=True).type(FloatTensor) hx = Variable(hx.data, volatile=True).type(FloatTensor) state_inp = Variable(state.unsqueeze(0), volatile=True).type(FloatTensor) value, logit, (hx, cx) = model((state_inp, (hx, cx))) prob = F.softmax(logit, dim=-1) action = prob.max(-1, keepdim=True)[1].data action_out = ACTIONS[action][0, 0] # print("Process: Test Action: {}".format(str(action_out))) state, reward, done, _ = env.step(action_out) env.render() done = done or episode_length >= args.max_episode_length reward_sum += reward # a quick hack to prevent the agent from stucking actions.append(action[0, 0]) if actions.count(actions[0]) == actions.maxlen: done = True if done: print( "Time {}, num steps {}, FPS {:.0f}, episode reward {}, episode length {}" .format( time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start_time)), counter.value, counter.value / (time.time() - start_time), reward_sum, episode_length)) data = [ time.time() - ep_start_time, counter.value, reward_sum, episode_length ] with open(savefile, 'a', newline='') as sfile: writer = csv.writer(sfile) writer.writerows([data]) reward_sum = 0 episode_length = 0 actions.clear() time.sleep(60) env.locked_levels = [False] + [True] * 31 env.change_level(0) state = env.reset() state = torch.from_numpy(state)
def train(rank, args, shared_model, shared_scme, counter, lock, optimizer=None, select_sample=True): torch.manual_seed(args.seed + rank) print("Process No : {} | Sampling : {}".format(rank, select_sample)) FloatTensor = torch.FloatTensor# torch.cuda.FloatTensor if args.use_cuda else torch.FloatTensor DoubleTensor = torch.DoubleTensor# torch.cuda.DoubleTensor if args.use_cuda else torch.DoubleTensor ByteTensor = torch.ByteTensor# torch.cuda.ByteTensor if args.use_cuda else torch.ByteTensor savefile = os.getcwd() + '/save/scmemi_'+ args.reward_type +'/train_reward.csv' saveweights = os.getcwd() + '/save/scmemi_'+ args.reward_type +'/mario_a3c_params.pkl' env = create_mario_env(args.env_name, args.reward_type) #env.seed(args.seed + rank) model = ActorCritic(env.observation_space.shape[0], len(ACTIONS)) if optimizer is None: optimizer = optim.Adam(list(shared_model.parameters()) + list(shared_scme.parameters()), lr=args.lr) scme_model = SCME(env.observation_space.shape[0], len(ACTIONS)) model.train() scme_model.train() state = env.reset() cum_rew = 0 state = torch.from_numpy(state) done = True episode_length = 0 for num_iter in count(): #env.render() if rank == 0: if num_iter % args.save_interval == 0 and num_iter > 0: print ("Saving model at :" + saveweights) torch.save(shared_model.state_dict(), saveweights) torch.save(shared_scme.state_dict(), saveweights[:-4] + '_scme.pkl') if num_iter % (args.save_interval * 2.5) == 0 and num_iter > 0 and rank == 1: # Second saver in-case first processes crashes print ("Saving model for process 1 at :" + saveweights) torch.save(shared_model.state_dict(), saveweights) torch.save(shared_scme.state_dict(), saveweights[:-4] + '_scme.pkl') # Sync with the shared model model.load_state_dict(shared_model.state_dict()) scme_model.load_state_dict(shared_scme.state_dict()) if done: cx = Variable(torch.zeros(1, 512)).type(FloatTensor) hx = Variable(torch.zeros(1, 512)).type(FloatTensor) else: cx = Variable(cx.data).type(FloatTensor) hx = Variable(hx.data).type(FloatTensor) values = [] log_probs = [] rewards = [] entropies = [] vae_losses = [] cur_losses = [] mi_losses = [] #reason ='' for step in range(args.num_steps): episode_length += 1 state_inp = Variable(state.unsqueeze(0)).type(FloatTensor) value, logit, (hx, cx) = model((state_inp, (hx, cx))) prob = F.softmax(logit, dim=-1) log_prob = F.log_softmax(logit, dim=-1) entropy = -(log_prob * prob).sum(-1, keepdim=True) entropies.append(entropy) if select_sample: action = prob.multinomial(1).data else: action = prob.max(-1, keepdim=True)[1].data log_prob = log_prob.gather(-1, Variable(action)) action_out = int(action[0, 0].data.numpy()) state, reward, done, info = env.step(action_out) cum_rew = cum_rew + reward action_one_hot = (torch.eye(len(ACTIONS))[action_out]).view(1,-1) next_state_inp = Variable(torch.from_numpy(state).unsqueeze(0)).type(FloatTensor) pred_z, mi, mi1, actual_z, xt1_hat, xt1, xt1_mu, xt1_logvar = scme_model((state_inp, next_state_inp, action_one_hot)) vae_loss = loss_function(xt1_hat, xt1, xt1_mu, xt1_logvar) cur_loss = ((pred_z - actual_z).pow(2)).sum(-1, keepdim=True)/2/50 mi_loss = mutual(mi, mi1).sum(-1, keepdim=True)/10 done = done or episode_length >= args.max_episode_length cur_reward = (args.alpha*cur_loss).data.numpy()[0,0] mi_reward = (args.beta*mi_loss).data.numpy() reward = cur_reward + reward + mi_reward reward = max(min(reward, 50), -5) with lock: counter.value += 1 if done: episode_length = 0 # env.change_level(0) state = env.reset() with open(savefile[:-4]+'_{}.csv'.format(rank), 'a', newline='') as sfile: writer = csv.writer(sfile) writer.writerows([[cum_rew, info['x_pos']/x_norm]]) cum_rew = 0 # print ("Process {} has completed.".format(rank)) # env.locked_levels = [False] + [True] * 31 state = torch.from_numpy(state) values.append(value) log_probs.append(log_prob) rewards.append(reward) vae_losses.append(vae_loss) cur_losses.append(cur_loss) mi_losses.append(mi_loss) if done: break R = torch.zeros(1, 1) if not done: state_inp = Variable(state.unsqueeze(0)).type(FloatTensor) value, _, _ = model((state_inp, (hx, cx))) R = value.data values.append(Variable(R).type(FloatTensor)) policy_loss = 0 value_loss = 0 scme_loss = 0 R = Variable(R).type(FloatTensor) gae = torch.zeros(1, 1).type(FloatTensor) for i in reversed(range(len(rewards))): R = args.gamma * R + rewards[i] advantage = R - values[i] value_loss = value_loss + 0.5 * advantage.pow(2) # Generalized Advantage Estimataion delta_t = rewards[i] + args.gamma * values[i + 1].data - values[i].data gae = gae * args.gamma * args.tau + delta_t policy_loss = policy_loss - log_probs[i] * Variable(gae).type(FloatTensor) - args.entropy_coef * entropies[i] scme_loss = 0.01*vae_losses[i] + cur_losses[i] - mi_losses[i] total_loss = args.lambd*(policy_loss + args.value_loss_coef * value_loss) # print ("Process {} loss :".format(rank), total_loss.data) optimizer.zero_grad() # cur_optimizer.zero_grad() (total_loss + scme_loss).backward() # (curiosity_loss).backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) torch.nn.utils.clip_grad_norm_(scme_model.parameters(), args.max_grad_norm) ensure_shared_grads(model, shared_model) ensure_shared_grads(scme_model, shared_scme) optimizer.step()
def test(rank, args, shared_model, counter): torch.manual_seed(args.seed + rank) FloatTensor = torch.FloatTensor# torch.cuda.FloatTensor if args.use_cuda else torch.FloatTensor DoubleTensor = torch.DoubleTensor# torch.cuda.DoubleTensor if args.use_cuda else torch.DoubleTensor ByteTensor = torch.ByteTensor# torch.cuda.ByteTensor if args.use_cuda else torch.ByteTensor env = create_mario_env(args.env_name, args.reward_type) """ need to implement Monitor wrapper with env.change_level """ # expt_dir = 'video' # env = wrappers.Monitor(env, expt_dir, force=True, video_callable=lambda count: count % 10 == 0) #env.seed(args.seed + rank) model = ActorCritic(env.observation_space.shape[0], len(ACTIONS)) model.eval() state = env.reset() state = torch.from_numpy(state) reward_sum = 0 done = True savefile = os.getcwd() + '/save/scmemi_'+ args.reward_type +'/mario_curves.csv' title = ['Time','No. Steps', 'Total Reward', 'final_position', 'Episode Length'] with open(savefile, 'a', newline='') as sfile: writer = csv.writer(sfile) writer.writerow(title) start_time = time.time() # a quick hack to prevent the agent from stucking actions = deque(maxlen=400) positions = deque(maxlen=400) episode_length = 0 while True: episode_length += 1 ep_start_time = time.time() # Sync with the shared model if done: model.load_state_dict(shared_model.state_dict()) cx = Variable(torch.zeros(1, 512), requires_grad=True ).type(FloatTensor) with torch.no_grad(): cx=cx hx = Variable(torch.zeros(1, 512), requires_grad=True).type(FloatTensor) with torch.no_grad(): hx=hx else: with torch.no_grad(): cx = Variable(cx.data).type(FloatTensor) hx = Variable(hx.data).type(FloatTensor) with torch.no_grad(): state_inp = Variable(state.unsqueeze(0)).type(FloatTensor) value, logit, (hx, cx) = model((state_inp, (hx, cx))) prob = F.softmax(logit, dim=-1) action = prob.max(-1, keepdim=True)[1].data action_out = int(action[0, 0].data.numpy()) state, reward, done, info = env.step(action_out) #env.render() done = done or episode_length >= args.max_episode_length reward_sum += reward # a quick hack to prevent the agent from stucking actions.append(action[0, 0]) if actions.count(actions[0]) == actions.maxlen: done = True print('action') if args.pos_stuck : positions.append(info['x_pos']) pos_ar = np.array(positions) if (len(positions) >= 200) and (pos_ar < pos_ar[-1] + 20).all() and (pos_ar > pos_ar[-1] - 20).all(): done = True if done: print("Time {}, num steps {}, FPS {:.0f}, episode reward {:.3f}, distance covered {:.3f}, episode length {}".format( time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start_time)), counter.value, counter.value / (time.time() - start_time), reward_sum, info['x_pos']/x_norm, episode_length)) data = [time.time() - ep_start_time, counter.value, reward_sum, info['x_pos']/x_norm, episode_length] with open(savefile, 'a', newline='') as sfile: writer = csv.writer(sfile) writer.writerows([data]) reward_sum = 0 episode_length = 0 actions.clear() positions.clear() time.sleep(60) # env.locked_levels = [False] + [True] * 31 # env.change_level(0) state = env.reset() state = torch.from_numpy(state)
help='model save interval (default: {})'.format(SAVEPATH)) parser.add_argument('--non-sample', type=int,default=1, help='number of non sampling processes (default: 1)') mp = _mp.get_context('spawn') print("Cuda: " + str(torch.cuda.is_available())) if __name__ == '__main__': os.environ['OMP_NUM_THREADS'] = '1' args = parser.parse_args() env = create_mario_env(args.env_name) shared_model = ActorCritic( env.observation_space.shape[0], len(COMPLEX_MOVEMENT)) if args.use_cuda: shared_model.cuda() shared_model.share_memory() if os.path.isfile(args.save_path): print('Loading A3C parametets ...') shared_model.load_state_dict(torch.load(args.save_path, map_location='cpu')) torch.manual_seed(args.seed) optimizer = SharedAdam(shared_model.parameters(), lr=args.lr) optimizer.share_memory() print (color.BLUE + "No of available cores : {}".format(mp.cpu_count()) + color.END)
parser.add_argument('--non-sample', type=int, default=2, help='number of non sampling processes (default: 2)') mp = _mp.get_context('spawn') print("Cuda: " + str(torch.cuda.is_available())) if __name__ == '__main__': os.environ['OMP_NUM_THREADS'] = '1' args = parser.parse_args() env = create_mario_env(args.env_name) shared_model = ActorCritic(env.observation_space.shape[0], len(ACTIONS)) if args.use_cuda: shared_model.cuda() shared_model.share_memory() if os.path.isfile(args.save_path): print('Loading A3C parametets ...') shared_model.load_state_dict(torch.load(args.save_path)) torch.manual_seed(args.seed) optimizer = SharedAdam(shared_model.parameters(), lr=args.lr) optimizer.share_memory() print(color.BLUE + "No of available cores : {}".format(mp.cpu_count()) +
class PPO: def __init__( self, lr, gamma, k_epochs, eps_clip, n_j, n_m, num_layers, neighbor_pooling_type, input_dim, hidden_dim, num_mlp_layers_feature_extract, num_mlp_layers_actor, hidden_dim_actor, num_mlp_layers_critic, hidden_dim_critic, ): self.lr = lr self.gamma = gamma self.eps_clip = eps_clip self.k_epochs = k_epochs self.policy = ActorCritic( n_j=n_j, n_m=n_m, num_layers=num_layers, learn_eps=False, neighbor_pooling_type=neighbor_pooling_type, input_dim=input_dim, hidden_dim=hidden_dim, num_mlp_layers_feature_extract=num_mlp_layers_feature_extract, num_mlp_layers_actor=num_mlp_layers_actor, hidden_dim_actor=hidden_dim_actor, num_mlp_layers_critic=num_mlp_layers_critic, hidden_dim_critic=hidden_dim_critic, device=device) self.policy_old = deepcopy(self.policy) '''self.policy.load_state_dict( torch.load(path='./{}.pth'.format(str(n_j) + '_' + str(n_m) + '_' + str(1) + '_' + str(99))))''' self.policy_old.load_state_dict(self.policy.state_dict()) self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr) self.scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, step_size=configs.decay_step_size, gamma=configs.decay_ratio) self.V_loss_2 = nn.MSELoss() def update(self, memories, n_tasks, g_pool): vloss_coef = configs.vloss_coef ploss_coef = configs.ploss_coef entloss_coef = configs.entloss_coef rewards_all_env = [] adj_mb_t_all_env = [] fea_mb_t_all_env = [] candidate_mb_t_all_env = [] mask_mb_t_all_env = [] a_mb_t_all_env = [] old_logprobs_mb_t_all_env = [] # store data for all env for i in range(len(memories)): rewards = [] discounted_reward = 0 for reward, is_terminal in zip(reversed(memories[i].r_mb), reversed(memories[i].done_mb)): if is_terminal: discounted_reward = 0 discounted_reward = reward + (self.gamma * discounted_reward) rewards.insert(0, discounted_reward) rewards = torch.tensor(rewards, dtype=torch.float).to(device) rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5) rewards_all_env.append(rewards) # process each env data adj_mb_t_all_env.append( aggr_obs(torch.stack(memories[i].adj_mb).to(device), n_tasks)) fea_mb_t = torch.stack(memories[i].fea_mb).to(device) fea_mb_t = fea_mb_t.reshape(-1, fea_mb_t.size(-1)) fea_mb_t_all_env.append(fea_mb_t) candidate_mb_t_all_env.append( torch.stack(memories[i].candidate_mb).to(device).squeeze()) mask_mb_t_all_env.append( torch.stack(memories[i].mask_mb).to(device).squeeze()) a_mb_t_all_env.append( torch.stack(memories[i].a_mb).to(device).squeeze()) old_logprobs_mb_t_all_env.append( torch.stack( memories[i].logprobs).to(device).squeeze().detach()) # get batch argument for net forwarding: mb_g_pool is same for all env mb_g_pool = g_pool_cal( g_pool, torch.stack(memories[0].adj_mb).to(device).shape, n_tasks, device) # Optimize policy for K epochs: for _ in range(self.k_epochs): loss_sum = 0 vloss_sum = 0 for i in range(len(memories)): pis, vals = self.policy(x=fea_mb_t_all_env[i], graph_pool=mb_g_pool, adj=adj_mb_t_all_env[i], candidate=candidate_mb_t_all_env[i], mask=mask_mb_t_all_env[i], padded_nei=None) logprobs, ent_loss = eval_actions(pis.squeeze(), a_mb_t_all_env[i]) ratios = torch.exp(logprobs - old_logprobs_mb_t_all_env[i].detach()) advantages = rewards_all_env[i] - vals.detach() surr1 = ratios * advantages surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages v_loss = self.V_loss_2(vals.squeeze(), rewards_all_env[i]) p_loss = -torch.min(surr1, surr2) ent_loss = -ent_loss.clone() loss = vloss_coef * v_loss + ploss_coef * p_loss + entloss_coef * ent_loss loss_sum += loss vloss_sum += v_loss self.optimizer.zero_grad() loss_sum.mean().backward() self.optimizer.step() # Copy new weights into old policy: self.policy_old.load_state_dict(self.policy.state_dict()) if configs.decayflag: self.scheduler.step() return loss_sum.mean().item(), vloss_sum.mean().item()