def sac_opp( global_ac, global_ac_targ, global_cpc, rank, T, E, args, scores, wins, buffer, device=None, tensorboard_dir=None, ): torch.manual_seed(args.seed + rank) np.random.seed(args.seed + rank) # writer = GlobalSummaryWriter.getSummaryWriter() tensorboard_dir = os.path.join(tensorboard_dir, str(rank)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) writer = SummaryWriter(log_dir=tensorboard_dir) if args.exp_name == "test": env = gym.make("CartPole-v0") elif args.non_station: env = make_ftg_ram_nonstation(args.env, p2_list=args.list, total_episode=args.station_rounds, stable=args.stable) else: env = make_ftg_ram(args.env, p2=args.p2) obs_dim = env.observation_space.shape[0] act_dim = env.action_space.n print("set up child process env") local_ac = MLPActorCritic(obs_dim + args.c_dim, act_dim, **dict(hidden_sizes=[args.hid] * args.l)).to(device) local_ac.load_state_dict(global_ac.state_dict()) print("local ac load global ac") # Freeze target networks with respect to optimizers (only update via polyak averaging) # Async Version for p in global_ac_targ.parameters(): p.requires_grad = False # Experience buffer if args.cpc: replay_buffer = ReplayBufferOppo(obs_dim=obs_dim, max_size=args.replay_size, encoder=global_cpc) else: replay_buffer = ReplayBuffer(obs_dim=obs_dim, size=args.replay_size) # Entropy Tuning target_entropy = -torch.prod( torch.Tensor(env.action_space.shape).to( device)).item() # heuristic value from the paper alpha = max(local_ac.log_alpha.exp().item(), args.min_alpha) if not args.fix_alpha else args.min_alpha # Set up optimizers for policy and q-function # Async Version pi_optimizer = Adam(global_ac.pi.parameters(), lr=args.lr, eps=1e-4) q1_optimizer = Adam(global_ac.q1.parameters(), lr=args.lr, eps=1e-4) q2_optimizer = Adam(global_ac.q2.parameters(), lr=args.lr, eps=1e-4) cpc_optimizer = Adam(global_cpc.parameters(), lr=args.lr, eps=1e-4) alpha_optim = Adam([global_ac.log_alpha], lr=args.lr, eps=1e-4) # Prepare for interaction with environment o, ep_ret, ep_len = env.reset(), 0, 0 if args.cpc: c_hidden = global_cpc.init_hidden(1, args.c_dim, use_gpu=args.cuda) c1, c_hidden = global_cpc.predict(o, c_hidden) assert len(c1.shape) == 3 c1 = c1.flatten().cpu().numpy() all_embeddings = [] meta = [] trajectory = list() p2 = env.p2 p2_list = [str(p2)] discard = False uncertainties = [] local_t, local_e = 0, 0 t = T.value() e = E.value() glod_input = list() glod_target = list() # Main loop: collect experience in env and update/log each epoch while e <= args.episode: with torch.no_grad(): # Until start_steps have elapsed, randomly sample actions # from a uniform distribution for better exploration. Afterwards, # use the learned policy. if t > args.start_steps: if args.cpc: a = local_ac.get_action(np.concatenate((o, c1), axis=0), device=device) a_prob = local_ac.act( torch.as_tensor(np.expand_dims(np.concatenate((o, c1), axis=0), axis=0), dtype=torch.float32, device=device)) else: a = local_ac.get_action(o, greedy=True, device=device) a_prob = local_ac.act( torch.as_tensor(np.expand_dims(o, axis=0), dtype=torch.float32, device=device)) else: a = env.action_space.sample() a_prob = local_ac.act( torch.as_tensor(np.expand_dims(o, axis=0), dtype=torch.float32, device=device)) uncertainty = ood_scores(a_prob).item() # Step the env o2, r, d, info = env.step(a) if info.get('no_data_receive', False): discard = True ep_ret += r ep_len += 1 # Ignore the "done" signal if it comes from hitting the time # horizon (that is, when it's an artificial terminal signal # that isn't based on the agent's state) d = False if (ep_len == args.max_ep_len) or discard else d glod_input.append(o), glod_target.append(a) if args.cpc: # changed the trace structure for further analysis c2, c_hidden = global_cpc.predict(o2, c_hidden) assert len(c2.shape) == 3 c2 = c2.flatten().cpu().numpy() replay_buffer.store(np.concatenate((o, c1), axis=0), a, r, np.concatenate((o2, c2), axis=0), d) trajectory.append([o, a, r, o2, d, c1, c2, ep_len]) all_embeddings.append(c1) meta.append([env.p2, local_e, ep_len, r, a, uncertainty]) c1 = c2 trajectory.append([o, a, r, o2, d, c1, c2]) else: replay_buffer.store(o, a, r, o2, d) # Super critical, easy to overlook step: make sure to update # most recent observation! o = o2 T.increment() t = T.value() local_t += 1 # End of trajectory handling if d or (ep_len == args.max_ep_len) or discard: replay_buffer.store(trajectory) E.increment() e = E.value() local_e += 1 # logger.store(EpRet=ep_ret, EpLen=ep_len) if info.get('win', False): wins.append(1) else: wins.append(0) scores.append(ep_ret) m_score = np.mean(scores[-100:]) win_rate = np.mean(wins[-100:]) print( "Process {}, opponent:{}, # of global_episode :{}, # of global_steps :{}, round score: {}, mean score : {:.1f}, win_rate:{}, steps: {}, alpha: {}" .format(rank, args.p2, e, t, ep_ret, m_score, win_rate, ep_len, alpha)) writer.add_scalar("metrics/round_score", ep_ret, e) writer.add_scalar("metrics/mean_score", m_score.item(), e) writer.add_scalar("metrics/win_rate", win_rate.item(), e) writer.add_scalar("metrics/round_step", ep_len, e) writer.add_scalar("metrics/alpha", alpha, e) # CPC update handing if local_e > args.batch_size and local_e % args.update_every == 0 and args.cpc: data, indexes, min_len = replay_buffer.sample_traj( args.batch_size) global_cpc.train() cpc_optimizer.zero_grad() c_hidden = global_cpc.init_hidden(len(data), args.c_dim, use_gpu=args.cuda) acc, loss, latents = global_cpc(data, c_hidden) replay_buffer.update_latent(indexes, min_len, latents.detach()) loss.backward() # add gradient clipping nn.utils.clip_grad_norm_(global_cpc.parameters(), 20) cpc_optimizer.step() writer.add_scalar("training/acc", acc, e) writer.add_scalar("training/cpc_loss", loss.detach().item(), e) all_embeddings = np.array(all_embeddings) writer.add_embedding(mat=all_embeddings, metadata=meta, metadata_header=[ "opponent", "round", "step", "reward", "action", "uncertainty" ]) c_hidden = global_cpc.init_hidden(1, args.c_dim, use_gpu=args.cuda) o, ep_ret, ep_len = env.reset(), 0, 0 trajectory = list() discard = False # OOD update stage if (t >= args.ood_update_step and local_t % args.ood_update_step == 0 or replay_buffer.is_full()) and args.ood: # used all the data collected from the last args.ood_update_steps as the train data print("Conduct OOD updating") ood_train = (glod_input, glod_target) glod_model = convert_to_glod(global_ac.pi, train_loader=ood_train, hidden_dim=args.hid, act_dim=act_dim, device=device) glod_scores = retrieve_scores( glod_model, replay_buffer.obs_buf[:replay_buffer.size], device=device, k=args.ood_K) glod_scores = glod_scores.detach().cpu().numpy() print(len(glod_scores)) writer.add_histogram(values=glod_scores, max_bins=300, global_step=local_t, tag="OOD") drop_points = np.percentile( a=glod_scores, q=[args.ood_drop_lower, args.ood_drop_upper]) lower, upper = drop_points[0], drop_points[1] print(lower, upper) mask = np.logical_and((glod_scores >= lower), (glod_scores <= upper)) reserved_indexes = np.argwhere(mask).flatten() print(len(reserved_indexes)) if len(reserved_indexes) > 0: replay_buffer.ood_drop(reserved_indexes) glod_input = list() glod_target = list() # SAC Update handling if local_t >= args.update_after and local_t % args.update_every == 0: for j in range(args.update_every): batch = replay_buffer.sample_trans(batch_size=args.batch_size, device=device) # First run one gradient descent step for Q1 and Q2 q1_optimizer.zero_grad() q2_optimizer.zero_grad() loss_q = local_ac.compute_loss_q(batch, global_ac_targ, args.gamma, alpha) loss_q.backward() # Next run one gradient descent step for pi. pi_optimizer.zero_grad() loss_pi, entropy = local_ac.compute_loss_pi(batch, alpha) loss_pi.backward() alpha_optim.zero_grad() alpha_loss = -(local_ac.log_alpha * (entropy + target_entropy).detach()).mean() alpha_loss.backward(retain_graph=False) alpha = max( local_ac.log_alpha.exp().item(), args.min_alpha) if not args.fix_alpha else args.min_alpha nn.utils.clip_grad_norm_(local_ac.parameters(), 20) for global_param, local_param in zip(global_ac.parameters(), local_ac.parameters()): global_param._grad = local_param.grad pi_optimizer.step() q1_optimizer.step() q2_optimizer.step() alpha_optim.step() state_dict = global_ac.state_dict() local_ac.load_state_dict(state_dict) # Finally, update target networks by polyak averaging. with torch.no_grad(): for p, p_targ in zip(global_ac.parameters(), global_ac_targ.parameters()): p_targ.data.copy_((1 - args.polyak) * p.data + args.polyak * p_targ.data) writer.add_scalar("training/pi_loss", loss_pi.detach().item(), t) writer.add_scalar("training/q_loss", loss_q.detach().item(), t) writer.add_scalar("training/alpha_loss", alpha_loss.detach().item(), t) writer.add_scalar("training/entropy", entropy.detach().mean().item(), t) if t % args.save_freq == 0 and t > 0: torch.save( global_ac.state_dict(), os.path.join(args.save_dir, args.exp_name, args.model_para)) torch.save( global_cpc.state_dict(), os.path.join(args.save_dir, args.exp_name, args.cpc_para)) state_dict_trans( global_ac.state_dict(), os.path.join(args.save_dir, args.exp_name, args.numpy_para)) torch.save((e, t, list(scores), list(wins)), os.path.join(args.save_dir, args.exp_name, args.train_indicator)) print("Saving model at episode:{}".format(t))
def train(global_model, rank, T, scores): env = make_ftg_ram(env_name, p2=p2) state_shape = env.observation_space.shape[0] action_shape = env.action_space.n local_model = ActorCritic(state_shape, action_shape, hidden_size) local_model.load_state_dict(global_model.state_dict()) # MODEL_STATE = "/home/byron/Repos/FTG4.50/OpenAI/ByronAI.numpy" # test_model = ActorCriticNumpy(MODEL_STATE) optimizer = optim.Adam(global_model.parameters(), lr=learning_rate) while True: discard = False done = False s = env.reset() score = 0 sum_entropy = 0 step = 0 while not done: s_lst, a_lst, r_lst = [], [], [] for t in range(update_interval): prob = local_model.pi(torch.from_numpy(s).float()) # test_prob = test_model.pi(s) # diff = prob.detach().numpy() - test_prob # print(diff.sum()) m = Categorical(prob) a = m.sample().item() s_prime, r, done, info = env.step(a) if info.get('no_data_receive', False): discard = True break s_lst.append(s) a_lst.append([a]) # r_lst.append(r/100.0) r_lst.append(r) s = s_prime score += r sum_entropy += Categorical(probs=prob.detach()).entropy() step += 1 if done: break if discard: break s_final = torch.tensor(s_prime, dtype=torch.float) R = 0.0 if done else local_model.v(s_final).item() td_target_lst = [] for reward in r_lst[::-1]: R = gamma * R + reward td_target_lst.append([R]) td_target_lst.reverse() s_batch, a_batch, td_target = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \ torch.tensor(td_target_lst) advantage = td_target - local_model.v(s_batch) pi = local_model.pi(s_batch, softmax_dim=1) pi_a = pi.gather(1, a_batch) loss = -torch.log(pi_a) * advantage.detach() + \ F.smooth_l1_loss(local_model.v(s_batch), td_target.detach()) - \ (entropy_weight * -(torch.log(pi) * pi).sum()) optimizer.zero_grad() loss.mean().backward() for global_param, local_param in zip(global_model.parameters(), local_model.parameters()): global_param._grad = local_param.grad optimizer.step() local_model.load_state_dict(global_model.state_dict()) if discard: continue T.increment() t = T.value() scores.append(score) m_score = np.mean(scores[-100:]) print( "Process {}, # of episode :{}, round score: {}, mean score : {:.1f}, entropy: {}, steps: {}" .format(rank, t, score, m_score, sum_entropy / step, step)) if t % save_interval == 0 and t > 0: torch.save(global_model.state_dict(), os.path.join(save_dir, "model")) state_dict_trans(global_model.state_dict(), os.path.join(save_dir, numpy_para)) print("Saving model at episode:{}".format(t))
def sac( global_ac, global_ac_targ, rank, T, E, args, scores, wins, buffer_q, device=None, tensorboard_dir=None, ): torch.manual_seed(args.seed + rank) np.random.seed(args.seed + rank) # writer = GlobalSummaryWriter.getSummaryWriter() tensorboard_dir = os.path.join(tensorboard_dir, str(rank)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) writer = SummaryWriter(log_dir=tensorboard_dir) # if args.exp_name == "test": # env = gym.make("CartPole-v0") # elif args.non_station: # env = make_ftg_ram_nonstation(args.env, p2_list=args.list, total_episode=args.station_rounds,stable=args.stable) # else: # env = make_ftg_ram(args.env, p2=args.p2) # obs_dim = env.observation_space.shape[0] # act_dim = env.action_space.n env = Soccer() # env = gym.make("CartPole-v0") obs_dim = env.n_features act_dim = env.n_actions print("set up child process env") local_ac = MLPActorCritic(obs_dim, act_dim, **dict(hidden_sizes=[args.hid] * args.l)).to(device) state_dict = global_ac.state_dict() local_ac.load_state_dict(state_dict) print("local ac load global ac") # Freeze target networks with respect to optimizers (only update via polyak averaging) # Async Version for p in global_ac_targ.parameters(): p.requires_grad = False # Experience buffer replay_buffer = ReplayBuffer(obs_dim=obs_dim, size=args.replay_size) training_buffer = ReplayBuffer(obs_dim=obs_dim, size=args.replay_size) # Entropy Tuning target_entropy = -np.log((1.0 / act_dim)) * 0.5 alpha = max(local_ac.log_alpha.exp().item(), args.min_alpha) if not args.fix_alpha else args.min_alpha # Set up optimizers for policy and q-function # Async Version pi_optimizer = Adam(global_ac.pi.parameters(), lr=args.lr, eps=1e-4) q1_optimizer = Adam(global_ac.q1.parameters(), lr=args.lr, eps=1e-4) q2_optimizer = Adam(global_ac.q2.parameters(), lr=args.lr, eps=1e-4) alpha_optim = Adam([global_ac.log_alpha], lr=args.lr, eps=1e-4) # Prepare for interaction with environment o, ep_ret, ep_len = env.reset(), 0, 0 discard = False glod_model = None glod_lower = None glod_upper = None last_updated = 0 saved_e = 0 t = T.value() e = E.value() local_t, local_e = 0, 0 # Main loop: collect experience in env and update/log each epoch while e <= args.episode: # Until start_steps have elapsed, randomly sample actions # from a uniform distribution for better exploration. Afterwards, # use the learned policy. with torch.no_grad(): a = local_ac.get_action(o, device=device) if hasattr(env, 'p2'): p2 = env.p2 else: p2 = None # Step the env o2, r, d, info = env.step(a, np.random.randint(act_dim)) env.render() if info.get('no_data_receive', False): discard = True ep_ret += r ep_len += 1 # Ignore the "done" signal if it comes from hitting the time # horizon (that is, when it's an artificial terminal signal # that isn't based on the agent's state) # d = False if (ep_len == args.max_ep_len) or discard else d # Store experience to replay buffer if glod_model is None or not args.ood: replay_buffer.store(o, a, r, o2, d, str(p2)) training_buffer.store(o, a, r, o2, d, str(p2)) else: obs_glod_score = retrieve_scores(glod_model, np.expand_dims(o, axis=0), device=torch.device("cpu"), k=args.ood_K) if glod_lower <= obs_glod_score <= glod_upper: training_buffer.store(o, a, r, o2, d, str(p2)) replay_buffer.store(o, a, r, o2, d, str(p2)) # Super critical, easy to overlook step: make sure to update # most recent observation! o = o2 T.increment() t = T.value() local_t += 1 # End of trajectory handling if d or (ep_len == args.max_ep_len) or discard: E.increment() e = E.value() local_e += 1 # logger.store(EpRet=ep_ret, EpLen=ep_len) if info.get('win', False): wins.append(1) else: wins.append(0) scores.append(ep_ret) m_score = np.mean(scores[-100:]) win_rate = np.mean(wins[-100:]) print( "Process {}, opponent:{}, # of global_episode :{}, # of global_steps :{}, round score: {}, mean score : {:.1f}, win_rate:{}, steps: {}, alpha: {}" .format(rank, args.p2, e, t, ep_ret, m_score, win_rate, ep_len, alpha)) writer.add_scalar("metrics/round_score", ep_ret, e) writer.add_scalar("metrics/mean_score", m_score.item(), e) writer.add_scalar("metrics/win_rate", win_rate.item(), e) writer.add_scalar("metrics/round_step", ep_len, e) writer.add_scalar("metrics/alpha", alpha, e) o, ep_ret, ep_len = env.reset(), 0, 0 discard = False # OOD update stage, can only use CPU as the GPU memory can not hold so much data if local_e >= args.ood_starts and local_e % args.ood_update_rounds == 0 and args.ood and local_e != last_updated: print("OOD updating at rounds {}".format(e)) print("Replay Buffer Size: {}, Training Buffer Size: {}".format( replay_buffer.size, training_buffer.size)) glod_idxs = np.random.randint(0, training_buffer.size, size=int(training_buffer.size * args.ood_train_per)) glod_input = training_buffer.obs_buf[glod_idxs] glod_target = training_buffer.act_buf[glod_idxs] ood_train = (glod_input, glod_target) glod_model = deepcopy(global_ac.pi).cpu() glod_model = convert_to_glod(glod_model, train_loader=ood_train, hidden_dim=args.hid, act_dim=act_dim, device=torch.device("cpu")) training_buffer = deepcopy(replay_buffer) glod_scores = retrieve_scores( glod_model, replay_buffer.obs_buf[:training_buffer.size], device=torch.device("cpu"), k=args.ood_K) glod_scores = glod_scores.detach().cpu().numpy() glod_p2 = training_buffer.p2_buf[:training_buffer.size] drop_points = np.percentile( a=glod_scores, q=[args.ood_drop_lower, args.ood_drop_upper]) glod_lower, glod_upper = drop_points[0], drop_points[1] mask = np.logical_and((glod_scores >= glod_lower), (glod_scores <= glod_upper)) reserved_indexes = np.argwhere(mask).flatten() if len(reserved_indexes) > 0: training_buffer.ood_drop(reserved_indexes) writer.add_histogram(values=glod_scores, max_bins=300, global_step=local_e, tag="OOD") print("Replay Buffer Size: {}, Training Buffer Size: {}".format( replay_buffer.size, training_buffer.size)) torch.save( (glod_scores, replay_buffer.p2_buf[:replay_buffer.size]), os.path.join(args.save_dir, args.exp_name, "glod_info_{}_{}".format(rank, local_e))) last_updated = local_e # SAC Update handling if local_e >= args.update_after and local_t % args.update_every == 0: for j in range(args.update_every): batch = training_buffer.sample_trans(args.batch_size, device=device) # First run one gradient descent step for Q1 and Q2 q1_optimizer.zero_grad() q2_optimizer.zero_grad() loss_q = local_ac.compute_loss_q(batch, global_ac_targ, args.gamma, alpha) loss_q.backward() nn.utils.clip_grad_norm_(global_ac.parameters(), max_norm=20, norm_type=2) q1_optimizer.step() q2_optimizer.step() # Next run one gradient descent step for pi. pi_optimizer.zero_grad() loss_pi, entropy = local_ac.compute_loss_pi(batch, alpha) loss_pi.backward() nn.utils.clip_grad_norm_(global_ac.parameters(), max_norm=20, norm_type=2) pi_optimizer.step() alpha_optim.zero_grad() alpha_loss = -(local_ac.log_alpha * (entropy + target_entropy).detach()).mean() alpha_loss.backward(retain_graph=False) alpha = max( local_ac.log_alpha.exp().item(), args.min_alpha) if not args.fix_alpha else args.min_alpha nn.utils.clip_grad_norm_(global_ac.parameters(), max_norm=20, norm_type=2) alpha_optim.step() for global_param, local_param in zip(global_ac.parameters(), local_ac.parameters()): global_param._grad = local_param.grad state_dict = global_ac.state_dict() local_ac.load_state_dict(state_dict) # Finally, update target networks by polyak averaging. with torch.no_grad(): for p, p_targ in zip(global_ac.parameters(), global_ac_targ.parameters()): p_targ.data.copy_((1 - args.polyak) * p.data + args.polyak * p_targ.data) writer.add_scalar("training/pi_loss", loss_pi.detach().item(), t) writer.add_scalar("training/q_loss", loss_q.detach().item(), t) writer.add_scalar("training/alpha_loss", alpha_loss.detach().item(), t) writer.add_scalar("training/entropy", entropy.detach().mean().item(), t) if e % args.save_freq == 0 and e > 0 and e != saved_e: torch.save( global_ac.state_dict(), os.path.join(args.save_dir, args.exp_name, "model_torch_{}".format(e))) state_dict_trans( global_ac.state_dict(), os.path.join(args.save_dir, args.exp_name, "model_numpy_{}".format(e))) torch.save((e, t, list(scores), list(wins)), os.path.join(args.save_dir, args.exp_name, "model_data_{}".format(e))) print("Saving model at episode:{}".format(t)) saved_e = e
def sac(env_fn, actor_critic=MLPActorCritic, ac_kwargs=dict(), seed=0, steps_per_epoch=4000, epochs=100, replay_size=int(1e6), gamma=0.99, polyak=0.995, lr=1e-3, alpha=0.2, batch_size=100, start_steps=10000, update_after=1000, update_every=50, num_test_episodes=10, max_ep_len=1000, logger_kwargs=dict(), save_freq=1000, save_dir=None): """ Soft Actor-Critic (SAC) Args: env_fn : A function which creates a copy of the environment. The environment must satisfy the OpenAI Gym API. actor_critic: The constructor method for a PyTorch Module with an ``act`` method, a ``pi`` module, a ``q1`` module, and a ``q2`` module. The ``act`` method and ``pi`` module should accept batches of observations as inputs, and ``q1`` and ``q2`` should accept a batch of observations and a batch of actions as inputs. When called, ``act``, ``q1``, and ``q2`` should return: =========== ================ ====================================== Call Output Shape Description =========== ================ ====================================== ``act`` (batch, act_dim) | Numpy array of actions for each | observation. ``q1`` (batch,) | Tensor containing one current estimate | of Q* for the provided observations | and actions. (Critical: make sure to | flatten this!) ``q2`` (batch,) | Tensor containing the other current | estimate of Q* for the provided observations | and actions. (Critical: make sure to | flatten this!) =========== ================ ====================================== Calling ``pi`` should return: =========== ================ ====================================== Symbol Shape Description =========== ================ ====================================== ``a`` (batch, act_dim) | Tensor containing actions from policy | given observations. ``logp_pi`` (batch,) | Tensor containing log probabilities of | actions in ``a``. Importantly: gradients | should be able to flow back into ``a``. =========== ================ ====================================== ac_kwargs (dict): Any kwargs appropriate for the ActorCritic object you provided to SAC. seed (int): Seed for random number generators. steps_per_epoch (int): Number of steps of interaction (state-action pairs) for the agent and the environment in each epoch. epochs (int): Number of epochs to run and train agent. replay_size (int): Maximum length of replay buffer. gamma (float): Discount factor. (Always between 0 and 1.) polyak (float): Interpolation factor in polyak averaging for target networks. Target networks are updated towards main networks according to: .. math:: \\theta_{\\text{targ}} \\leftarrow \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta where :math:`\\rho` is polyak. (Always between 0 and 1, usually close to 1.) lr (float): Learning rate (used for both policy and value learning). alpha (float): Entropy regularization coefficient. (Equivalent to inverse of reward scale in the original SAC paper.) batch_size (int): Minibatch size for SGD. start_steps (int): Number of steps for uniform-random action selection, before running real policy. Helps exploration. update_after (int): Number of env interactions to collect before starting to do gradient descent updates. Ensures replay buffer is full enough for useful updates. update_every (int): Number of env interactions that should elapse between gradient descent updates. Note: Regardless of how long you wait between updates, the ratio of env steps to gradient steps is locked to 1. num_test_episodes (int): Number of episodes to test the deterministic policy at the end of each epoch. max_ep_len (int): Maximum length of trajectory / episode / rollout. logger_kwargs (dict): Keyword args for EpochLogger. save_freq (int): How often (in terms of gap between epochs) to save the current policy and value function. """ logger = EpochLogger(**logger_kwargs) logger.save_config(locals()) torch.manual_seed(seed) np.random.seed(seed) env, test_env = env_fn(), env_fn() obs_dim = env.observation_space.shape[0] act_dim = env.action_space.n # Action limit for clamping: critically, assumes all dimensions share the same bound! # act_limit = env.action_space.high[0] # Create actor-critic module and target networks ac = actor_critic(env.observation_space, env.action_space, **ac_kwargs) ac_targ = deepcopy(ac) # Freeze target networks with respect to optimizers (only update via polyak averaging) for p in ac_targ.parameters(): p.requires_grad = False # List of parameters for both Q-networks (save this for convenience) # Experience buffer replay_buffer = ReplayBuffer(obs_dim=obs_dim, size=replay_size) # Count variables (protip: try to get a feel for how different size networks behave!) var_counts = tuple(count_vars(module) for module in [ac.pi, ac.q1, ac.q2]) logger.log('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n' % var_counts) # Set up optimizers for policy and q-function pi_optimizer = Adam(ac.pi.parameters(), lr=lr) q1_optimizer = Adam(ac.q1.parameters(), lr=lr) q2_optimizer = Adam(ac.q2.parameters(), lr=lr) # Set up model saving logger.setup_pytorch_saver(ac) # product action def get_actions_info(a_prob): a_dis = Categorical(a_prob) max_a = torch.argmax(a_prob) sample_a = a_dis.sample().cpu() z = a_prob == 0.0 z = z.float() * 1e-8 log_a_prob = torch.log(a_prob + z) return a_prob, log_a_prob, sample_a, max_a # Set up function for computing SAC Q-losses def compute_loss_q(data): o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done'] # Bellman backup for Q functions with torch.no_grad(): # Target actions come from *current* policy a_prob, log_a_prob, sample_a, max_a = get_actions_info(ac.pi(o2)) # Target Q-values q1_pi_targ = ac_targ.q1(o2) q2_pi_targ = ac_targ.q2(o2) q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ) backup = r + gamma * (1 - d) * (a_prob * (q_pi_targ - alpha * log_a_prob)).sum(dim=1) # MSE loss against Bellman backup q1 = ac.q1(o).gather(1, a.unsqueeze(-1).long()) q2 = ac.q2(o).gather(1, a.unsqueeze(-1).long()) loss_q1 = F.mse_loss(q1, backup.unsqueeze(-1)) loss_q2 = F.mse_loss(q2, backup.unsqueeze(-1)) loss_q = loss_q1 + loss_q2 # Useful info for logging q_info = dict(Q1Vals=q1.detach().numpy(), Q2Vals=q2.detach().numpy()) return loss_q, q_info # Set up function for computing SAC pi loss def compute_loss_pi(data): o = data['obs'] a_prob, log_a_prob, sample_a, max_a = get_actions_info(ac.pi(o)) q1_pi = ac.q1(o) q2_pi = ac.q2(o) q_pi = torch.min(q1_pi, q2_pi) # Entropy-regularized policy loss loss_pi = (a_prob * (alpha * log_a_prob - q_pi)).mean() entropy = torch.sum(log_a_prob * a_prob, dim=1).detach() # Useful info for logging pi_info = dict(LogPi=entropy.numpy()) return loss_pi, pi_info def update(data): # First run one gradient descent step for Q1 and Q2 q1_optimizer.zero_grad() q2_optimizer.zero_grad() loss_q, q_info = compute_loss_q(data) loss_q.backward() q1_optimizer.step() q2_optimizer.step() # Record things logger.store(LossQ=loss_q.detach().item(), **q_info) # Freeze Q-networks so you don't waste computational effort # computing gradients for them during the policy learning step. # for p in q_params: # p.requires_grad = False # Next run one gradient descent step for pi. pi_optimizer.zero_grad() loss_pi, pi_info = compute_loss_pi(data) loss_pi.backward() pi_optimizer.step() # Unfreeze Q-networks so you can optimize it at next DDPG step. # for p in q_params: # p.requires_grad = True # Record things logger.store(LossPi=loss_pi.item(), **pi_info) # Finally, update target networks by polyak averaging. with torch.no_grad(): for p, p_targ in zip(ac.parameters(), ac_targ.parameters()): # NB: We use an in-place operations "mul_", "add_" to update target # params, as opposed to "mul" and "add", which would make new tensors. p_targ.data.copy_((1 - polyak) * p.data + polyak * p_targ.data) # p_targ.data.mul_(polyak) # p_targ.data.add_((1 - polyak) * p.data) def get_action(o, greedy=False): if len(o.shape) == 1: o = np.expand_dims(o, axis=0) a_prob = ac.act(torch.as_tensor(o, dtype=torch.float32), greedy) a_prob, log_a_prob, sample_a, max_a = get_actions_info(a_prob) action = sample_a if not greedy else max_a return action.item() def test_agent(): for j in range(num_test_episodes): o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0 while not (d or (ep_len == max_ep_len)): # Take deterministic actions at test time o, r, d, _ = test_env.step(get_action(o, True)) # test_env.render() ep_ret += r ep_len += 1 logger.store(TestEpRet=ep_ret, TestEpLen=ep_len) print(ep_ret) # Prepare for interaction with environment total_steps = steps_per_epoch * epochs start_time = time.time() scores = [] o, ep_ret, ep_len = env.reset(), 0, 0 discard = False # Main loop: collect experience in env and update/log each epoch for t in range(total_steps): # Until start_steps have elapsed, randomly sample actions # from a uniform distribution for better exploration. Afterwards, # use the learned policy. if t > start_steps: a = get_action(o) else: a = env.action_space.sample() # Step the env o2, r, d, info = env.step(a) if info.get('no_data_receive', False): discard = True env.render() ep_ret += r ep_len += 1 # Ignore the "done" signal if it comes from hitting the time # horizon (that is, when it's an artificial terminal signal # that isn't based on the agent's state) d = False if ep_len == max_ep_len or discard else d # Store experience to replay buffer replay_buffer.store(o, a, r, o2, d) # Super critical, easy to overlook step: make sure to update # most recent observation! o = o2 # End of trajectory handling if d or (ep_len == max_ep_len) or discard: logger.store(EpRet=ep_ret, EpLen=ep_len) scores.append(ep_ret) print("round len:{}, round score: {}, mean score: {}".format(ep_len, ep_ret, np.mean(scores[-100:]))) o, ep_ret, ep_len = env.reset(), 0, 0 discard = False # Update handling if t >= update_after and t % update_every == 0: for j in range(update_every): batch = replay_buffer.sample_batch(batch_size) update(data=batch) # End of epoch handling if (t + 1) % steps_per_epoch == 0: epoch = (t + 1) // steps_per_epoch # Save model if t % save_freq == 0 and t > 0: torch.save(ac.state_dict(), os.path.join(save_dir, "model")) state_dict_trans(ac.state_dict(), os.path.join(save_dir, "SAC_Toothless.numpy")) print("Saving model at episode:{}".format(t)) if (epoch % save_freq == 0) or (epoch == epochs): logger.save_state({'env': env}, None)