def initialise_policy(self): # initialise policy network policy_net = Policy( args=self.args, # pass_state_to_policy=self.args.pass_state_to_policy, pass_latent_to_policy=self.args.pass_latent_to_policy, pass_belief_to_policy=self.args.pass_belief_to_policy, pass_task_to_policy=self.args.pass_task_to_policy, dim_state=self.args.state_dim, dim_latent=self.args.latent_dim * 2, dim_belief=self.args.belief_dim, dim_task=self.args.task_dim, # hidden_layers=self.args.policy_layers, activation_function=self.args.policy_activation_function, policy_initialisation=self.args.policy_initialisation, # action_space=self.envs.action_space, init_std=self.args.policy_init_std, ).to(device) # initialise policy trainer if self.args.policy == 'a2c': policy = A2C( self.args, policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, policy_optimiser=self.args.policy_optimiser, policy_anneal_lr=self.args.policy_anneal_lr, train_steps=self.num_updates, optimiser_vae=self.vae.optimiser_vae, lr=self.args.lr_policy, eps=self.args.policy_eps, ) elif self.args.policy == 'ppo': policy = PPO( self.args, policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, policy_optimiser=self.args.policy_optimiser, policy_anneal_lr=self.args.policy_anneal_lr, train_steps=self.num_updates, lr=self.args.lr_policy, eps=self.args.policy_eps, ppo_epoch=self.args.ppo_num_epochs, num_mini_batch=self.args.ppo_num_minibatch, use_huber_loss=self.args.ppo_use_huberloss, use_clipped_value_loss=self.args.ppo_use_clipped_value_loss, clip_param=self.args.ppo_clip_param, optimiser_vae=self.vae.optimiser_vae, ) else: raise NotImplementedError return policy
def initialise_policy(self): if hasattr(self.envs.action_space, 'low'): action_low = self.envs.action_space.low action_high = self.envs.action_space.high else: action_low = action_high = None # initialise policy network policy_net = Policy( args=self.args, # pass_state_to_policy=self.args.pass_state_to_policy, pass_latent_to_policy= False, # use metalearner.py if you want to use the VAE pass_belief_to_policy=self.args.pass_belief_to_policy, pass_task_to_policy=self.args.pass_task_to_policy, dim_state=self.args.state_dim, dim_latent=0, dim_belief=self.args.belief_dim, dim_task=self.args.task_dim, # hidden_layers=self.args.policy_layers, activation_function=self.args.policy_activation_function, policy_initialisation=self.args.policy_initialisation, # action_space=self.envs.action_space, init_std=self.args.policy_init_std, norm_actions_of_policy=self.args.norm_actions_of_policy, action_low=action_low, action_high=action_high, ).to(device) # initialise policy trainer if self.args.policy == 'a2c': policy = A2C( self.args, policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, policy_optimiser=self.args.policy_optimiser, policy_anneal_lr=self.args.policy_anneal_lr, train_steps=self.num_updates, lr=self.args.lr_policy, eps=self.args.policy_eps, ) elif self.args.policy == 'ppo': policy = PPO( self.args, policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, policy_optimiser=self.args.policy_optimiser, policy_anneal_lr=self.args.policy_anneal_lr, train_steps=self.num_updates, lr=self.args.lr_policy, eps=self.args.policy_eps, ppo_epoch=self.args.ppo_num_epochs, num_mini_batch=self.args.ppo_num_minibatch, use_huber_loss=self.args.ppo_use_huberloss, use_clipped_value_loss=self.args.ppo_use_clipped_value_loss, clip_param=self.args.ppo_clip_param, ) else: raise NotImplementedError return policy
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Apr 18 16:35:31 2019 @author: clytie """ if __name__ == "__main__": import numpy as np import time from tqdm import tqdm from env.dist_env import BreakoutEnv from algorithms.ppo import PPO ppo = PPO(4, (84, 84, 4), temperature=0.1, save_path="./ppo_log") env = BreakoutEnv(4999, num_envs=1, mode="test") env_ids, states, _, _ = env.start() for _ in tqdm(range(10000)): time.sleep(0.1) actions = ppo.get_action(np.asarray(states)) env_ids, states, _, _ = env.step(env_ids, actions) env.close()
x = tf.layers.conv2d(x, 64, 3, 1, activation=tf.nn.relu) x = tf.contrib.layers.flatten(x) x = tf.layers.dense(x, 512, activation=tf.nn.relu) logit_action_probability = tf.layers.dense( x, action_space, kernel_initializer=tf.truncated_normal_initializer(0.0, 0.01)) state_value = tf.squeeze( tf.layers.dense( x, 1, kernel_initializer=tf.truncated_normal_initializer())) return logit_action_probability, state_value ppo = PPO(action_space, obs_fn, model_fn, train_epoch=5, batch_size=64, save_path='./ppo_log_oneplayer_stack') env = Raiden2(6666, num_envs=8, with_stack=True) env_ids, states, rewards, dones = env.start() nth_trajectory = 0 while True: nth_trajectory += 1 for _ in tqdm(range(explore_steps)): actions = ppo.get_action(np.asarray(states)) actions = [(action, 4) for action in actions] env_ids, states, rewards, dones = env.step(env_ids, actions) s_batchs, a_batchs, r_batchs, d_batchs = env.get_episodes()
def run(config): model_path = (Path('./models') / config.env_id / ('run%i' % config.run_num)) if config.incremental is not None: model_path = model_path / 'incremental' / ('model_ep%i.pt' % config.incremental) else: model_path = model_path / 'model.pt' if config.save_gifs: gif_path = model_path.parent / 'gifs' gif_path.mkdir(exist_ok=True) ppo = PPO.init_from_save(model_path) env = make_env(config.env_id) ppo.prep_rollouts(device='cpu') ifi = 1 / config.fps # inter-frame interval for ep_i in range(config.n_episodes): print("Episode %i of %i" % (ep_i + 1, config.n_episodes)) obs = env.reset() nagents = len(obs) dones = [False] * nagents for agent in env.agents: agent.trajectory = [] if config.save_gifs: frames = [] frames.append(env.render('rgb_array')[0]) act_hidden = [[torch.zeros(1, 128), torch.zeros(1, 128)] for i in range(nagents)] crt_hidden = [[torch.zeros(1, 128), torch.zeros(1, 128)] for i in range(nagents)] rews = None env.render('human') for t_i in range(config.episode_length): print(f'{t_i} / {config.episode_length}') calc_start = time.time() nagents = len(obs) torch_obs = [Variable(torch.Tensor(obs[i]).view(1, -1), requires_grad=False) for i in range(nagents)] _, _, _, mean_list = ppo.step(torch_obs, act_hidden, crt_hidden) agent_actions_list = [a.data.cpu().numpy() for a in mean_list] clipped_action_list = [np.clip(a, -1, 1) for a in agent_actions_list] actions = [ac.flatten() for ac in clipped_action_list] for i in range(len(dones)): if dones[i]: env.agents[i].movable = False else: env.agents[i].trajectory.append(np.copy(env.agents[i].state.p_pos)) obs, rewards, dones, infos = env.step(actions) if rews is None: rews = np.zeros(len(rewards)) rews += np.array(rewards) if config.save_gifs: frames.append(env.render('rgb_array')[0]) calc_end = time.time() elapsed = calc_end - calc_start if elapsed < ifi: time.sleep(ifi - elapsed) env.render('human') if config.save_gifs: gif_num = 0 while (gif_path / ('%i_%i.gif' % (gif_num, ep_i))).exists(): gif_num += 1 imageio.mimsave(str(gif_path / ('%i_%i.gif' % (gif_num, ep_i))), frames, duration=ifi) env.close()
from tqdm import tqdm import logging from algorithms.ppo import PPO from env.dist_env import BreakoutEnv logging.basicConfig(level=logging.INFO, format='%(asctime)s|%(levelname)s|%(message)s') explore_steps = 512 total_updates = 2000 save_model_freq = 100 env = BreakoutEnv(50002, num_envs=20) env_ids, states, rewards, dones = env.start() ppo = PPO(env.action_space, env.state_space, train_epoch=5, clip_schedule=lambda x: 0.2) nth_trajectory = 0 while True: nth_trajectory += 1 for _ in tqdm(range(explore_steps)): actions = ppo.get_action(np.asarray(states)) env_ids, states, rewards, dones = env.step(env_ids, actions) s_batch, a_batch, r_batch, d_batch = env.get_episodes() logging.info(f'>>>>{env.mean_reward}, nth_trajectory{nth_trajectory}') ppo.update(s_batch, a_batch, r_batch, d_batch, min(0.9, nth_trajectory / total_updates)) ppo.sw.add_scalar('epreward_mean',
def main(): # setup parameters args = SimpleNamespace( env_module="environments", env_name="TargetEnv-v0", device="cuda:0" if torch.cuda.is_available() else "cpu", num_parallel=100, vae_path="models/", frame_skip=1, seed=16, load_saved_model=False, ) args.num_parallel *= args.frame_skip env = make_gym_environment(args) # env parameters args.action_size = env.action_space.shape[0] args.observation_size = env.observation_space.shape[0] # other configs args.save_path = os.path.join(current_dir, "con_" + args.env_name + ".pt") # sampling parameters args.num_frames = 10e7 args.num_steps_per_rollout = env.unwrapped.max_timestep args.num_updates = int(args.num_frames / args.num_parallel / args.num_steps_per_rollout) # learning parameters args.lr = 3e-5 args.final_lr = 1e-5 args.eps = 1e-5 args.lr_decay_type = "exponential" args.mini_batch_size = 1000 args.num_mini_batch = (args.num_parallel * args.num_steps_per_rollout // args.mini_batch_size) # ppo parameters use_gae = True entropy_coef = 0.0 value_loss_coef = 1.0 ppo_epoch = 10 gamma = 0.99 gae_lambda = 0.95 clip_param = 0.2 max_grad_norm = 1.0 torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) obs_shape = env.observation_space.shape obs_shape = (obs_shape[0], *obs_shape[1:]) if args.load_saved_model: actor_critic = torch.load(args.save_path, map_location=args.device) print("Loading model:", args.save_path) else: controller = PoseVAEController(env) actor_critic = PoseVAEPolicy(controller) actor_critic = actor_critic.to(args.device) actor_critic.env_info = {"frame_skip": args.frame_skip} agent = PPO( actor_critic, clip_param, ppo_epoch, args.num_mini_batch, value_loss_coef, entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=max_grad_norm, ) rollouts = RolloutStorage( args.num_steps_per_rollout, args.num_parallel, obs_shape, args.action_size, actor_critic.state_size, ) obs = env.reset() rollouts.observations[0].copy_(obs) rollouts.to(args.device) log_path = os.path.join(current_dir, "log_ppo_progress-{}".format(args.env_name)) logger = StatsLogger(csv_path=log_path) for update in range(args.num_updates): ep_info = {"reward": []} ep_reward = 0 if args.lr_decay_type == "linear": update_linear_schedule(agent.optimizer, update, args.num_updates, args.lr, args.final_lr) elif args.lr_decay_type == "exponential": update_exponential_schedule(agent.optimizer, update, 0.99, args.lr, args.final_lr) for step in range(args.num_steps_per_rollout): # Sample actions with torch.no_grad(): value, action, action_log_prob = actor_critic.act( rollouts.observations[step]) obs, reward, done, info = env.step(action) ep_reward += reward end_of_rollout = info.get("reset") masks = (~done).float() bad_masks = (~(done * end_of_rollout)).float() if done.any(): ep_info["reward"].append(ep_reward[done].clone()) ep_reward *= (~done).float() # zero out the dones reset_indices = env.parallel_ind_buf.masked_select( done.squeeze()) obs = env.reset(reset_indices) if end_of_rollout: obs = env.reset() rollouts.insert(obs, action, action_log_prob, value, reward, masks, bad_masks) with torch.no_grad(): next_value = actor_critic.get_value( rollouts.observations[-1]).detach() rollouts.compute_returns(next_value, use_gae, gamma, gae_lambda) value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.after_update() torch.save(copy.deepcopy(actor_critic).cpu(), args.save_path) ep_info["reward"] = torch.cat(ep_info["reward"]) logger.log_stats( args, { "update": update, "ep_info": ep_info, "dist_entropy": dist_entropy, "value_loss": value_loss, "action_loss": action_loss, }, )
def main(_seed, _config, _run): args = init(_seed, _config, _run) env_name = args.env_name dummy_env = make_env(env_name, render=False) cleanup_log_dir(args.log_dir) cleanup_log_dir(args.log_dir + "_test") try: os.makedirs(args.save_dir) except OSError: pass torch.set_num_threads(1) envs = make_vec_envs(env_name, args.seed, args.num_processes, args.log_dir) envs.set_mirror(args.use_phase_mirror) test_envs = make_vec_envs(env_name, args.seed, args.num_tests, args.log_dir + "_test") test_envs.set_mirror(args.use_phase_mirror) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) if args.use_curriculum: curriculum = 0 print("curriculum", curriculum) envs.update_curriculum(curriculum) if args.use_specialist: specialist = 0 print("specialist", specialist) envs.update_specialist(specialist) if args.use_threshold_sampling: sampling_threshold = 200 first_sampling = False uniform_sampling = True uniform_every = 500000 uniform_counter = 1 evaluate_envs = make_env(env_name, render=False) evaluate_envs.set_mirror(args.use_phase_mirror) evaluate_envs.update_curriculum(0) prob_filter = np.zeros((11, 11)) prob_filter[5, 5] = 1 if args.use_adaptive_sampling: evaluate_envs = make_env(env_name, render=False) evaluate_envs.set_mirror(args.use_phase_mirror) evaluate_envs.update_curriculum(0) if args.plot_prob: import matplotlib.pyplot as plt fig = plt.figure() plt.show(block=False) ax1 = fig.add_subplot(121) ax2 = fig.add_subplot(122) obs_shape = envs.observation_space.shape obs_shape = (obs_shape[0], *obs_shape[1:]) if args.load_saved_controller: best_model = "{}_base.pt".format(env_name) model_path = os.path.join(current_dir, "models", best_model) print("Loading model {}".format(best_model)) actor_critic = torch.load(model_path) actor_critic.reset_dist() else: controller = SoftsignActor(dummy_env) actor_critic = Policy(controller, num_ensembles=args.num_ensembles) mirror_function = None if args.use_mirror: indices = dummy_env.unwrapped.get_mirror_indices() mirror_function = get_mirror_function(indices) device = "cuda:0" if args.cuda else "cpu" if args.cuda: actor_critic.cuda() agent = PPO(actor_critic, mirror_function=mirror_function, **args.ppo_params) rollouts = RolloutStorage( args.num_steps, args.num_processes, obs_shape, envs.action_space.shape[0], actor_critic.state_size, ) current_obs = torch.zeros(args.num_processes, *obs_shape) def update_current_obs(obs): shape_dim0 = envs.observation_space.shape[0] obs = torch.from_numpy(obs).float() current_obs[:, -shape_dim0:] = obs obs = envs.reset() update_current_obs(obs) rollouts.observations[0].copy_(current_obs) if args.cuda: current_obs = current_obs.cuda() rollouts.cuda() episode_rewards = deque(maxlen=args.num_processes) test_episode_rewards = deque(maxlen=args.num_tests) num_updates = int(args.num_frames) // args.num_steps // args.num_processes start = time.time() next_checkpoint = args.save_every max_ep_reward = float("-inf") logger = ConsoleCSVLogger(log_dir=args.experiment_dir, console_log_interval=args.log_interval) update_values = False if args.save_sampling_prob: sampling_prob_list = [] for j in range(num_updates): if args.lr_decay_type == "linear": scheduled_lr = linear_decay(j, num_updates, args.lr, final_value=0) elif args.lr_decay_type == "exponential": scheduled_lr = exponential_decay(j, 0.99, args.lr, final_value=3e-5) else: scheduled_lr = args.lr set_optimizer_lr(agent.optimizer, scheduled_lr) ac_state_dict = copy.deepcopy(actor_critic).cpu().state_dict() if update_values and args.use_threshold_sampling: envs.update_curriculum(5) elif (not update_values ) and args.use_threshold_sampling and first_sampling: envs.update_specialist(0) if args.use_threshold_sampling and not uniform_sampling: obs = evaluate_envs.reset() yaw_size = dummy_env.yaw_samples.shape[0] pitch_size = dummy_env.pitch_samples.shape[0] total_metric = torch.zeros(1, yaw_size * pitch_size).to(device) evaluate_counter = 0 while True: obs = torch.from_numpy(obs).float().unsqueeze(0).to(device) with torch.no_grad(): _, action, _, _ = actor_critic.act(obs, None, None, deterministic=True) cpu_actions = action.squeeze().cpu().numpy() obs, reward, done, info = evaluate_envs.step(cpu_actions) if done: obs = evaluate_envs.reset() if evaluate_envs.update_terrain: evaluate_counter += 1 temp_states = evaluate_envs.create_temp_states() with torch.no_grad(): temp_states = torch.from_numpy(temp_states).float().to( device) value_samples = actor_critic.get_ensemble_values( temp_states, None, None) #yaw_size = dummy_env.yaw_samples.shape[0] mean = value_samples.mean(dim=-1) #mean = value_samples.min(dim=-1)[0] metric = mean.clone() metric = metric.view(yaw_size, pitch_size) #metric = metric / (metric.abs().max()) metric = metric.view(1, yaw_size * pitch_size) total_metric += metric if evaluate_counter >= 5: total_metric /= (total_metric.abs().max()) #total_metric[total_metric < 0.7] = 0 print("metric", total_metric) sampling_probs = ( -10 * (total_metric - args.curriculum_threshold).abs() ).softmax(dim=1).view( yaw_size, pitch_size ) #threshold1:150, 0.9 l2, threshold2: 10, 0.85 l1, threshold3: 10, 0.85, l1, 0.40 gap #threshold 4: 20, 0.85, l1, yaw 10 if args.save_sampling_prob: sampling_prob_list.append(sampling_probs.cpu().numpy()) sample_probs = np.zeros( (args.num_processes, yaw_size, pitch_size)) #print("prob", sampling_probs) for i in range(args.num_processes): sample_probs[i, :, :] = np.copy( sampling_probs.cpu().numpy().astype(np.float64)) envs.update_sample_prob(sample_probs) break elif args.use_threshold_sampling and uniform_sampling: envs.update_curriculum(5) # if args.use_threshold_sampling and not uniform_sampling: # obs = evaluate_envs.reset() # yaw_size = dummy_env.yaw_samples.shape[0] # pitch_size = dummy_env.pitch_samples.shape[0] # r_size = dummy_env.r_samples.shape[0] # total_metric = torch.zeros(1, yaw_size * pitch_size * r_size).to(device) # evaluate_counter = 0 # while True: # obs = torch.from_numpy(obs).float().unsqueeze(0).to(device) # with torch.no_grad(): # _, action, _, _ = actor_critic.act( # obs, None, None, deterministic=True # ) # cpu_actions = action.squeeze().cpu().numpy() # obs, reward, done, info = evaluate_envs.step(cpu_actions) # if done: # obs = evaluate_envs.reset() # if evaluate_envs.update_terrain: # evaluate_counter += 1 # temp_states = evaluate_envs.create_temp_states() # with torch.no_grad(): # temp_states = torch.from_numpy(temp_states).float().to(device) # value_samples = actor_critic.get_ensemble_values(temp_states, None, None) # mean = value_samples.mean(dim=-1) # #mean = value_samples.min(dim=-1)[0] # metric = mean.clone() # metric = metric.view(yaw_size, pitch_size, r_size) # #metric = metric / (metric.abs().max()) # metric = metric.view(1, yaw_size*pitch_size*r_size) # total_metric += metric # if evaluate_counter >= 5: # total_metric /= (total_metric.abs().max()) # #total_metric[total_metric < 0.7] = 0 # #print("metric", total_metric) # sampling_probs = (-10*(total_metric-0.85).abs()).softmax(dim=1).view(yaw_size, pitch_size, r_size) #threshold1:150, 0.9 l2, threshold2: 10, 0.85 l1, threshold3: 10, 0.85, l1, 0.40 gap # #threshold 4: 3d grid, 10, 0.85, l1 # sample_probs = np.zeros((args.num_processes, yaw_size, pitch_size, r_size)) # #print("prob", sampling_probs) # for i in range(args.num_processes): # sample_probs[i, :, :, :] = np.copy(sampling_probs.cpu().numpy().astype(np.float64)) # envs.update_sample_prob(sample_probs) # break # elif args.use_threshold_sampling and uniform_sampling: # envs.update_curriculum(5) if args.use_adaptive_sampling: obs = evaluate_envs.reset() yaw_size = dummy_env.yaw_samples.shape[0] pitch_size = dummy_env.pitch_samples.shape[0] total_metric = torch.zeros(1, yaw_size * pitch_size).to(device) evaluate_counter = 0 while True: obs = torch.from_numpy(obs).float().unsqueeze(0).to(device) with torch.no_grad(): _, action, _, _ = actor_critic.act(obs, None, None, deterministic=True) cpu_actions = action.squeeze().cpu().numpy() obs, reward, done, info = evaluate_envs.step(cpu_actions) if done: obs = evaluate_envs.reset() if evaluate_envs.update_terrain: evaluate_counter += 1 temp_states = evaluate_envs.create_temp_states() with torch.no_grad(): temp_states = torch.from_numpy(temp_states).float().to( device) value_samples = actor_critic.get_ensemble_values( temp_states, None, None) mean = value_samples.mean(dim=-1) metric = mean.clone() metric = metric.view(yaw_size, pitch_size) #metric = metric / metric.abs().max() metric = metric.view(1, yaw_size * pitch_size) total_metric += metric # sampling_probs = (-30*metric).softmax(dim=1).view(size, size) # sample_probs = np.zeros((args.num_processes, size, size)) # for i in range(args.num_processes): # sample_probs[i, :, :] = np.copy(sampling_probs.cpu().numpy().astype(np.float64)) # envs.update_sample_prob(sample_probs) if evaluate_counter >= 5: total_metric /= (total_metric.abs().max()) print("metric", total_metric) sampling_probs = (-10 * total_metric).softmax(dim=1).view( yaw_size, pitch_size) sample_probs = np.zeros( (args.num_processes, yaw_size, pitch_size)) for i in range(args.num_processes): sample_probs[i, :, :] = np.copy( sampling_probs.cpu().numpy().astype(np.float64)) envs.update_sample_prob(sample_probs) break for step in range(args.num_steps): # Sample actions with torch.no_grad(): value, action, action_log_prob, states = actor_critic.act( rollouts.observations[step], rollouts.states[step], rollouts.masks[step], deterministic=update_values) cpu_actions = action.squeeze(1).cpu().numpy() obs, reward, done, infos = envs.step(cpu_actions) reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float() if args.plot_prob and step == 0: temp_states = envs.create_temp_states() with torch.no_grad(): temp_states = torch.from_numpy(temp_states).float().to( device) value_samples = actor_critic.get_value( temp_states, None, None) size = dummy_env.yaw_samples.shape[0] v = value_samples.mean(dim=0).view(size, size).cpu().numpy() vs = value_samples.var(dim=0).view(size, size).cpu().numpy() ax1.pcolormesh(v) ax2.pcolormesh(vs) print(np.round(v, 2)) fig.canvas.draw() # if args.use_adaptive_sampling: # temp_states = envs.create_temp_states() # with torch.no_grad(): # temp_states = torch.from_numpy(temp_states).float().to(device) # value_samples = actor_critic.get_value(temp_states, None, None) # size = dummy_env.yaw_samples.shape[0] # sample_probs = (-value_samples / 5).softmax(dim=1).view(args.num_processes, size, size) # envs.update_sample_prob(sample_probs.cpu().numpy()) # if args.use_threshold_sampling and not uniform_sampling: # temp_states = envs.create_temp_states() # with torch.no_grad(): # temp_states = torch.from_numpy(temp_states).float().to(device) # value_samples = actor_critic.get_ensemble_values(temp_states, None, None) # size = dummy_env.yaw_samples.shape[0] # mean = value_samples.mean(dim=-1) # std = value_samples.std(dim=-1) #using std # metric = std.clone() # metric = metric.view(args.num_processes, size, size) # value_filter = torch.ones(args.num_processes, 11, 11).to(device) * -1e5 # value_filter[:, 5 - curriculum: 5 + curriculum + 1, 5 - curriculum: 5 + curriculum + 1] = 0 # metric = metric / metric.max() + value_filter # metric = metric.view(args.num_processes, size*size) # sample_probs = (30*metric).softmax(dim=1).view(args.num_processes, size, size) #using value estimate # metric = mean.clone() # metric = metric.view(args.num_processes, size, size) # value_filter = torch.ones(args.num_processes, 11, 11).to(device) * -1e5 # value_filter[:, 5 - curriculum: 5 + curriculum + 1, 5 - curriculum: 5 + curriculum + 1] = 0 # metric = metric / metric.abs().max() - value_filter # metric = metric.view(args.num_processes, size*size) # sample_probs = (-30*metric).softmax(dim=1).view(args.num_processes, size, size) # if args.plot_prob and step == 0: # #print(sample_probs.cpu().numpy()[0, :, :]) # ax.pcolormesh(sample_probs.cpu().numpy()[0, :, :]) # print(np.round(sample_probs.cpu().numpy()[0, :, :], 4)) # fig.canvas.draw() # envs.update_sample_prob(sample_probs.cpu().numpy()) #using value threshold # metric = mean.clone() # metric = metric.view(args.num_processes, size, size) # metric = metric / metric.abs().max()# - value_filter # metric = metric.view(args.num_processes, size*size) # sample_probs = (-30*(metric-0.8)**2).softmax(dim=1).view(args.num_processes, size, size) # if args.plot_prob and step == 0: # ax.pcolormesh(sample_probs.cpu().numpy()[0, :, :]) # print(np.round(sample_probs.cpu().numpy()[0, :, :], 4)) # fig.canvas.draw() # envs.update_sample_prob(sample_probs.cpu().numpy()) bad_masks = np.ones((args.num_processes, 1)) for p_index, info in enumerate(infos): keys = info.keys() # This information is added by algorithms.utils.TimeLimitMask if "bad_transition" in keys: bad_masks[p_index] = 0.0 # This information is added by baselines.bench.Monitor if "episode" in keys: episode_rewards.append(info["episode"]["r"]) masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.from_numpy(bad_masks) update_current_obs(obs) rollouts.insert( current_obs, states, action, action_log_prob, value, reward, masks, bad_masks, ) obs = test_envs.reset() if args.use_threshold_sampling: if uniform_counter % uniform_every == 0: uniform_sampling = True uniform_counter = 0 else: uniform_sampling = False uniform_counter += 1 if uniform_sampling: envs.update_curriculum(5) print("uniform") #print("max_step", dummy_env._max_episode_steps) for step in range(dummy_env._max_episode_steps): # Sample actions with torch.no_grad(): obs = torch.from_numpy(obs).float().to(device) _, action, _, _ = actor_critic.act(obs, None, None, deterministic=True) cpu_actions = action.squeeze(1).cpu().numpy() obs, reward, done, infos = test_envs.step(cpu_actions) reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float() for p_index, info in enumerate(infos): keys = info.keys() # This information is added by baselines.bench.Monitor if "episode" in keys: #print(info["episode"]["r"]) test_episode_rewards.append(info["episode"]["r"]) if args.use_curriculum and np.mean( episode_rewards) > 1000 and curriculum <= 4: curriculum += 1 print("curriculum", curriculum) envs.update_curriculum(curriculum) with torch.no_grad(): next_value = actor_critic.get_value(rollouts.observations[-1], rollouts.states[-1], rollouts.masks[-1]).detach() rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda) if update_values: value_loss = agent.update_values(rollouts) else: value_loss, action_loss, dist_entropy = agent.update(rollouts) #update_values = (not update_values) rollouts.after_update() frame_count = (j + 1) * args.num_steps * args.num_processes if (frame_count >= next_checkpoint or j == num_updates - 1) and args.save_dir != "": model_name = "{}_{:d}.pt".format(env_name, int(next_checkpoint)) next_checkpoint += args.save_every else: model_name = "{}_latest.pt".format(env_name) if args.save_sampling_prob: import pickle with open('{}_sampling_prob85.pkl'.format(env_name), 'wb') as fp: pickle.dump(sampling_prob_list, fp) # A really ugly way to save a model to CPU save_model = actor_critic if args.cuda: save_model = copy.deepcopy(actor_critic).cpu() if args.use_specialist and np.mean( episode_rewards) > 1000 and specialist <= 4: specialist_name = "{}_specialist_{:d}.pt".format( env_name, int(specialist)) specialist_model = actor_critic if args.cuda: specialist_model = copy.deepcopy(actor_critic).cpu() torch.save(specialist_model, os.path.join(args.save_dir, specialist_name)) specialist += 1 envs.update_specialist(specialist) # if args.use_threshold_sampling and np.mean(episode_rewards) > 1000 and curriculum <= 4: # first_sampling = False # curriculum += 1 # print("curriculum", curriculum) # envs.update_curriculum(curriculum) # prob_filter[5-curriculum:5+curriculum+1, 5-curriculum:5+curriculum+1] = 1 torch.save(save_model, os.path.join(args.save_dir, model_name)) if len(episode_rewards) > 1 and np.mean( episode_rewards) > max_ep_reward: model_name = "{}_best.pt".format(env_name) max_ep_reward = np.mean(episode_rewards) torch.save(save_model, os.path.join(args.save_dir, model_name)) if len(episode_rewards) > 1: end = time.time() total_num_steps = (j + 1) * args.num_processes * args.num_steps logger.log_epoch({ "iter": j + 1, "total_num_steps": total_num_steps, "fps": int(total_num_steps / (end - start)), "entropy": dist_entropy, "value_loss": value_loss, "action_loss": action_loss, "stats": { "rew": episode_rewards }, "test_stats": { "rew": test_episode_rewards }, })
def initialise_policy(self): # variables for task encoder (used for oracle) state_dim = self.envs.observation_space.shape[0] # TODO: this isn't ideal, find a nicer way to get the task dimension! if 'BeliefOracle' in self.args.env_name: task_dim = gym.make(self.args.env_name).observation_space.shape[0] - \ gym.make(self.args.env_name.replace('BeliefOracle', '')).observation_space.shape[0] latent_dim = self.args.latent_dim state_embedding_size = self.args.state_embedding_size use_task_encoder = True elif 'Oracle' in self.args.env_name: task_dim = gym.make(self.args.env_name).observation_space.shape[0] - \ gym.make(self.args.env_name.replace('Oracle', '')).observation_space.shape[0] latent_dim = self.args.latent_dim state_embedding_size = self.args.state_embedding_size use_task_encoder = True else: task_dim = latent_dim = state_embedding_size = 0 use_task_encoder = False # initialise rollout storage for the policy self.policy_storage = OnlineStorage( self.args, self.args.policy_num_steps, self.args.num_processes, self.args.obs_dim, self.args.act_space, hidden_size=0, latent_dim=self.args.latent_dim, normalise_observations=self.args.norm_obs_for_policy, normalise_rewards=self.args.norm_rew_for_policy, ) if hasattr(self.envs.action_space, 'low'): action_low = self.envs.action_space.low action_high = self.envs.action_space.high else: action_low = action_high = None # initialise policy network policy_net = Policy( # general state_dim=int(self.args.condition_policy_on_state) * state_dim, action_space=self.envs.action_space, init_std=self.args.policy_init_std, hidden_layers=self.args.policy_layers, activation_function=self.args.policy_activation_function, use_task_encoder=use_task_encoder, # task encoding things (for oracle) task_dim=task_dim, latent_dim=latent_dim, state_embed_dim=state_embedding_size, # normalise_actions=self.args.normalise_actions, action_low=action_low, action_high=action_high, ).to(device) # initialise policy if self.args.policy == 'a2c': # initialise policy trainer (A2C) self.policy = A2C( policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, lr=self.args.lr_policy, eps=self.args.policy_eps, alpha=self.args.a2c_alpha, ) elif self.args.policy == 'ppo': # initialise policy network self.policy = PPO( policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, lr=self.args.lr_policy, eps=self.args.policy_eps, ppo_epoch=self.args.ppo_num_epochs, num_mini_batch=self.args.ppo_num_minibatch, use_huber_loss=self.args.ppo_use_huberloss, use_clipped_value_loss=self.args.ppo_use_clipped_value_loss, clip_param=self.args.ppo_clip_param, ) else: raise NotImplementedError
class Learner: """ Learner (no meta-learning), can be used to train Oracle policies. """ def __init__(self, args): self.args = args # make sure everything has the same seed utl.seed(self.args.seed, self.args.deterministic_execution) # initialise tensorboard logger self.logger = TBLogger(self.args, self.args.exp_label) # initialise environments self.envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, log_dir=args.agent_log_dir, device=device, allow_early_resets=False, episodes_per_task=self.args.max_rollouts_per_task, obs_rms=None, ret_rms=None, ) # calculate what the maximum length of the trajectories is args.max_trajectory_len = self.envs._max_episode_steps args.max_trajectory_len *= self.args.max_rollouts_per_task # calculate number of meta updates self.args.num_updates = int( args.num_frames) // args.policy_num_steps // args.num_processes # get action / observation dimensions if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete): self.args.action_dim = 1 else: self.args.action_dim = self.envs.action_space.shape[0] self.args.obs_dim = self.envs.observation_space.shape[0] self.args.num_states = self.envs.num_states if str.startswith( self.args.env_name, 'Grid') else None self.args.act_space = self.envs.action_space self.initialise_policy() # count number of frames and updates self.frames = 0 self.iter_idx = 0 def initialise_policy(self): # variables for task encoder (used for oracle) state_dim = self.envs.observation_space.shape[0] # TODO: this isn't ideal, find a nicer way to get the task dimension! if 'BeliefOracle' in self.args.env_name: task_dim = gym.make(self.args.env_name).observation_space.shape[0] - \ gym.make(self.args.env_name.replace('BeliefOracle', '')).observation_space.shape[0] latent_dim = self.args.latent_dim state_embedding_size = self.args.state_embedding_size use_task_encoder = True elif 'Oracle' in self.args.env_name: task_dim = gym.make(self.args.env_name).observation_space.shape[0] - \ gym.make(self.args.env_name.replace('Oracle', '')).observation_space.shape[0] latent_dim = self.args.latent_dim state_embedding_size = self.args.state_embedding_size use_task_encoder = True else: task_dim = latent_dim = state_embedding_size = 0 use_task_encoder = False # initialise rollout storage for the policy self.policy_storage = OnlineStorage( self.args, self.args.policy_num_steps, self.args.num_processes, self.args.obs_dim, self.args.act_space, hidden_size=0, latent_dim=self.args.latent_dim, normalise_observations=self.args.norm_obs_for_policy, normalise_rewards=self.args.norm_rew_for_policy, ) if hasattr(self.envs.action_space, 'low'): action_low = self.envs.action_space.low action_high = self.envs.action_space.high else: action_low = action_high = None # initialise policy network policy_net = Policy( # general state_dim=int(self.args.condition_policy_on_state) * state_dim, action_space=self.envs.action_space, init_std=self.args.policy_init_std, hidden_layers=self.args.policy_layers, activation_function=self.args.policy_activation_function, use_task_encoder=use_task_encoder, # task encoding things (for oracle) task_dim=task_dim, latent_dim=latent_dim, state_embed_dim=state_embedding_size, # normalise_actions=self.args.normalise_actions, action_low=action_low, action_high=action_high, ).to(device) # initialise policy if self.args.policy == 'a2c': # initialise policy trainer (A2C) self.policy = A2C( policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, lr=self.args.lr_policy, eps=self.args.policy_eps, alpha=self.args.a2c_alpha, ) elif self.args.policy == 'ppo': # initialise policy network self.policy = PPO( policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, lr=self.args.lr_policy, eps=self.args.policy_eps, ppo_epoch=self.args.ppo_num_epochs, num_mini_batch=self.args.ppo_num_minibatch, use_huber_loss=self.args.ppo_use_huberloss, use_clipped_value_loss=self.args.ppo_use_clipped_value_loss, clip_param=self.args.ppo_clip_param, ) else: raise NotImplementedError def train(self): """ Given some stream of environments and a logger (tensorboard), (meta-)trains the policy. """ start_time = time.time() # reset environments (prev_obs_raw, prev_obs_normalised) = self.envs.reset() prev_obs_raw = prev_obs_raw.to(device) prev_obs_normalised = prev_obs_normalised.to(device) # insert initial observation / embeddings to rollout storage self.policy_storage.prev_obs_raw[0].copy_(prev_obs_raw) self.policy_storage.prev_obs_normalised[0].copy_(prev_obs_normalised) self.policy_storage.to(device) for self.iter_idx in range(self.args.num_updates): # check if we flushed the policy storage assert len(self.policy_storage.latent_mean) == 0 # rollouts policies for a few steps for step in range(self.args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action, action_log_prob = utl.select_action( policy=self.policy, args=self.args, obs=prev_obs_normalised if self.args.norm_obs_for_policy else prev_obs_raw, deterministic=False) # observe reward and next obs (next_obs_raw, next_obs_normalised), ( rew_raw, rew_normalised), done, infos = utl.env_step( self.envs, action) action = action.float() # create mask for episode ends masks_done = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # bad_mask is true if episode ended because time limit was reached bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]).to(device) # add the obs before reset to the policy storage self.policy_storage.next_obs_raw[step] = next_obs_raw.clone() self.policy_storage.next_obs_normalised[ step] = next_obs_normalised.clone() # reset environments that are done done_indices = np.argwhere(done.flatten()).flatten() if len(done_indices) == self.args.num_processes: [next_obs_raw, next_obs_normalised] = self.envs.reset() if not self.args.sample_embeddings: latent_sample = latent_sample else: for i in done_indices: [next_obs_raw[i], next_obs_normalised[i]] = self.envs.reset(index=i) if not self.args.sample_embeddings: latent_sample[i] = latent_sample[i] # add experience to policy buffer self.policy_storage.insert( obs_raw=next_obs_raw.clone(), obs_normalised=next_obs_normalised.clone(), actions=action.clone(), action_log_probs=action_log_prob.clone(), rewards_raw=rew_raw.clone(), rewards_normalised=rew_normalised.clone(), value_preds=value.clone(), masks=masks_done.clone(), bad_masks=bad_masks.clone(), done=torch.from_numpy(np.array( done, dtype=float)).unsqueeze(1).clone(), ) prev_obs_normalised = next_obs_normalised prev_obs_raw = next_obs_raw self.frames += self.args.num_processes # --- UPDATE --- train_stats = self.update(prev_obs_normalised if self.args. norm_obs_for_policy else prev_obs_raw) # log run_stats = [action, action_log_prob, value] if train_stats is not None: self.log(run_stats, train_stats, start_time) # clean up after update self.policy_storage.after_update() def get_value(self, obs): obs = utl.get_augmented_obs(args=self.args, obs=obs) return self.policy.actor_critic.get_value(obs).detach() def update(self, obs): """ Meta-update. Here the policy is updated for good average performance across tasks. :return: policy_train_stats which are: value_loss_epoch, action_loss_epoch, dist_entropy_epoch, loss_epoch """ # bootstrap next value prediction with torch.no_grad(): next_value = self.get_value(obs) # compute returns for current rollouts self.policy_storage.compute_returns( next_value, self.args.policy_use_gae, self.args.policy_gamma, self.args.policy_tau, use_proper_time_limits=self.args.use_proper_time_limits) policy_train_stats = self.policy.update( args=self.args, policy_storage=self.policy_storage) return policy_train_stats, None def log(self, run_stats, train_stats, start): """ Evaluate policy, save model, write to tensorboard logger. """ train_stats, meta_train_stats = train_stats # --- visualise behaviour of policy --- if self.iter_idx % self.args.vis_interval == 0: obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None utl_eval.visualise_behaviour( args=self.args, policy=self.policy, image_folder=self.logger.full_output_folder, iter_idx=self.iter_idx, obs_rms=obs_rms, ret_rms=ret_rms, ) # --- evaluate policy ---- if self.iter_idx % self.args.eval_interval == 0: obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None returns_per_episode = utl_eval.evaluate(args=self.args, policy=self.policy, obs_rms=obs_rms, ret_rms=ret_rms, iter_idx=self.iter_idx) # log the average return across tasks (=processes) returns_avg = returns_per_episode.mean(dim=0) returns_std = returns_per_episode.std(dim=0) for k in range(len(returns_avg)): self.logger.add('return_avg_per_iter/episode_{}'.format(k + 1), returns_avg[k], self.iter_idx) self.logger.add( 'return_avg_per_frame/episode_{}'.format(k + 1), returns_avg[k], self.frames) self.logger.add('return_std_per_iter/episode_{}'.format(k + 1), returns_std[k], self.iter_idx) self.logger.add( 'return_std_per_frame/episode_{}'.format(k + 1), returns_std[k], self.frames) print( "Updates {}, num timesteps {}, FPS {} \n Mean return (train): {:.5f} \n" .format(self.iter_idx, self.frames, int(self.frames / (time.time() - start)), returns_avg[-1].item())) # save model if self.iter_idx % self.args.save_interval == 0: save_path = os.path.join(self.logger.full_output_folder, 'models') if not os.path.exists(save_path): os.mkdir(save_path) torch.save( self.policy.actor_critic, os.path.join(save_path, "policy{0}.pt".format(self.iter_idx))) # save normalisation params of envs if self.args.norm_rew_for_policy: # save rolling mean and std rew_rms = self.envs.venv.ret_rms utl.save_obj(rew_rms, save_path, "env_rew_rms{0}.pkl".format(self.iter_idx)) if self.args.norm_obs_for_policy: obs_rms = self.envs.venv.obs_rms utl.save_obj(obs_rms, save_path, "env_obs_rms{0}.pkl".format(self.iter_idx)) # --- log some other things --- if self.iter_idx % self.args.log_interval == 0: self.logger.add('policy_losses/value_loss', train_stats[0], self.iter_idx) self.logger.add('policy_losses/action_loss', train_stats[1], self.iter_idx) self.logger.add('policy_losses/dist_entropy', train_stats[2], self.iter_idx) self.logger.add('policy_losses/sum', train_stats[3], self.iter_idx) # writer.add_scalar('policy/action', action.mean(), j) self.logger.add('policy/action', run_stats[0][0].float().mean(), self.iter_idx) if hasattr(self.policy.actor_critic, 'logstd'): self.logger.add('policy/action_logstd', self.policy.actor_critic.dist.logstd.mean(), self.iter_idx) self.logger.add('policy/action_logprob', run_stats[1].mean(), self.iter_idx) self.logger.add('policy/value', run_stats[2].mean(), self.iter_idx) param_list = list(self.policy.actor_critic.parameters()) param_mean = np.mean( [param_list[i].data.mean() for i in range(len(param_list))]) param_grad_mean = np.mean( [param_list[i].grad.mean() for i in range(len(param_list))]) self.logger.add('weights/policy', param_mean, self.iter_idx) self.logger.add('weights/policy_std', param_list[0].data.mean(), self.iter_idx) self.logger.add('gradients/policy', param_grad_mean, self.iter_idx) self.logger.add('gradients/policy_std', param_list[0].grad.mean(), self.iter_idx)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='coinrun', help='name of the environment to train on.') parser.add_argument('--model', type=str, default='ppo', help='the model to use for training. {ppo, ppo_aup}') args, rest_args = parser.parse_known_args() env_name = args.env_name model = args.model # --- ARGUMENTS --- if model == 'ppo': args = args_ppo.get_args(rest_args) elif model == 'ppo_aup': args = args_ppo_aup.get_args(rest_args) else: raise NotImplementedError # place other args back into argparse.Namespace args.env_name = env_name args.model = model # warnings if args.deterministic_execution: print('Envoking deterministic code execution.') if torch.backends.cudnn.enabled: warnings.warn('Running with deterministic CUDNN.') if args.num_processes > 1: raise RuntimeError( 'If you want fully deterministic code, run it with num_processes=1.' 'Warning: This will slow things down and might break A2C if ' 'policy_num_steps < env._max_episode_steps.') # --- TRAINING --- print("Setting up wandb logging.") # Weights & Biases logger if args.run_name is None: # make run name as {env_name}_{TIME} now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S') args.run_name = args.env_name + '_' + args.algo + now # initialise wandb wandb.init(project=args.proj_name, name=args.run_name, group=args.group_name, config=args, monitor_gym=False) # save wandb dir path args.run_dir = wandb.run.dir # make directory for saving models save_dir = os.path.join(wandb.run.dir, 'models') if not os.path.exists(save_dir): os.makedirs(save_dir) # set random seed of random, torch and numpy utl.set_global_seed(args.seed, args.deterministic_execution) print("Setting up Environments.") # initialise environments for training train_envs = make_vec_envs(env_name=args.env_name, start_level=args.train_start_level, num_levels=args.train_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # initialise environments for evaluation eval_envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) _ = eval_envs.reset() print("Setting up Actor-Critic model and Training algorithm.") # initialise policy network actor_critic = ACModel(obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space, hidden_size=args.hidden_size).to(device) # initialise policy training algorithm if args.algo == 'ppo': policy = PPO(actor_critic=actor_critic, ppo_epoch=args.policy_ppo_epoch, num_mini_batch=args.policy_num_mini_batch, clip_param=args.policy_clip_param, value_loss_coef=args.policy_value_loss_coef, entropy_coef=args.policy_entropy_coef, max_grad_norm=args.policy_max_grad_norm, lr=args.policy_lr, eps=args.policy_eps) else: raise NotImplementedError # initialise rollout storage for the policy training algorithm rollouts = RolloutStorage(num_steps=args.policy_num_steps, num_processes=args.num_processes, obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space) # initialise Q_aux function(s) for AUP if args.use_aup: print("Initialising Q_aux models.") q_aux = [ QModel(obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space, hidden_size=args.hidden_size).to(device) for _ in range(args.num_q_aux) ] if args.num_q_aux == 1: # load weights to model path = args.q_aux_dir + "0.pt" q_aux[0].load_state_dict(torch.load(path)) q_aux[0].eval() else: # get max number of q_aux functions to choose from args.max_num_q_aux = os.listdir(args.q_aux_dir) q_aux_models = random.sample(list(range(0, args.max_num_q_aux)), args.num_q_aux) # load weights to models for i, model in enumerate(q_aux): path = args.q_aux_dir + str(q_aux_models[i]) + ".pt" model.load_state_dict(torch.load(path)) model.eval() # count number of frames and updates frames = 0 iter_idx = 0 # update wandb args wandb.config.update(args) update_start_time = time.time() # reset environments obs = train_envs.reset() # obs.shape = (num_processes,C,H,W) # insert initial observation to rollout storage rollouts.obs[0].copy_(obs) rollouts.to(device) # initialise buffer for calculating mean episodic returns episode_info_buf = deque(maxlen=10) # calculate number of updates # number of frames ÷ number of policy steps before update ÷ number of processes args.num_batch = args.num_processes * args.policy_num_steps args.num_updates = int(args.num_frames) // args.num_batch # define AUP coefficient if args.use_aup: aup_coef = args.aup_coef_start aup_linear_increase_val = math.exp( math.log(args.aup_coef_end / args.aup_coef_start) / args.num_updates) print("Training beginning.") print("Number of updates: ", args.num_updates) for iter_idx in range(args.num_updates): print("Iter: ", iter_idx) # put actor-critic into train mode actor_critic.train() if args.use_aup: aup_measures = defaultdict(list) # rollout policy to collect num_batch of experience and place in storage for step in range(args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action, action_log_prob = actor_critic.act( rollouts.obs[step]) # observe rewards and next obs obs, reward, done, infos = train_envs.step(action) # calculate AUP reward if args.use_aup: intrinsic_reward = torch.zeros_like(reward) with torch.no_grad(): for model in q_aux: # get action-values action_values = model.get_action_value( rollouts.obs[step]) # get action-value for action taken action_value = torch.sum( action_values * torch.nn.functional.one_hot( action, num_classes=train_envs.action_space.n).squeeze( dim=1), dim=1) # calculate the penalty intrinsic_reward += torch.abs( action_value.unsqueeze(dim=1) - action_values[:, 4].unsqueeze(dim=1)) intrinsic_reward /= args.num_q_aux # add intrinsic reward to the extrinsic reward reward -= aup_coef * intrinsic_reward # log the intrinsic reward from the first env. aup_measures['intrinsic_reward'].append(aup_coef * intrinsic_reward[0, 0]) if done[0] and infos[0]['prev_level_complete'] == 1: aup_measures['episode_complete'].append(2) elif done[0] and infos[0]['prev_level_complete'] == 0: aup_measures['episode_complete'].append(1) else: aup_measures['episode_complete'].append(0) # log episode info if episode finished for info in infos: if 'episode' in info.keys(): episode_info_buf.append(info['episode']) # create mask for episode ends masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # add experience to storage rollouts.insert(obs, reward, action, value, action_log_prob, masks) frames += args.num_processes # linearly increase aup coefficient after every update if args.use_aup: aup_coef *= aup_linear_increase_val # --- UPDATE --- # bootstrap next value prediction with torch.no_grad(): next_value = actor_critic.get_value(rollouts.obs[-1]).detach() # compute returns for current rollouts rollouts.compute_returns(next_value, args.policy_gamma, args.policy_gae_lambda) # update actor-critic using policy training algorithm total_loss, value_loss, action_loss, dist_entropy = policy.update( rollouts) # clean up storage after update rollouts.after_update() # --- LOGGING --- if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1: # --- EVALUATION --- eval_episode_info_buf = utl_eval.evaluate( eval_envs=eval_envs, actor_critic=actor_critic, device=device) # get stats for run update_end_time = time.time() num_interval_updates = 1 if iter_idx == 0 else args.log_interval fps = num_interval_updates * ( args.num_processes * args.policy_num_steps) / (update_end_time - update_start_time) update_start_time = update_end_time # calculates whether the value function is a good predicator of the returns (ev > 1) # or if it's just worse than predicting nothing (ev =< 0) ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds), utl.sf01(rollouts.returns)) if args.use_aup: step = frames - args.num_processes * args.policy_num_steps for i in range(args.policy_num_steps): wandb.log( { 'aup/intrinsic_reward': aup_measures['intrinsic_reward'][i], 'aup/episode_complete': aup_measures['episode_complete'][i] }, step=step) step += args.num_processes wandb.log( { 'misc/timesteps': frames, 'misc/fps': fps, 'misc/explained_variance': float(ev), 'losses/total_loss': total_loss, 'losses/value_loss': value_loss, 'losses/action_loss': action_loss, 'losses/dist_entropy': dist_entropy, 'train/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in episode_info_buf ]), 'train/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in episode_info_buf ]), 'eval/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in eval_episode_info_buf ]), 'eval/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in eval_episode_info_buf ]) }, step=frames) # --- SAVE MODEL --- # save for every interval-th episode or for the last epoch if iter_idx != 0 and (iter_idx % args.save_interval == 0 or iter_idx == args.num_updates - 1): print("Saving Actor-Critic Model.") torch.save(actor_critic.state_dict(), os.path.join(save_dir, "policy{0}.pt".format(iter_idx))) # close envs train_envs.close() eval_envs.close() # --- TEST --- if args.test: print("Testing beginning.") episodic_return = utl_test.test(args=args, actor_critic=actor_critic, device=device) # save returns from train and test levels to analyse using interactive mode train_levels = torch.arange( args.train_start_level, args.train_start_level + args.train_num_levels) for i, level in enumerate(train_levels): wandb.log({ 'test/train_levels': level, 'test/train_returns': episodic_return[0][i] }) test_levels = torch.arange( args.test_start_level, args.test_start_level + args.test_num_levels) for i, level in enumerate(test_levels): wandb.log({ 'test/test_levels': level, 'test/test_returns': episodic_return[1][i] }) # log returns from test envs wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean( episodic_return[0]) wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean( episodic_return[1])
x = tf.layers.conv2d(x, 64, 3, 1, activation=tf.nn.relu) x = tf.contrib.layers.flatten(x) x = tf.layers.dense(x, 512, activation=tf.nn.relu) logit_action_probability = tf.layers.dense( x, action_space, kernel_initializer=tf.truncated_normal_initializer(0.0, 0.01)) state_value = tf.squeeze( tf.layers.dense( x, 1, kernel_initializer=tf.truncated_normal_initializer())) return logit_action_probability, state_value ppo = PPO(action_space, obs_fn, model_fn, train_epoch=5, batch_size=64, save_path='./raiden2_model') env = Raiden2(6666, num_envs=1, with_stack=True) env_ids, states, rewards, dones = env.start() nth_trajectory = 0 while True: nth_trajectory += 1 for _ in tqdm(range(explore_steps)): actions = ppo.get_action(np.asarray(states)) actions = [(action, 4) for action in actions] env_ids, states, rewards, dones = env.step(env_ids, actions) logging.info(f'>>>>{env.mean_reward}, nth_trajectory{nth_trajectory}')
def main(): parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='coinrun', help='name of the environment to train on.') parser.add_argument( '--model', type=str, default='ppo', help='the model to use for training. {ppo, ibac, ibac_sni, dist_match}' ) args, rest_args = parser.parse_known_args() env_name = args.env_name model = args.model # --- ARGUMENTS --- if model == 'ppo': args = args_ppo.get_args(rest_args) elif model == 'ibac': args = args_ibac.get_args(rest_args) elif model == 'ibac_sni': args = args_ibac_sni.get_args(rest_args) elif model == 'dist_match': args = args_dist_match.get_args(rest_args) else: raise NotImplementedError # place other args back into argparse.Namespace args.env_name = env_name args.model = model args.num_train_envs = args.num_processes - args.num_val_envs if args.num_val_envs > 0 else args.num_processes # warnings if args.deterministic_execution: print('Envoking deterministic code execution.') if torch.backends.cudnn.enabled: warnings.warn('Running with deterministic CUDNN.') if args.num_processes > 1: raise RuntimeError( 'If you want fully deterministic code, run it with num_processes=1.' 'Warning: This will slow things down and might break A2C if ' 'policy_num_steps < env._max_episode_steps.') elif args.num_val_envs > 0 and (args.num_val_envs >= args.num_processes or not args.percentage_levels_train < 1.0): raise ValueError( 'If --args.num_val_envs>0 then you must also have' '--num_val_envs < --num_processes and 0 < --percentage_levels_train < 1.' ) elif args.num_val_envs > 0 and not args.use_dist_matching and args.dist_matching_coef != 0: raise ValueError( 'If --num_val_envs>0 and --use_dist_matching=False then you must also have' '--dist_matching_coef=0.') elif args.use_dist_matching and not args.num_val_envs > 0: raise ValueError( 'If --use_dist_matching=True then you must also have' '0 < --num_val_envs < --num_processes and 0 < --percentage_levels_train < 1.' ) elif args.analyse_rep and not args.use_bottleneck: raise ValueError('If --analyse_rep=True then you must also have' '--use_bottleneck=True.') # --- TRAINING --- print("Setting up wandb logging.") # Weights & Biases logger if args.run_name is None: # make run name as {env_name}_{TIME} now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S') args.run_name = args.env_name + '_' + args.algo + now # initialise wandb wandb.init(project=args.proj_name, name=args.run_name, group=args.group_name, config=args, monitor_gym=False) # save wandb dir path args.run_dir = wandb.run.dir # make directory for saving models save_dir = os.path.join(wandb.run.dir, 'models') if not os.path.exists(save_dir): os.makedirs(save_dir) # set random seed of random, torch and numpy utl.set_global_seed(args.seed, args.deterministic_execution) # initialise environments for training print("Setting up Environments.") if args.num_val_envs > 0: train_num_levels = int(args.train_num_levels * args.percentage_levels_train) val_start_level = args.train_start_level + train_num_levels val_num_levels = args.train_num_levels - train_num_levels train_envs = make_vec_envs(env_name=args.env_name, start_level=args.train_start_level, num_levels=train_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_train_envs, num_frame_stack=args.num_frame_stack, device=device) val_envs = make_vec_envs(env_name=args.env_name, start_level=val_start_level, num_levels=val_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_val_envs, num_frame_stack=args.num_frame_stack, device=device) else: train_envs = make_vec_envs(env_name=args.env_name, start_level=args.train_start_level, num_levels=args.train_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # initialise environments for evaluation eval_envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) _ = eval_envs.reset() # initialise environments for analysing the representation if args.analyse_rep: analyse_rep_train1_envs, analyse_rep_train2_envs, analyse_rep_val_envs, analyse_rep_test_envs = make_rep_analysis_envs( args, device) print("Setting up Actor-Critic model and Training algorithm.") # initialise policy network actor_critic = ACModel(obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space, hidden_size=args.hidden_size, use_bottleneck=args.use_bottleneck, sni_type=args.sni_type).to(device) # initialise policy training algorithm if args.algo == 'ppo': policy = PPO(actor_critic=actor_critic, ppo_epoch=args.policy_ppo_epoch, num_mini_batch=args.policy_num_mini_batch, clip_param=args.policy_clip_param, value_loss_coef=args.policy_value_loss_coef, entropy_coef=args.policy_entropy_coef, max_grad_norm=args.policy_max_grad_norm, lr=args.policy_lr, eps=args.policy_eps, vib_coef=args.vib_coef, sni_coef=args.sni_coef, use_dist_matching=args.use_dist_matching, dist_matching_loss=args.dist_matching_loss, dist_matching_coef=args.dist_matching_coef, num_train_envs=args.num_train_envs, num_val_envs=args.num_val_envs) else: raise NotImplementedError # initialise rollout storage for the policy training algorithm rollouts = RolloutStorage(num_steps=args.policy_num_steps, num_processes=args.num_processes, obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space) # count number of frames and updates frames = 0 iter_idx = 0 # update wandb args wandb.config.update(args) # wandb.watch(actor_critic, log="all") # to log gradients of actor-critic network update_start_time = time.time() # reset environments if args.num_val_envs > 0: obs = torch.cat([train_envs.reset(), val_envs.reset()]) # obs.shape = (n_envs,C,H,W) else: obs = train_envs.reset() # obs.shape = (n_envs,C,H,W) # insert initial observation to rollout storage rollouts.obs[0].copy_(obs) rollouts.to(device) # initialise buffer for calculating mean episodic returns train_episode_info_buf = deque(maxlen=10) val_episode_info_buf = deque(maxlen=10) # calculate number of updates # number of frames ÷ number of policy steps before update ÷ number of processes args.num_batch = args.num_processes * args.policy_num_steps args.num_updates = int(args.num_frames) // args.num_batch print("Training beginning.") print("Number of updates: ", args.num_updates) for iter_idx in range(args.num_updates): print("Iter: ", iter_idx) # put actor-critic into train mode actor_critic.train() # rollout policy to collect num_batch of experience and place in storage for step in range(args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action, action_log_prob, _ = actor_critic.act( rollouts.obs[step]) # observe rewards and next obs if args.num_val_envs > 0: obs, reward, done, infos = train_envs.step( action[:args.num_train_envs, :]) val_obs, val_reward, val_done, val_infos = val_envs.step( action[args.num_train_envs:, :]) obs = torch.cat([obs, val_obs]) reward = torch.cat([reward, val_reward]) done, val_done = list(done), list(val_done) done.extend(val_done) infos.extend(val_infos) else: obs, reward, done, infos = train_envs.step(action) # log episode info if episode finished for i, info in enumerate(infos): if i < args.num_train_envs and 'episode' in info.keys(): train_episode_info_buf.append(info['episode']) elif i >= args.num_train_envs and 'episode' in info.keys(): val_episode_info_buf.append(info['episode']) # create mask for episode ends masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # add experience to storage rollouts.insert(obs, reward, action, value, action_log_prob, masks) frames += args.num_processes # --- UPDATE --- # bootstrap next value prediction with torch.no_grad(): next_value = actor_critic.get_value(rollouts.obs[-1]).detach() # compute returns for current rollouts rollouts.compute_returns(next_value, args.policy_gamma, args.policy_gae_lambda) # update actor-critic using policy gradient algo total_loss, value_loss, action_loss, dist_entropy, vib_kl, dist_matching_loss = policy.update( rollouts) # clean up storage after update rollouts.after_update() # --- LOGGING --- if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1: # --- EVALUATION --- eval_episode_info_buf = utl_eval.evaluate( eval_envs=eval_envs, actor_critic=actor_critic, device=device) # --- ANALYSE REPRESENTATION --- if args.analyse_rep: rep_measures = utl_rep.analyse_rep( args=args, train1_envs=analyse_rep_train1_envs, train2_envs=analyse_rep_train2_envs, val_envs=analyse_rep_val_envs, test_envs=analyse_rep_test_envs, actor_critic=actor_critic, device=device) # get stats for run update_end_time = time.time() num_interval_updates = 1 if iter_idx == 0 else args.log_interval fps = num_interval_updates * ( args.num_processes * args.policy_num_steps) / (update_end_time - update_start_time) update_start_time = update_end_time # Calculates if value function is a good predicator of the returns (ev > 1) # or if it's just worse than predicting nothing (ev =< 0) ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds), utl.sf01(rollouts.returns)) wandb.log( { 'misc/timesteps': frames, 'misc/fps': fps, 'misc/explained_variance': float(ev), 'losses/total_loss': total_loss, 'losses/value_loss': value_loss, 'losses/action_loss': action_loss, 'losses/dist_entropy': dist_entropy, 'train/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in train_episode_info_buf ]), 'train/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in train_episode_info_buf ]), 'eval/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in eval_episode_info_buf ]), 'eval/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in eval_episode_info_buf ]) }, step=iter_idx) if args.use_bottleneck: wandb.log({'losses/vib_kl': vib_kl}, step=iter_idx) if args.num_val_envs > 0: wandb.log( { 'losses/dist_matching_loss': dist_matching_loss, 'val/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in val_episode_info_buf ]), 'val/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in val_episode_info_buf ]) }, step=iter_idx) if args.analyse_rep: wandb.log( { "analysis/" + key: val for key, val in rep_measures.items() }, step=iter_idx) # --- SAVE MODEL --- # save for every interval-th episode or for the last epoch if iter_idx != 0 and (iter_idx % args.save_interval == 0 or iter_idx == args.num_updates - 1): print("Saving Actor-Critic Model.") torch.save(actor_critic.state_dict(), os.path.join(save_dir, "policy{0}.pt".format(iter_idx))) # close envs train_envs.close() eval_envs.close() # --- TEST --- if args.test: print("Testing beginning.") episodic_return, latents_z = utl_test.test(args=args, actor_critic=actor_critic, device=device) # save returns from train and test levels to analyse using interactive mode train_levels = torch.arange( args.train_start_level, args.train_start_level + args.train_num_levels) for i, level in enumerate(train_levels): wandb.log({ 'test/train_levels': level, 'test/train_returns': episodic_return[0][i] }) test_levels = torch.arange( args.test_start_level, args.test_start_level + args.test_num_levels) for i, level in enumerate(test_levels): wandb.log({ 'test/test_levels': level, 'test/test_returns': episodic_return[1][i] }) # log returns from test envs wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean( episodic_return[0]) wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean( episodic_return[1]) # plot latent representation if args.plot_pca: print("Plotting PCA of Latent Representation.") utl_rep.pca(args, latents_z)
def train(): model_dir = Path('./models') / ENV_ID if not model_dir.exists(): curr_run = 'run1' else: exst_run_nums = [ int(str(folder.name).split('run')[1]) for folder in model_dir.iterdir() if str(folder.name).startswith('run') ] if len(exst_run_nums) == 0: curr_run = 'run1' else: curr_run = 'run%i' % (max(exst_run_nums) + 1) run_dir = model_dir / curr_run log_dir = run_dir / 'logs' os.makedirs(str(log_dir)) logger = SummaryWriter(str(log_dir), max_queue=5, flush_secs=30) torch.manual_seed(RANDOM_SEED) np.random.seed(RANDOM_SEED) env = make_parallel_env(ENV_ID, N_ROLLOUT_THREADS, RANDOM_SEED) ppo = PPO.init_from_env(env, gamma=GAMMA, lam=LAMDA, lr=LR, coeff_entropy=COEFF_ENTROPY, batch_size=BATCH_SIZE) # save_dict = torch.load(model_dir / 'run1/model.pt') # TODO init from save? # save_dict = save_dict['model_params']['policy'] # ppo.policy.load_state_dict(save_dict) buff = [] t = 0 for ep_i in range(0, N_EPISODES, N_ROLLOUT_THREADS): print("Episodes %i-%i of %i" % (ep_i + 1, ep_i + 1 + N_ROLLOUT_THREADS, N_EPISODES)) obs = env.reset() nagents = obs.shape[1] ppo.prep_rollouts(device='cpu') ep_rew = 0 act_hidden = [[ torch.zeros(N_ROLLOUT_THREADS, 128), torch.zeros(N_ROLLOUT_THREADS, 128) ] for i in range(nagents)] crt_hidden = [[ torch.zeros(N_ROLLOUT_THREADS, 128), torch.zeros(N_ROLLOUT_THREADS, 128) ] for i in range(nagents)] for et_i in range(EPISODE_LEN): """ generate actions """ torch_obs = [ Variable(torch.Tensor(np.vstack(obs[:, i])), requires_grad=False) for i in range(nagents) ] prev_act_hidden = [[h.data.cpu().numpy(), c.data.cpu().numpy()] for h, c in act_hidden] prev_crt_hidden = [[h.data.cpu().numpy(), c.data.cpu().numpy()] for h, c in crt_hidden] v_list, agent_actions_list, logprob_list, mean_list = ppo.step( torch_obs, act_hidden, crt_hidden) v_list = [a.data.cpu().numpy() for a in v_list] agent_actions_list = [ a.data.cpu().numpy() for a in agent_actions_list ] logprob_list = [a.data.cpu().numpy() for a in logprob_list] clipped_action_list = [ np.clip(a, -1, 1) for a in agent_actions_list ] actions = [[ac[i] for ac in clipped_action_list] for i in range(N_ROLLOUT_THREADS)] """ step env """ next_obs, rewards, dones, infos = env.step(actions) ep_rew += np.mean(rewards) buff.append( (obs, prev_act_hidden, prev_crt_hidden, agent_actions_list, rewards, dones, logprob_list, v_list)) obs = next_obs t += N_ROLLOUT_THREADS for i, done in enumerate(dones[0]): if done: act_hidden[i] = [ torch.zeros(N_ROLLOUT_THREADS, 128), torch.zeros(N_ROLLOUT_THREADS, 128) ] crt_hidden[i] = [ torch.zeros(N_ROLLOUT_THREADS, 128), torch.zeros(N_ROLLOUT_THREADS, 128) ] env.envs[0].agents[i].terminate = False # if dones.any(): # break print('mean reward:', ep_rew) """ train """ next_obs = [ Variable(torch.Tensor(np.vstack(obs[:, i])), requires_grad=False) for i in range(nagents) ] v_list, _, _, _ = ppo.step(next_obs, act_hidden, crt_hidden) v_list = [a.data.cpu().numpy() for a in v_list] print('updating params...') if USE_CUDA: ppo.prep_training(device='gpu') else: ppo.prep_training(device='cpu') ppo.update(buff=buff, last_v=v_list, to_gpu=USE_CUDA) ppo.prep_rollouts(device='cpu') buff = [] logger.add_scalar('mean_episode_rewards', ep_rew, ep_i) if ep_i % SAVE_INTERVAL < N_ROLLOUT_THREADS: print('saving incremental...') os.makedirs(str(run_dir / 'incremental'), exist_ok=True) ppo.save( str(run_dir / 'incremental' / ('model_ep%i.pt' % (ep_i + 1)))) ppo.save(str(run_dir / 'model.pt')) print('saving model...') ppo.save(str(run_dir / 'model.pt')) env.close() logger.export_scalars_to_json(str(log_dir / 'summary.json')) logger.close()
class MetaLearner: """ Meta-Learner class with the main training loop for variBAD. """ def __init__(self, args): self.args = args utl.seed(self.args.seed, self.args.deterministic_execution) # count number of frames and number of meta-iterations self.frames = 0 self.iter_idx = 0 # initialise tensorboard logger self.logger = TBLogger(self.args, self.args.exp_label) # initialise environments self.envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, log_dir=args.agent_log_dir, device=device, allow_early_resets=False, episodes_per_task=self.args.max_rollouts_per_task, obs_rms=None, ret_rms=None, ) # calculate what the maximum length of the trajectories is args.max_trajectory_len = self.envs._max_episode_steps args.max_trajectory_len *= self.args.max_rollouts_per_task # calculate number of meta updates self.args.num_updates = int( args.num_frames) // args.policy_num_steps // args.num_processes # get action / observation dimensions if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete): self.args.action_dim = 1 else: self.args.action_dim = self.envs.action_space.shape[0] self.args.obs_dim = self.envs.observation_space.shape[0] self.args.num_states = self.envs.num_states if str.startswith( self.args.env_name, 'Grid') else None self.args.act_space = self.envs.action_space self.vae = VaribadVAE(self.args, self.logger, lambda: self.iter_idx) self.initialise_policy() def initialise_policy(self): # initialise rollout storage for the policy self.policy_storage = OnlineStorage( self.args, self.args.policy_num_steps, self.args.num_processes, self.args.obs_dim, self.args.act_space, hidden_size=self.args.aggregator_hidden_size, latent_dim=self.args.latent_dim, normalise_observations=self.args.norm_obs_for_policy, normalise_rewards=self.args.norm_rew_for_policy, ) # initialise policy network input_dim = self.args.obs_dim * int( self.args.condition_policy_on_state) input_dim += ( 1 + int(not self.args.sample_embeddings)) * self.args.latent_dim if hasattr(self.envs.action_space, 'low'): action_low = self.envs.action_space.low action_high = self.envs.action_space.high else: action_low = action_high = None policy_net = Policy( state_dim=input_dim, action_space=self.args.act_space, init_std=self.args.policy_init_std, hidden_layers=self.args.policy_layers, activation_function=self.args.policy_activation_function, normalise_actions=self.args.normalise_actions, action_low=action_low, action_high=action_high, ).to(device) # initialise policy trainer if self.args.policy == 'a2c': self.policy = A2C( policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, optimiser_vae=self.vae.optimiser_vae, lr=self.args.lr_policy, eps=self.args.policy_eps, alpha=self.args.a2c_alpha, ) elif self.args.policy == 'ppo': self.policy = PPO( policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, optimiser_vae=self.vae.optimiser_vae, lr=self.args.lr_policy, eps=self.args.policy_eps, ppo_epoch=self.args.ppo_num_epochs, num_mini_batch=self.args.ppo_num_minibatch, use_huber_loss=self.args.ppo_use_huberloss, use_clipped_value_loss=self.args.ppo_use_clipped_value_loss, clip_param=self.args.ppo_clip_param, ) else: raise NotImplementedError def train(self): """ Given some stream of environments and a logger (tensorboard), (meta-)trains the policy. """ start_time = time.time() # reset environments (prev_obs_raw, prev_obs_normalised) = self.envs.reset() prev_obs_raw = prev_obs_raw.to(device) prev_obs_normalised = prev_obs_normalised.to(device) # insert initial observation / embeddings to rollout storage self.policy_storage.prev_obs_raw[0].copy_(prev_obs_raw) self.policy_storage.prev_obs_normalised[0].copy_(prev_obs_normalised) self.policy_storage.to(device) vae_is_pretrained = False for self.iter_idx in range(self.args.num_updates): # First, re-compute the hidden states given the current rollouts (since the VAE might've changed) # compute latent embedding (will return prior if current trajectory is empty) with torch.no_grad(): latent_sample, latent_mean, latent_logvar, hidden_state = self.encode_running_trajectory( ) # check if we flushed the policy storage assert len(self.policy_storage.latent_mean) == 0 # add this initial hidden state to the policy storage self.policy_storage.hidden_states[0].copy_(hidden_state) self.policy_storage.latent_samples.append(latent_sample.clone()) self.policy_storage.latent_mean.append(latent_mean.clone()) self.policy_storage.latent_logvar.append(latent_logvar.clone()) # rollout policies for a few steps for step in range(self.args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action, action_log_prob = utl.select_action( args=self.args, policy=self.policy, obs=prev_obs_normalised if self.args.norm_obs_for_policy else prev_obs_raw, deterministic=False, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, ) # observe reward and next obs (next_obs_raw, next_obs_normalised), ( rew_raw, rew_normalised), done, infos = utl.env_step( self.envs, action) tasks = torch.FloatTensor([info['task'] for info in infos]).to(device) done = torch.from_numpy(np.array( done, dtype=int)).to(device).float().view((-1, 1)) # create mask for episode ends masks_done = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # bad_mask is true if episode ended because time limit was reached bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]).to(device) # compute next embedding (for next loop and/or value prediction bootstrap) latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding( encoder=self.vae.encoder, next_obs=next_obs_raw, action=action, reward=rew_raw, done=done, hidden_state=hidden_state) # before resetting, update the embedding and add to vae buffer # (last state might include useful task info) if not (self.args.disable_decoder and self.args.disable_stochasticity_in_latent): self.vae.rollout_storage.insert(prev_obs_raw.clone(), action.detach().clone(), next_obs_raw.clone(), rew_raw.clone(), done.clone(), tasks.clone()) # add the obs before reset to the policy storage # (only used to recompute embeddings if rlloss is backpropagated through encoder) self.policy_storage.next_obs_raw[step] = next_obs_raw.clone() self.policy_storage.next_obs_normalised[ step] = next_obs_normalised.clone() # reset environments that are done done_indices = np.argwhere( done.cpu().detach().flatten()).flatten() if len(done_indices) == self.args.num_processes: [next_obs_raw, next_obs_normalised] = self.envs.reset() if not self.args.sample_embeddings: latent_sample = latent_sample else: for i in done_indices: [next_obs_raw[i], next_obs_normalised[i]] = self.envs.reset(index=i) if not self.args.sample_embeddings: latent_sample[i] = latent_sample[i] # # add experience to policy buffer self.policy_storage.insert( obs_raw=next_obs_raw, obs_normalised=next_obs_normalised, actions=action, action_log_probs=action_log_prob, rewards_raw=rew_raw, rewards_normalised=rew_normalised, value_preds=value, masks=masks_done, bad_masks=bad_masks, done=done, hidden_states=hidden_state.squeeze(0).detach(), latent_sample=latent_sample.detach(), latent_mean=latent_mean.detach(), latent_logvar=latent_logvar.detach(), ) prev_obs_normalised = next_obs_normalised prev_obs_raw = next_obs_raw self.frames += self.args.num_processes # --- UPDATE --- if self.args.precollect_len <= self.frames: # check if we are pre-training the VAE if self.args.pretrain_len > 0 and not vae_is_pretrained: for _ in range(self.args.pretrain_len): self.vae.compute_vae_loss(update=True) vae_is_pretrained = True # otherwise do the normal update (policy + vae) else: train_stats = self.update( obs=prev_obs_normalised if self.args.norm_obs_for_policy else prev_obs_raw, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar) # log run_stats = [action, action_log_prob, value] if train_stats is not None: self.log(run_stats, train_stats, start_time) # clean up after update self.policy_storage.after_update() def encode_running_trajectory(self): """ (Re-)Encodes (for each process) the entire current trajectory. Returns sample/mean/logvar and hidden state (if applicable) for the current timestep. :return: """ # for each process, get the current batch (zero-padded obs/act/rew + length indicators) prev_obs, next_obs, act, rew, lens = self.vae.rollout_storage.get_running_batch( ) # get embedding - will return (1+sequence_len) * batch * input_size -- includes the prior! all_latent_samples, all_latent_means, all_latent_logvars, all_hidden_states = self.vae.encoder( actions=act, states=next_obs, rewards=rew, hidden_state=None, return_prior=True) # get the embedding / hidden state of the current time step (need to do this since we zero-padded) latent_sample = (torch.stack([ all_latent_samples[lens[i]][i] for i in range(len(lens)) ])).detach().to(device) latent_mean = (torch.stack([ all_latent_means[lens[i]][i] for i in range(len(lens)) ])).detach().to(device) latent_logvar = (torch.stack([ all_latent_logvars[lens[i]][i] for i in range(len(lens)) ])).detach().to(device) hidden_state = (torch.stack([ all_hidden_states[lens[i]][i] for i in range(len(lens)) ])).detach().to(device) return latent_sample, latent_mean, latent_logvar, hidden_state def get_value(self, obs, latent_sample, latent_mean, latent_logvar): obs = utl.get_augmented_obs(self.args, obs, latent_sample, latent_mean, latent_logvar) return self.policy.actor_critic.get_value(obs).detach() def update(self, obs, latent_sample, latent_mean, latent_logvar): """ Meta-update. Here the policy is updated for good average performance across tasks. :return: """ # update policy (if we are not pre-training, have enough data in the vae buffer, and are not at iteration 0) if self.iter_idx >= self.args.pretrain_len and self.iter_idx > 0: # bootstrap next value prediction with torch.no_grad(): next_value = self.get_value(obs=obs, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar) # compute returns for current rollouts self.policy_storage.compute_returns( next_value, self.args.policy_use_gae, self.args.policy_gamma, self.args.policy_tau, use_proper_time_limits=self.args.use_proper_time_limits) # update agent (this will also call the VAE update!) policy_train_stats = self.policy.update( args=self.args, policy_storage=self.policy_storage, encoder=self.vae.encoder, rlloss_through_encoder=self.args.rlloss_through_encoder, compute_vae_loss=self.vae.compute_vae_loss) else: policy_train_stats = 0, 0, 0, 0 # pre-train the VAE if self.iter_idx < self.args.pretrain_len: self.vae.compute_vae_loss(update=True) return policy_train_stats, None def log(self, run_stats, train_stats, start_time): train_stats, meta_train_stats = train_stats # --- visualise behaviour of policy --- if self.iter_idx % self.args.vis_interval == 0: obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None utl_eval.visualise_behaviour( args=self.args, policy=self.policy, image_folder=self.logger.full_output_folder, iter_idx=self.iter_idx, obs_rms=obs_rms, ret_rms=ret_rms, encoder=self.vae.encoder, reward_decoder=self.vae.reward_decoder, state_decoder=self.vae.state_decoder, task_decoder=self.vae.task_decoder, compute_rew_reconstruction_loss=self.vae. compute_rew_reconstruction_loss, compute_state_reconstruction_loss=self.vae. compute_state_reconstruction_loss, compute_task_reconstruction_loss=self.vae. compute_task_reconstruction_loss, compute_kl_loss=self.vae.compute_kl_loss, ) # --- evaluate policy ---- if self.iter_idx % self.args.eval_interval == 0: obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None returns_per_episode = utl_eval.evaluate(args=self.args, policy=self.policy, obs_rms=obs_rms, ret_rms=ret_rms, encoder=self.vae.encoder, iter_idx=self.iter_idx) # log the return avg/std across tasks (=processes) returns_avg = returns_per_episode.mean(dim=0) returns_std = returns_per_episode.std(dim=0) for k in range(len(returns_avg)): self.logger.add('return_avg_per_iter/episode_{}'.format(k + 1), returns_avg[k], self.iter_idx) self.logger.add( 'return_avg_per_frame/episode_{}'.format(k + 1), returns_avg[k], self.frames) self.logger.add('return_std_per_iter/episode_{}'.format(k + 1), returns_std[k], self.iter_idx) self.logger.add( 'return_std_per_frame/episode_{}'.format(k + 1), returns_std[k], self.frames) print( "Updates {}, num timesteps {}, FPS {}, {} \n Mean return (train): {:.5f} \n" .format(self.iter_idx, self.frames, int(self.frames / (time.time() - start_time)), self.vae.rollout_storage.prev_obs.shape, returns_avg[-1].item())) # --- save models --- if self.iter_idx % self.args.save_interval == 0: save_path = os.path.join(self.logger.full_output_folder, 'models') if not os.path.exists(save_path): os.mkdir(save_path) torch.save( self.policy.actor_critic, os.path.join(save_path, "policy{0}.pt".format(self.iter_idx))) torch.save( self.vae.encoder, os.path.join(save_path, "encoder{0}.pt".format(self.iter_idx))) if self.vae.state_decoder is not None: torch.save( self.vae.state_decoder, os.path.join(save_path, "state_decoder{0}.pt".format(self.iter_idx))) if self.vae.reward_decoder is not None: torch.save( self.vae.reward_decoder, os.path.join(save_path, "reward_decoder{0}.pt".format(self.iter_idx))) if self.vae.task_decoder is not None: torch.save( self.vae.task_decoder, os.path.join(save_path, "task_decoder{0}.pt".format(self.iter_idx))) # save normalisation params of envs if self.args.norm_rew_for_policy: # save rolling mean and std rew_rms = self.envs.venv.ret_rms utl.save_obj(rew_rms, save_path, "env_rew_rms{0}.pkl".format(self.iter_idx)) if self.args.norm_obs_for_policy: obs_rms = self.envs.venv.obs_rms utl.save_obj(obs_rms, save_path, "env_obs_rms{0}.pkl".format(self.iter_idx)) # --- log some other things --- if self.iter_idx % self.args.log_interval == 0: self.logger.add('policy_losses/value_loss', train_stats[0], self.iter_idx) self.logger.add('policy_losses/action_loss', train_stats[1], self.iter_idx) self.logger.add('policy_losses/dist_entropy', train_stats[2], self.iter_idx) self.logger.add('policy_losses/sum', train_stats[3], self.iter_idx) self.logger.add('policy/action', run_stats[0][0].float().mean(), self.iter_idx) if hasattr(self.policy.actor_critic, 'logstd'): self.logger.add('policy/action_logstd', self.policy.actor_critic.dist.logstd.mean(), self.iter_idx) self.logger.add('policy/action_logprob', run_stats[1].mean(), self.iter_idx) self.logger.add('policy/value', run_stats[2].mean(), self.iter_idx) self.logger.add('encoder/latent_mean', torch.cat(self.policy_storage.latent_mean).mean(), self.iter_idx) self.logger.add( 'encoder/latent_logvar', torch.cat(self.policy_storage.latent_logvar).mean(), self.iter_idx) # log the average weights and gradients of all models (where applicable) for [model, name ] in [[self.policy.actor_critic, 'policy'], [self.vae.encoder, 'encoder'], [self.vae.reward_decoder, 'reward_decoder'], [self.vae.state_decoder, 'state_transition_decoder'], [self.vae.task_decoder, 'task_decoder']]: if model is not None: param_list = list(model.parameters()) param_mean = np.mean([ param_list[i].data.cpu().numpy().mean() for i in range(len(param_list)) ]) self.logger.add('weights/{}'.format(name), param_mean, self.iter_idx) if name == 'policy': self.logger.add('weights/policy_std', param_list[0].data.mean(), self.iter_idx) if param_list[0].grad is not None: param_grad_mean = np.mean([ param_list[i].grad.cpu().numpy().mean() for i in range(len(param_list)) ]) self.logger.add('gradients/{}'.format(name), param_grad_mean, self.iter_idx) def load_and_render(self, load_iter): #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahJoint-v0/varibad_73__15:05_17:14:07', 'models') #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/hfield', 'models') save_path = os.path.join( '/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahBlocks-v0/varibad_73__15:05_20:20:25', 'models') self.policy.actor_critic = torch.load( os.path.join(save_path, "policy{0}.pt".format(load_iter))) self.vae.encoder = torch.load( os.path.join(save_path, "encoder{0}.pt").format(load_iter)) args = self.args device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") num_processes = 1 num_episodes = 100 num_steps = 1999 #import pdb; pdb.set_trace() # initialise environments envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=num_processes, # 1 gamma=args.policy_gamma, log_dir=args.agent_log_dir, device=device, allow_early_resets=False, episodes_per_task=self.args.max_rollouts_per_task, obs_rms=None, ret_rms=None, ) # reset latent state to prior latent_sample, latent_mean, latent_logvar, hidden_state = self.vae.encoder.prior( num_processes) for episode_idx in range(num_episodes): (prev_obs_raw, prev_obs_normalised) = envs.reset() prev_obs_raw = prev_obs_raw.to(device) prev_obs_normalised = prev_obs_normalised.to(device) for step_idx in range(num_steps): with torch.no_grad(): _, action, _ = utl.select_action( args=self.args, policy=self.policy, obs=prev_obs_normalised if self.args.norm_obs_for_policy else prev_obs_raw, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, deterministic=True) # observe reward and next obs (next_obs_raw, next_obs_normalised), ( rew_raw, rew_normalised), done, infos = utl.env_step(envs, action) # render envs.venv.venv.envs[0].env.env.env.env.render() # update the hidden state latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding( encoder=self.vae.encoder, next_obs=next_obs_raw, action=action, reward=rew_raw, done=None, hidden_state=hidden_state) prev_obs_normalised = next_obs_normalised prev_obs_raw = next_obs_raw if done[0]: break
def main(): args = get_args() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True log_dir = os.path.expanduser(args.log_dir) eval_log_dir = log_dir + "_eval" utils.cleanup_log_dir(log_dir) utils.cleanup_log_dir(eval_log_dir) torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") envs = make_vec_envs(args.env_name, args.seed, args.num_processes, args.gamma, args.log_dir, device, False, no_obs_norm=args.no_obs_norm) if args.multi_action_head: head_infos = envs.get_attr("head_infos")[0] autoregressive_maps = envs.get_attr("autoregressive_maps")[0] action_type_masks = torch.tensor(envs.get_attr("action_type_masks")[0], dtype=torch.float32, device=device) action_heads = MultiActionHeads(head_infos, autoregressive_maps, action_type_masks, input_dim=args.hidden_size) actor_critic = MultiHeadPolicy(envs.observation_space.shape, action_heads, use_action_masks=args.use_action_masks, base_kwargs={ 'recurrent': args.recurrent_policy, 'recurrent_type': args.recurrent_type, 'hidden_size': args.hidden_size }) else: actor_critic = Policy(envs.observation_space.shape, envs.action_space, use_action_masks=args.use_action_masks, base_kwargs={ 'recurrent': args.recurrent_policy, 'recurrent_type': args.recurrent_type, 'hidden_size': args.hidden_size }) actor_critic.to(device) agent = PPO(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, recompute_returns=args.recompute_returns, use_gae=args.use_gae, gamma=args.gamma, gae_lambda=args.gae_lambda) if args.multi_action_head: action_head_info = envs.get_attr("head_infos")[0] else: action_head_info = None rollouts = RolloutStorage( args.num_steps, args.num_processes, envs.observation_space.shape, action_head_info=action_head_info, action_space=envs.action_space, recurrent_hidden_state_size=actor_critic.recurrent_hidden_state_size, multi_action_head=args.multi_action_head) obs = envs.reset() if actor_critic.use_action_masks: action_masks = envs.env_method( "get_available_actions" ) #build in zip so it returns [head_1(all_envs), head_2(all_envs), ...] if args.multi_action_head: action_masks = list(zip(*action_masks)) for i in range(len(rollouts.actions)): rollouts.action_masks[i][0].copy_(torch.tensor( action_masks[i])) else: rollouts.action_masks[0].copy_( torch.tensor(action_masks, dtype=torch.float32, device=device)) rollouts.obs[0].copy_(obs) rollouts.to(device) episode_rewards = deque(maxlen=10) start = time.time() num_updates = int( args.num_env_steps) // args.num_steps // args.num_processes for j in range(num_updates): if args.use_linear_lr_decay: utils.update_linear_schedule(agent.optimizer, j, num_updates, args.lr) for step in range(args.num_steps): with torch.no_grad(): if actor_critic.is_recurrent and actor_critic.base.recurrent_type == "LSTM": recurrent_hidden_state_in = ( rollouts.recurrent_hidden_states[step], rollouts.recurrent_cell_states[step]) else: recurrent_hidden_state_in = rollouts.recurrent_hidden_states[ step] if args.multi_action_head: action_masks = [ rollouts.action_masks[i][step] for i in range(len(rollouts.actions)) ] else: action_masks = rollouts.action_masks[step] value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( rollouts.obs[step], recurrent_hidden_state_in, rollouts.masks[step], action_masks=action_masks) obs, reward, done, infos = envs.step(action) action_masks_info = [] for info in infos: if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) if actor_critic.use_action_masks: action_masks_info.append(info["available_actions"]) if actor_critic.use_action_masks: if args.multi_action_head: action_masks = list(zip(*action_masks_info)) for i in range(len(action_masks)): action_masks[i] = torch.tensor(action_masks[i], dtype=torch.float32, device=device) else: action_masks = torch.tensor(action_masks_info, dtype=torch.float32, device=device) else: action_masks = None masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, action_masks=action_masks) value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.after_update() # save for every interval-th episode or for the last epoch if (j % args.save_interval == 0 or j == num_updates - 1) and args.save_dir != "": save_path = os.path.join(args.save_dir, args.algo) try: os.makedirs(save_path) except OSError: pass torch.save([ actor_critic, getattr(utils.get_vec_normalize(envs), 'obs_rms', None) ], os.path.join(save_path, args.env_name + args.extra_id + ".pt")) if j % args.log_interval == 0 and len(episode_rewards) > 1: total_num_steps = (j + 1) * args.num_processes * args.num_steps end = time.time() print( "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n" .format(j, total_num_steps, int(total_num_steps / (end - start)), len(episode_rewards), np.mean(episode_rewards), np.median(episode_rewards), np.min(episode_rewards), np.max(episode_rewards), dist_entropy, value_loss, action_loss)) # x = tuple(actor_critic.dist.logstd._bias.squeeze().detach().cpu().numpy()) # print(("action std's: ["+', '.join(['%.2f']*len(x))+"]") % tuple([np.exp(a) for a in x])) if (args.eval_interval is not None and len(episode_rewards) > 1 and j % args.eval_interval == 0): if args.no_obs_norm == False: obs_rms = utils.get_vec_normalize(envs).obs_rms else: obs_rms = None
def initialise_policy(self): # initialise rollout storage for the policy self.policy_storage = OnlineStorage( self.args, self.args.policy_num_steps, self.args.num_processes, self.args.obs_dim, self.args.act_space, hidden_size=self.args.aggregator_hidden_size, latent_dim=self.args.latent_dim, normalise_observations=self.args.norm_obs_for_policy, normalise_rewards=self.args.norm_rew_for_policy, ) # initialise policy network input_dim = self.args.obs_dim * int( self.args.condition_policy_on_state) input_dim += ( 1 + int(not self.args.sample_embeddings)) * self.args.latent_dim if hasattr(self.envs.action_space, 'low'): action_low = self.envs.action_space.low action_high = self.envs.action_space.high else: action_low = action_high = None policy_net = Policy( state_dim=input_dim, action_space=self.args.act_space, init_std=self.args.policy_init_std, hidden_layers=self.args.policy_layers, activation_function=self.args.policy_activation_function, normalise_actions=self.args.normalise_actions, action_low=action_low, action_high=action_high, ).to(device) # initialise policy trainer if self.args.policy == 'a2c': self.policy = A2C( policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, optimiser_vae=self.vae.optimiser_vae, lr=self.args.lr_policy, eps=self.args.policy_eps, alpha=self.args.a2c_alpha, ) elif self.args.policy == 'ppo': self.policy = PPO( policy_net, self.args.policy_value_loss_coef, self.args.policy_entropy_coef, optimiser_vae=self.vae.optimiser_vae, lr=self.args.lr_policy, eps=self.args.policy_eps, ppo_epoch=self.args.ppo_num_epochs, num_mini_batch=self.args.ppo_num_minibatch, use_huber_loss=self.args.ppo_use_huberloss, use_clipped_value_loss=self.args.ppo_use_clipped_value_loss, clip_param=self.args.ppo_clip_param, ) else: raise NotImplementedError
def main(_seed, _config, _run): args = init(_seed, _config, _run, post_config=post_config) env_name = args.env_name dummy_env = make_env(env_name, render=False) cleanup_log_dir(args.log_dir) try: os.makedirs(args.save_dir) except OSError: pass torch.set_num_threads(1) envs = make_vec_envs(env_name, args.seed, args.num_processes, args.log_dir) obs_shape = envs.observation_space.shape obs_shape = (obs_shape[0], *obs_shape[1:]) if args.load_saved_controller: best_model = "{}_best.pt".format(env_name) model_path = os.path.join(current_dir, "models", best_model) print("Loading model {}".format(best_model)) actor_critic = torch.load(model_path) else: if args.mirror_method == MirrorMethods.net2: controller = SymmetricNetV2( *dummy_env.unwrapped.mirror_sizes, num_layers=6, hidden_size=256, tanh_finish=True ) else: controller = SoftsignActor(dummy_env) if args.mirror_method == MirrorMethods.net: controller = SymmetricNet(controller, *dummy_env.unwrapped.sym_act_inds) actor_critic = Policy(controller) if args.sym_value_net: actor_critic.critic = SymmetricVNet( actor_critic.critic, controller.state_dim ) mirror_function = None if ( args.mirror_method == MirrorMethods.traj or args.mirror_method == MirrorMethods.loss ): indices = dummy_env.unwrapped.get_mirror_indices() mirror_function = get_mirror_function(indices) if args.cuda: actor_critic.cuda() agent = PPO(actor_critic, mirror_function=mirror_function, **args.ppo_params) rollouts = RolloutStorage( args.num_steps, args.num_processes, obs_shape, envs.action_space.shape[0], actor_critic.state_size, ) current_obs = torch.zeros(args.num_processes, *obs_shape) def update_current_obs(obs): shape_dim0 = envs.observation_space.shape[0] obs = torch.from_numpy(obs).float() current_obs[:, -shape_dim0:] = obs obs = envs.reset() update_current_obs(obs) rollouts.observations[0].copy_(current_obs) if args.cuda: current_obs = current_obs.cuda() rollouts.cuda() episode_rewards = deque(maxlen=args.num_processes) num_updates = int(args.num_frames) // args.num_steps // args.num_processes start = time.time() next_checkpoint = args.save_every max_ep_reward = float("-inf") logger = ConsoleCSVLogger( log_dir=args.experiment_dir, console_log_interval=args.log_interval ) for j in range(num_updates): if args.lr_decay_type == "linear": scheduled_lr = linear_decay(j, num_updates, args.lr, final_value=0) elif args.lr_decay_type == "exponential": scheduled_lr = exponential_decay(j, 0.99, args.lr, final_value=3e-5) else: scheduled_lr = args.lr set_optimizer_lr(agent.optimizer, scheduled_lr) for step in range(args.num_steps): # Sample actions with torch.no_grad(): value, action, action_log_prob, states = actor_critic.act( rollouts.observations[step], rollouts.states[step], rollouts.masks[step], ) cpu_actions = action.squeeze(1).cpu().numpy() obs, reward, done, infos = envs.step(cpu_actions) reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float() bad_masks = np.ones((args.num_processes, 1)) for p_index, info in enumerate(infos): keys = info.keys() # This information is added by algorithms.utils.TimeLimitMask if "bad_transition" in keys: bad_masks[p_index] = 0.0 # This information is added by baselines.bench.Monitor if "episode" in keys: episode_rewards.append(info["episode"]["r"]) masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.from_numpy(bad_masks) update_current_obs(obs) rollouts.insert( current_obs, states, action, action_log_prob, value, reward, masks, bad_masks, ) with torch.no_grad(): next_value = actor_critic.get_value( rollouts.observations[-1], rollouts.states[-1], rollouts.masks[-1] ).detach() rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda) value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.after_update() frame_count = (j + 1) * args.num_steps * args.num_processes if ( frame_count >= next_checkpoint or j == num_updates - 1 ) and args.save_dir != "": model_name = "{}_{:d}.pt".format(env_name, int(next_checkpoint)) next_checkpoint += args.save_every else: model_name = "{}_latest.pt".format(env_name) # A really ugly way to save a model to CPU save_model = actor_critic if args.cuda: save_model = copy.deepcopy(actor_critic).cpu() drive=1 if drive: #print("save") torch.save(save_model, os.path.join("/content/gdrive/My Drive/darwin", model_name)) torch.save(save_model, os.path.join(args.save_dir, model_name)) if len(episode_rewards) > 1 and np.mean(episode_rewards) > max_ep_reward: model_name = "{}_best.pt".format(env_name) max_ep_reward = np.mean(episode_rewards) drive=1 if drive: #print("max_ep_reward",max_ep_reward) torch.save(save_model, os.path.join("/content/gdrive/My Drive/darwin", model_name)) torch.save(save_model, os.path.join(args.save_dir, model_name)) if len(episode_rewards) > 1: end = time.time() total_num_steps = (j + 1) * args.num_processes * args.num_steps logger.log_epoch( { "iter": j + 1, "total_num_steps": total_num_steps, "fps": int(total_num_steps / (end - start)), "entropy": dist_entropy, "value_loss": value_loss, "action_loss": action_loss, "stats": {"rew": episode_rewards}, } )
def main(): # --- ARGUMENTS --- parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='coinrun', help='name of the environment to train on.') parser.add_argument('--model', type=str, default='ppo', help='the model to use for training.') args, rest_args = parser.parse_known_args() env_name = args.env_name model = args.model # get arguments args = args_pretrain_aup.get_args(rest_args) # place other args back into argparse.Namespace args.env_name = env_name args.model = model # Weights & Biases logger if args.run_name is None: # make run name as {env_name}_{TIME} now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S') args.run_name = args.env_name + '_' + args.algo + now # initialise wandb wandb.init(name=args.run_name, project=args.proj_name, group=args.group_name, config=args, monitor_gym=False) # save wandb dir path args.run_dir = wandb.run.dir wandb.config.update(args) # set random seed of random, torch and numpy utl.set_global_seed(args.seed, args.deterministic_execution) # --- OBTAIN DATA FOR TRAINING R_aux --- print("Gathering data for R_aux Model.") # gather observations for pretraining the auxiliary reward function (CB-VAE) envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # number of frames ÷ number of policy steps before update ÷ number of cpu processes num_batch = args.num_processes * args.policy_num_steps num_updates = int(args.num_frames_r_aux) // num_batch # create list to store env observations obs_data = torch.zeros(num_updates * args.policy_num_steps + 1, args.num_processes, *envs.observation_space.shape) # reset environments obs = envs.reset() # obs.shape = (n_env,C,H,W) obs_data[0].copy_(obs) obs = obs.to(device) for iter_idx in range(num_updates): # rollout policy to collect num_batch of experience and store in storage for step in range(args.policy_num_steps): # sample actions from random agent action = torch.randint(0, envs.action_space.n, (args.num_processes, 1)) # observe rewards and next obs obs, reward, done, infos = envs.step(action) # store obs obs_data[1 + iter_idx * args.policy_num_steps + step].copy_(obs) # close envs envs.close() # --- TRAIN R_aux (CB-VAE) --- # define CB-VAE where the encoder will be used as the auxiliary reward function R_aux print("Training R_aux Model.") # create dataloader for observations gathered obs_data = obs_data.reshape(-1, *envs.observation_space.shape) sampler = BatchSampler(SubsetRandomSampler(range(obs_data.size(0))), args.cb_vae_batch_size, drop_last=False) # initialise CB-VAE cb_vae = CBVAE(obs_shape=envs.observation_space.shape, latent_dim=args.cb_vae_latent_dim).to(device) # optimiser optimiser = torch.optim.Adam(cb_vae.parameters(), lr=args.cb_vae_learning_rate) # put CB-VAE into train mode cb_vae.train() measures = defaultdict(list) for epoch in range(args.cb_vae_epochs): print("Epoch: ", epoch) start_time = time.time() batch_loss = 0 for indices in sampler: obs = obs_data[indices].to(device) # zero accumulated gradients cb_vae.zero_grad() # forward pass through CB-VAE recon_batch, mu, log_var = cb_vae(obs) # calculate loss loss = cb_vae_loss(recon_batch, obs, mu, log_var) # backpropogation: calculating gradients loss.backward() # update parameters of generator optimiser.step() # save loss per mini-batch batch_loss += loss.item() * obs.size(0) # log losses per epoch wandb.log({ 'cb_vae/loss': batch_loss / obs_data.size(0), 'cb_vae/time_taken': time.time() - start_time, 'cb_vae/epoch': epoch }) indices = np.random.randint(0, obs.size(0), args.cb_vae_num_samples**2) measures['true_images'].append(obs[indices].detach().cpu().numpy()) measures['recon_images'].append( recon_batch[indices].detach().cpu().numpy()) # plot ground truth images plt.rcParams.update({'font.size': 10}) fig, axs = plt.subplots(args.cb_vae_num_samples, args.cb_vae_num_samples, figsize=(20, 20)) for i, img in enumerate(measures['true_images'][0]): axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow( img.transpose(1, 2, 0)) axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].axis('off') wandb.log({"Ground Truth Images": wandb.Image(plt)}) # plot reconstructed images fig, axs = plt.subplots(args.cb_vae_num_samples, args.cb_vae_num_samples, figsize=(20, 20)) for i, img in enumerate(measures['recon_images'][0]): axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow( img.transpose(1, 2, 0)) axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].axis('off') wandb.log({"Reconstructed Images": wandb.Image(plt)}) # --- TRAIN Q_aux -- # train PPO agent with value head replaced with action-value head and training on R_aux instead of the environment R print("Training Q_aux Model.") # initialise environments for training Q_aux envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # initialise policy network actor_critic = QModel(obs_shape=envs.observation_space.shape, action_space=envs.action_space, hidden_size=args.hidden_size).to(device) # initialise policy trainer if args.algo == 'ppo': policy = PPO(actor_critic=actor_critic, ppo_epoch=args.policy_ppo_epoch, num_mini_batch=args.policy_num_mini_batch, clip_param=args.policy_clip_param, value_loss_coef=args.policy_value_loss_coef, entropy_coef=args.policy_entropy_coef, max_grad_norm=args.policy_max_grad_norm, lr=args.policy_lr, eps=args.policy_eps) else: raise NotImplementedError # initialise rollout storage for the policy rollouts = RolloutStorage(num_steps=args.policy_num_steps, num_processes=args.num_processes, obs_shape=envs.observation_space.shape, action_space=envs.action_space) # count number of frames and updates frames = 0 iter_idx = 0 update_start_time = time.time() # reset environments obs = envs.reset() # obs.shape = (n_envs,C,H,W) # insert initial observation to rollout storage rollouts.obs[0].copy_(obs) rollouts.to(device) # initialise buffer for calculating mean episodic returns episode_info_buf = deque(maxlen=10) # calculate number of updates # number of frames ÷ number of policy steps before update ÷ number of cpu processes args.num_batch = args.num_processes * args.policy_num_steps args.num_updates = int(args.num_frames_q_aux) // args.num_batch print("Number of updates: ", args.num_updates) for iter_idx in range(args.num_updates): print("Iter: ", iter_idx) # put actor-critic into train mode actor_critic.train() # rollout policy to collect num_batch of experience and store in storage for step in range(args.policy_num_steps): with torch.no_grad(): # sample actions from policy value, action, action_log_prob = actor_critic.act( rollouts.obs[step]) # obtain reward R_aux from encoder of CB-VAE r_aux, _, _ = cb_vae.encode(rollouts.obs[step]) # observe rewards and next obs obs, _, done, infos = envs.step(action) # log episode info if episode finished for i, info in enumerate(infos): if 'episode' in info.keys(): episode_info_buf.append(info['episode']) # create mask for episode ends masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # add experience to policy buffer rollouts.insert(obs, r_aux, action, value, action_log_prob, masks) frames += args.num_processes # --- UPDATE --- # bootstrap next value prediction with torch.no_grad(): next_value = actor_critic.get_value(rollouts.obs[-1]).detach() # compute returns for current rollouts rollouts.compute_returns(next_value, args.policy_gamma, args.policy_gae_lambda) # update actor-critic using policy gradient algo total_loss, value_loss, action_loss, dist_entropy = policy.update( rollouts) # clean up after update rollouts.after_update() # --- LOGGING --- if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1: # get stats for run update_end_time = time.time() num_interval_updates = 1 if iter_idx == 0 else args.log_interval fps = num_interval_updates * ( args.num_processes * args.policy_num_steps) / (update_end_time - update_start_time) update_start_time = update_end_time # Calculates if value function is a good predicator of the returns (ev > 1) # or if it's just worse than predicting nothing (ev =< 0) ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds), utl.sf01(rollouts.returns)) wandb.log({ 'q_aux_misc/timesteps': frames, 'q_aux_misc/fps': fps, 'q_aux_misc/explained_variance': float(ev), 'q_aux_losses/total_loss': total_loss, 'q_aux_losses/value_loss': value_loss, 'q_aux_losses/action_loss': action_loss, 'q_aux_losses/dist_entropy': dist_entropy, 'q_aux_train/mean_episodic_return': utl_math.safe_mean( [episode_info['r'] for episode_info in episode_info_buf]), 'q_aux_train/mean_episodic_length': utl_math.safe_mean( [episode_info['l'] for episode_info in episode_info_buf]) }) # close envs envs.close() # --- SAVE MODEL --- print("Saving Q_aux Model.") torch.save(actor_critic.state_dict(), args.q_aux_path)
def main(): # make the environments if args.num_envs == 1: env = [gym.make(args.env_name)] else: env = [gym.make(args.env_name) for i in range(args.num_envs)] env = MultiGym(env, render=args.render) n_states = env.observation_space.shape n_actions = env.action_space.n print('state shape:', n_states, 'actions:', n_actions) policy = ConvPolicy(n_actions).to(device) optimizer = optim.RMSprop(policy.parameters(), lr=args.lr) if args.algo == 'ppo': sys.path.append('../') from algorithms.ppo import PPO update_algo = PPO(policy=policy, optimizer=optimizer, num_steps=args.num_steps, num_envs=args.num_envs, state_size=(4, 105, 80), entropy_coef=args.entropy, gamma=args.gamma, device=device, epochs=args.ppo_epochs) else: sys.path.append('../') from algorithms.a2c import A2C update_algo = A2C(policy=policy, optimizer=optimizer, num_steps=args.num_steps, num_envs=args.num_envs, state_size=(4, 105, 80), entropy_coef=args.entropy, gamma=args.gamma, device=device) end_rewards = [] try: print('starting episodes') idx = 0 d = False reward_sum = np.zeros((args.num_envs)) restart = True frame = env.reset() mask = torch.ones(args.num_envs) all_start = time.time() for update_idx in range(args.num_updates): update_algo.policy.train() # stack the frames s = train_state_proc.proc_state(frame, mask=mask) # insert state before getting actions update_algo.states[0].copy_(s) start = time.time() for step in range(args.num_steps): with torch.no_grad(): # get probability dist and values p, v = update_algo.policy(update_algo.states[step]) a = Categorical(p).sample() # take action get response frame, r, d = env.step( a.cpu().numpy() if args.num_envs > 1 else [a.item()]) s = train_state_proc.proc_state(frame, mask) update_algo.insert_experience(step=step, s=s, a=a, v=v, r=r, d=d) mask = torch.tensor(1. - d).float() reward_sum = (reward_sum + r) # if any episode finished append episode reward to list if d.any(): end_rewards.extend(reward_sum[d]) # reset any rewards that finished reward_sum = reward_sum * mask.numpy() idx += 1 with torch.no_grad(): _, next_val = update_algo.policy(update_algo.states[-1]) update_algo.update(next_val.view(1, args.num_envs).to(device), next_mask=mask.to(device)) if args.lr_decay: for params in update_algo.optimizer.param_groups: params['lr'] = ( lr_min + 0.5 * (args.lr - lr_min) * (1 + np.cos(np.pi * idx / args.num_updates))) # update every so often by displaying results in term if (update_idx % args.log_interval == 0) and (len(end_rewards) > 0): total_steps = (idx + 1) * args.num_envs * args.num_steps end = time.time() print(end_rewards[-10:]) print('Updates {}\t Time: {:.4f} \t FPS: {}'.format( update_idx, end - start, int(total_steps / (end - all_start)))) print( 'Mean Episode Rewards: {:.2f} \t Min/Max Current Rewards: {}/{}' .format(np.mean(end_rewards[-10:]), reward_sum.min(), reward_sum.max())) except KeyboardInterrupt: pass torch.save( update_algo.policy.state_dict(), '../model_weights/{}_{}_conv.pth'.format(args.env_name, args.algo)) import pandas as pd out_dict = {'avg_end_rewards': end_rewards} out_log = pd.DataFrame(out_dict) out_log.to_csv('../logs/{}_{}_rewards.csv'.format(args.env_name, args.algo), index=False) out_dict = { 'actor losses': update_algo.actor_losses, 'critic losses': update_algo.critic_losses, 'entropy': update_algo.entropy_logs } out_log = pd.DataFrame(out_dict) out_log.to_csv('../logs/{}_{}_training_behavior.csv'.format( args.env_name, args.algo), index=False) plt.plot(end_rewards) plt.show()
def model_fn(obs): x = tf.layers.conv2d(obs, 32, 8, 4, activation=tf.nn.relu) x = tf.layers.conv2d(x, 64, 4, 2, activation=tf.nn.relu) x = tf.layers.conv2d(x, 64, 3, 1, activation=tf.nn.relu) x = tf.contrib.layers.flatten(x) x = tf.layers.dense(x, 512, activation=tf.nn.relu) logit_action_probability = tf.layers.dense( x, action_space, kernel_initializer=tf.truncated_normal_initializer(0.0, 0.01)) state_value = tf.squeeze(tf.layers.dense( x, 1, kernel_initializer=tf.truncated_normal_initializer())) return logit_action_probability, state_value ppo = PPO(action_space, obs_fn, model_fn, train_epoch=5, batch_size=32) env = Raiden2(6666, num_envs=8, with_stack=False) env_ids, states, rewards, dones = env.start() env_states = defaultdict(partial(deque, maxlen=frame_stack)) nth_trajectory = 0 while True: nth_trajectory += 1 for _ in tqdm(range(explore_steps)): sts = [] for env_id, state in zip(env_ids, states): st = np.zeros((size, size, frame_stack), dtype=np.float32) im = cv2.resize(rgb2gray(state), (size, size)) env_states[env_id].append(im) while len(env_states[env_id]) < frame_stack:
n_actions = env.action_space.n print('states:', n_states, 'actions:', n_actions) policy = GRUPolicy(n_states[0], n_actions, args.hid_size, args.num_steps, args.num_envs).to(device) optimizer = optim.RMSprop(policy.parameters(), lr=args.lr, eps=1e-5) if args.algo == 'ppo': sys.path.append('../') from algorithms.ppo import PPO update_algo = PPO(policy=policy, optimizer=optimizer, num_steps=args.num_steps, num_envs=args.num_envs, state_size=n_states, entropy_coef=args.entropy, gamma=args.gamma, device=device, recurrent=True, rnn_size=args.hid_size, epochs=args.ppo_epochs, batch_size=args.batch_size) else: sys.path.append('../') from algorithms.a2c import A2C update_algo = A2C(policy=policy, optimizer=optimizer, num_steps=args.num_steps, num_envs=args.num_envs, state_size=n_states, entropy_coef=args.entropy, gamma=args.gamma,