def fit(self, paths, targvals): X = np.concatenate([self._preproc(p) for p in paths]) y = np.concatenate(targvals) logger.record_tabular( "EVBefore", explained_variance(self._predict(X), y) ) for _ in range(25): self.do_update(X, y) logger.record_tabular("EVAfter", explained_variance(self._predict(X), y))
def main(): # --- ARGUMENTS --- parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='coinrun', help='name of the environment to train on.') parser.add_argument('--model', type=str, default='ppo', help='the model to use for training.') args, rest_args = parser.parse_known_args() env_name = args.env_name model = args.model # get arguments args = args_pretrain_aup.get_args(rest_args) # place other args back into argparse.Namespace args.env_name = env_name args.model = model # Weights & Biases logger if args.run_name is None: # make run name as {env_name}_{TIME} now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S') args.run_name = args.env_name + '_' + args.algo + now # initialise wandb wandb.init(name=args.run_name, project=args.proj_name, group=args.group_name, config=args, monitor_gym=False) # save wandb dir path args.run_dir = wandb.run.dir wandb.config.update(args) # set random seed of random, torch and numpy utl.set_global_seed(args.seed, args.deterministic_execution) # --- OBTAIN DATA FOR TRAINING R_aux --- print("Gathering data for R_aux Model.") # gather observations for pretraining the auxiliary reward function (CB-VAE) envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # number of frames ÷ number of policy steps before update ÷ number of cpu processes num_batch = args.num_processes * args.policy_num_steps num_updates = int(args.num_frames_r_aux) // num_batch # create list to store env observations obs_data = torch.zeros(num_updates * args.policy_num_steps + 1, args.num_processes, *envs.observation_space.shape) # reset environments obs = envs.reset() # obs.shape = (n_env,C,H,W) obs_data[0].copy_(obs) obs = obs.to(device) for iter_idx in range(num_updates): # rollout policy to collect num_batch of experience and store in storage for step in range(args.policy_num_steps): # sample actions from random agent action = torch.randint(0, envs.action_space.n, (args.num_processes, 1)) # observe rewards and next obs obs, reward, done, infos = envs.step(action) # store obs obs_data[1 + iter_idx * args.policy_num_steps + step].copy_(obs) # close envs envs.close() # --- TRAIN R_aux (CB-VAE) --- # define CB-VAE where the encoder will be used as the auxiliary reward function R_aux print("Training R_aux Model.") # create dataloader for observations gathered obs_data = obs_data.reshape(-1, *envs.observation_space.shape) sampler = BatchSampler(SubsetRandomSampler(range(obs_data.size(0))), args.cb_vae_batch_size, drop_last=False) # initialise CB-VAE cb_vae = CBVAE(obs_shape=envs.observation_space.shape, latent_dim=args.cb_vae_latent_dim).to(device) # optimiser optimiser = torch.optim.Adam(cb_vae.parameters(), lr=args.cb_vae_learning_rate) # put CB-VAE into train mode cb_vae.train() measures = defaultdict(list) for epoch in range(args.cb_vae_epochs): print("Epoch: ", epoch) start_time = time.time() batch_loss = 0 for indices in sampler: obs = obs_data[indices].to(device) # zero accumulated gradients cb_vae.zero_grad() # forward pass through CB-VAE recon_batch, mu, log_var = cb_vae(obs) # calculate loss loss = cb_vae_loss(recon_batch, obs, mu, log_var) # backpropogation: calculating gradients loss.backward() # update parameters of generator optimiser.step() # save loss per mini-batch batch_loss += loss.item() * obs.size(0) # log losses per epoch wandb.log({ 'cb_vae/loss': batch_loss / obs_data.size(0), 'cb_vae/time_taken': time.time() - start_time, 'cb_vae/epoch': epoch }) indices = np.random.randint(0, obs.size(0), args.cb_vae_num_samples**2) measures['true_images'].append(obs[indices].detach().cpu().numpy()) measures['recon_images'].append( recon_batch[indices].detach().cpu().numpy()) # plot ground truth images plt.rcParams.update({'font.size': 10}) fig, axs = plt.subplots(args.cb_vae_num_samples, args.cb_vae_num_samples, figsize=(20, 20)) for i, img in enumerate(measures['true_images'][0]): axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow( img.transpose(1, 2, 0)) axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].axis('off') wandb.log({"Ground Truth Images": wandb.Image(plt)}) # plot reconstructed images fig, axs = plt.subplots(args.cb_vae_num_samples, args.cb_vae_num_samples, figsize=(20, 20)) for i, img in enumerate(measures['recon_images'][0]): axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow( img.transpose(1, 2, 0)) axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].axis('off') wandb.log({"Reconstructed Images": wandb.Image(plt)}) # --- TRAIN Q_aux -- # train PPO agent with value head replaced with action-value head and training on R_aux instead of the environment R print("Training Q_aux Model.") # initialise environments for training Q_aux envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # initialise policy network actor_critic = QModel(obs_shape=envs.observation_space.shape, action_space=envs.action_space, hidden_size=args.hidden_size).to(device) # initialise policy trainer if args.algo == 'ppo': policy = PPO(actor_critic=actor_critic, ppo_epoch=args.policy_ppo_epoch, num_mini_batch=args.policy_num_mini_batch, clip_param=args.policy_clip_param, value_loss_coef=args.policy_value_loss_coef, entropy_coef=args.policy_entropy_coef, max_grad_norm=args.policy_max_grad_norm, lr=args.policy_lr, eps=args.policy_eps) else: raise NotImplementedError # initialise rollout storage for the policy rollouts = RolloutStorage(num_steps=args.policy_num_steps, num_processes=args.num_processes, obs_shape=envs.observation_space.shape, action_space=envs.action_space) # count number of frames and updates frames = 0 iter_idx = 0 update_start_time = time.time() # reset environments obs = envs.reset() # obs.shape = (n_envs,C,H,W) # insert initial observation to rollout storage rollouts.obs[0].copy_(obs) rollouts.to(device) # initialise buffer for calculating mean episodic returns episode_info_buf = deque(maxlen=10) # calculate number of updates # number of frames ÷ number of policy steps before update ÷ number of cpu processes args.num_batch = args.num_processes * args.policy_num_steps args.num_updates = int(args.num_frames_q_aux) // args.num_batch print("Number of updates: ", args.num_updates) for iter_idx in range(args.num_updates): print("Iter: ", iter_idx) # put actor-critic into train mode actor_critic.train() # rollout policy to collect num_batch of experience and store in storage for step in range(args.policy_num_steps): with torch.no_grad(): # sample actions from policy value, action, action_log_prob = actor_critic.act( rollouts.obs[step]) # obtain reward R_aux from encoder of CB-VAE r_aux, _, _ = cb_vae.encode(rollouts.obs[step]) # observe rewards and next obs obs, _, done, infos = envs.step(action) # log episode info if episode finished for i, info in enumerate(infos): if 'episode' in info.keys(): episode_info_buf.append(info['episode']) # create mask for episode ends masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # add experience to policy buffer rollouts.insert(obs, r_aux, action, value, action_log_prob, masks) frames += args.num_processes # --- UPDATE --- # bootstrap next value prediction with torch.no_grad(): next_value = actor_critic.get_value(rollouts.obs[-1]).detach() # compute returns for current rollouts rollouts.compute_returns(next_value, args.policy_gamma, args.policy_gae_lambda) # update actor-critic using policy gradient algo total_loss, value_loss, action_loss, dist_entropy = policy.update( rollouts) # clean up after update rollouts.after_update() # --- LOGGING --- if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1: # get stats for run update_end_time = time.time() num_interval_updates = 1 if iter_idx == 0 else args.log_interval fps = num_interval_updates * ( args.num_processes * args.policy_num_steps) / (update_end_time - update_start_time) update_start_time = update_end_time # Calculates if value function is a good predicator of the returns (ev > 1) # or if it's just worse than predicting nothing (ev =< 0) ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds), utl.sf01(rollouts.returns)) wandb.log({ 'q_aux_misc/timesteps': frames, 'q_aux_misc/fps': fps, 'q_aux_misc/explained_variance': float(ev), 'q_aux_losses/total_loss': total_loss, 'q_aux_losses/value_loss': value_loss, 'q_aux_losses/action_loss': action_loss, 'q_aux_losses/dist_entropy': dist_entropy, 'q_aux_train/mean_episodic_return': utl_math.safe_mean( [episode_info['r'] for episode_info in episode_info_buf]), 'q_aux_train/mean_episodic_length': utl_math.safe_mean( [episode_info['l'] for episode_info in episode_info_buf]) }) # close envs envs.close() # --- SAVE MODEL --- print("Saving Q_aux Model.") torch.save(actor_critic.state_dict(), args.q_aux_path)
def fit(policy, env, seed, nsteps=5, total_timesteps=int(80e6), vf_coef=0.5, ent_coef=0.01, max_grad_norm=0.5, lr=7e-4, lrschedule='linear', epsilon=1e-5, alpha=0.99, gamma=0.99, log_interval=100): set_global_seeds(seed) model = A2C(policy=policy, observation_space=env.observation_space, action_space=env.action_space, nenvs=env.num_envs, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule) session = model.init_session() tf.global_variables_initializer().run(session=session) env_runner = Environment(env, model, nsteps=nsteps, gamma=gamma) nbatch = env.num_envs * nsteps tstart = time.time() writer = tf.summary.FileWriter('output', session.graph) for update in range(1, total_timesteps // nbatch + 1): tf.reset_default_graph() obs, states, rewards, masks, actions, values = env_runner.run(session) policy_loss, value_loss, policy_entropy = model.predict( observations=obs, states=states, rewards=rewards, masks=masks, actions=actions, values=values, session=session) nseconds = time.time() - tstart fps = int((update * nbatch) / nseconds) if update % log_interval == 0 or update == 1: ev = explained_variance(values, rewards) logger.record_tabular("nupdates", update) logger.record_tabular("total_timesteps", update * nbatch) logger.record_tabular("fps", fps) logger.record_tabular("policy_entropy", float(policy_entropy)) logger.record_tabular("value_loss", float(value_loss)) logger.record_tabular("explained_variance", float(ev)) logger.dump_tabular() env.close() writer.close() session.close()
def fit( policy, env, nsteps, total_timesteps, ent_coef, lr, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2, save_interval=0, load_path=None ): if isinstance(lr, float): lr = constfn(lr) else: assert callable(lr) if isinstance(cliprange, float): cliprange = constfn(cliprange) else: assert callable(cliprange) total_timesteps = int(total_timesteps) nenvs = env.num_envs # nenvs = 8 ob_space = env.observation_space ac_space = env.action_space nbatch = nenvs * nsteps nbatch_train = nbatch // nminibatches model = PPO2( policy=policy, observation_space=ob_space, action_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm ) Agent().init_vars() # if save_interval and logger.get_dir(): # import cloudpickle # with open(os.path.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh: # fh.write(cloudpickle.dumps(make_model)) # model = make_model() # if load_path is not None: # model.load(load_path) runner = Environment(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam) epinfobuf = deque(maxlen=100) tfirststart = time.time() nupdates = total_timesteps//nbatch for update in range(1, nupdates+1): assert nbatch % nminibatches == 0 nbatch_train = nbatch // nminibatches tstart = time.time() frac = 1.0 - (update - 1.0) / nupdates lrnow = lr(frac) cliprangenow = cliprange(frac) obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run() epinfobuf.extend(epinfos) mblossvals = [] if states is None: # nonrecurrent version inds = np.arange(nbatch) for _ in range(noptepochs): np.random.shuffle(inds) for start in range(0, nbatch, nbatch_train): end = start + nbatch_train mbinds = inds[start:end] slices = ( arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs) ) mblossvals.append(model.predict(lrnow, cliprangenow, *slices)) else: # recurrent version assert nenvs % nminibatches == 0 envsperbatch = nenvs // nminibatches envinds = np.arange(nenvs) flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps) envsperbatch = nbatch_train // nsteps for _ in range(noptepochs): np.random.shuffle(envinds) for start in range(0, nenvs, envsperbatch): end = start + envsperbatch mbenvinds = envinds[start:end] mbflatinds = flatinds[mbenvinds].ravel() slices = ( arr[mbflatinds] for arr in (obs, returns, masks, actions, values, neglogpacs) ) mbstates = states[mbenvinds] mblossvals.append(model.predict(lrnow, cliprangenow, *slices, mbstates)) lossvals = np.mean(mblossvals, axis=0) tnow = time.time() fps = int(nbatch / (tnow - tstart)) if update % log_interval == 0 or update == 1: ev = explained_variance(values, returns) logger.logkv("serial_timesteps", update*nsteps) logger.logkv("nupdates", update) logger.logkv("total_timesteps", update*nbatch) logger.logkv("fps", fps) logger.logkv("explained_variance", float(ev)) logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf])) logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf])) logger.logkv('time_elapsed', tnow - tfirststart) for (lossval, lossname) in zip(lossvals, model.loss_names): logger.logkv(lossname, lossval) logger.dumpkvs() if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir(): checkdir = os.path.join(logger.get_dir(), 'checkpoints') os.makedirs(checkdir, exist_ok=True) savepath = os.path.join(checkdir, '%.5i' % update) print('Saving to', savepath) model.save(savepath) env.close() return model
def rollouts(self): # Prepare for rollouts # ---------------------------------------- seg_gen = self.traj_segment_generator(self.pi, self.env, self.timesteps_per_actorbatch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards assert sum([ self.max_iters > 0, self.max_timesteps > 0, self.max_episodes > 0, self.max_seconds > 0 ]) == 1, "Only one time constraint permitted" while True: if self.callback: self.callback(locals(), globals()) if self.max_timesteps and timesteps_so_far >= self.max_timesteps: break elif self.max_episodes and episodes_so_far >= self.max_episodes: break elif self.max_iters and iters_so_far >= self.max_iters: break elif self.max_seconds and time.time() - tstart >= self.max_seconds: break if self.schedule == 'constant': cur_lrmult = 1.0 elif self.schedule == 'linear': cur_lrmult = max( 1.0 - float(timesteps_so_far) / self.max_timesteps, 0) else: raise NotImplementedError logger.log("********** Iteration %i ************" % iters_so_far) seg = seg_gen.__next__() self.add_vtarg_and_adv(seg, self.gamma, self.lam) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], \ seg["tdlamret"] vpredbefore = seg[ "vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean()) / atarg.std( ) # standardized advantage function estimate d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not self.pi.recurrent) optim_batchsize = self.optim_batchsize or ob.shape[0] if hasattr(self.pi, "ob_rms"): self.pi.ob_rms.update(ob) # update running mean/std for policy self.assign_old_eq_new( ) # set old parameter values to new parameter values logger.log("Optimizing...") logger.log(fmt_row(13, self.loss_names)) # Here we do a bunch of optimization epochs over the data for _ in range(self.optim_epochs): losses = [ ] # list of tuples, each of which gives the loss for a minibatch for batch in d.iterate_once(optim_batchsize): *newlosses, g = self.lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult) self.adam.update(g, self.optim_stepsize * cur_lrmult) losses.append(newlosses) logger.log(fmt_row(13, np.mean(losses, axis=0))) logger.log("Evaluating losses...") losses = [] for batch in d.iterate_once(optim_batchsize): newlosses = self.loss(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult) losses.append(newlosses) meanlosses, _, _ = mpi_moments(losses, axis=0) logger.log(fmt_row(13, meanlosses)) for (lossval, name) in zipsame(meanlosses, self.loss_names): logger.record_tabular("loss_" + name, lossval) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews = map(self.flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if MPI.COMM_WORLD.Get_rank() == 0: logger.dump_tabular() return self.pi
def main(): parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='coinrun', help='name of the environment to train on.') parser.add_argument('--model', type=str, default='ppo', help='the model to use for training. {ppo, ppo_aup}') args, rest_args = parser.parse_known_args() env_name = args.env_name model = args.model # --- ARGUMENTS --- if model == 'ppo': args = args_ppo.get_args(rest_args) elif model == 'ppo_aup': args = args_ppo_aup.get_args(rest_args) else: raise NotImplementedError # place other args back into argparse.Namespace args.env_name = env_name args.model = model # warnings if args.deterministic_execution: print('Envoking deterministic code execution.') if torch.backends.cudnn.enabled: warnings.warn('Running with deterministic CUDNN.') if args.num_processes > 1: raise RuntimeError( 'If you want fully deterministic code, run it with num_processes=1.' 'Warning: This will slow things down and might break A2C if ' 'policy_num_steps < env._max_episode_steps.') # --- TRAINING --- print("Setting up wandb logging.") # Weights & Biases logger if args.run_name is None: # make run name as {env_name}_{TIME} now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S') args.run_name = args.env_name + '_' + args.algo + now # initialise wandb wandb.init(project=args.proj_name, name=args.run_name, group=args.group_name, config=args, monitor_gym=False) # save wandb dir path args.run_dir = wandb.run.dir # make directory for saving models save_dir = os.path.join(wandb.run.dir, 'models') if not os.path.exists(save_dir): os.makedirs(save_dir) # set random seed of random, torch and numpy utl.set_global_seed(args.seed, args.deterministic_execution) print("Setting up Environments.") # initialise environments for training train_envs = make_vec_envs(env_name=args.env_name, start_level=args.train_start_level, num_levels=args.train_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # initialise environments for evaluation eval_envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) _ = eval_envs.reset() print("Setting up Actor-Critic model and Training algorithm.") # initialise policy network actor_critic = ACModel(obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space, hidden_size=args.hidden_size).to(device) # initialise policy training algorithm if args.algo == 'ppo': policy = PPO(actor_critic=actor_critic, ppo_epoch=args.policy_ppo_epoch, num_mini_batch=args.policy_num_mini_batch, clip_param=args.policy_clip_param, value_loss_coef=args.policy_value_loss_coef, entropy_coef=args.policy_entropy_coef, max_grad_norm=args.policy_max_grad_norm, lr=args.policy_lr, eps=args.policy_eps) else: raise NotImplementedError # initialise rollout storage for the policy training algorithm rollouts = RolloutStorage(num_steps=args.policy_num_steps, num_processes=args.num_processes, obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space) # initialise Q_aux function(s) for AUP if args.use_aup: print("Initialising Q_aux models.") q_aux = [ QModel(obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space, hidden_size=args.hidden_size).to(device) for _ in range(args.num_q_aux) ] if args.num_q_aux == 1: # load weights to model path = args.q_aux_dir + "0.pt" q_aux[0].load_state_dict(torch.load(path)) q_aux[0].eval() else: # get max number of q_aux functions to choose from args.max_num_q_aux = os.listdir(args.q_aux_dir) q_aux_models = random.sample(list(range(0, args.max_num_q_aux)), args.num_q_aux) # load weights to models for i, model in enumerate(q_aux): path = args.q_aux_dir + str(q_aux_models[i]) + ".pt" model.load_state_dict(torch.load(path)) model.eval() # count number of frames and updates frames = 0 iter_idx = 0 # update wandb args wandb.config.update(args) update_start_time = time.time() # reset environments obs = train_envs.reset() # obs.shape = (num_processes,C,H,W) # insert initial observation to rollout storage rollouts.obs[0].copy_(obs) rollouts.to(device) # initialise buffer for calculating mean episodic returns episode_info_buf = deque(maxlen=10) # calculate number of updates # number of frames ÷ number of policy steps before update ÷ number of processes args.num_batch = args.num_processes * args.policy_num_steps args.num_updates = int(args.num_frames) // args.num_batch # define AUP coefficient if args.use_aup: aup_coef = args.aup_coef_start aup_linear_increase_val = math.exp( math.log(args.aup_coef_end / args.aup_coef_start) / args.num_updates) print("Training beginning.") print("Number of updates: ", args.num_updates) for iter_idx in range(args.num_updates): print("Iter: ", iter_idx) # put actor-critic into train mode actor_critic.train() if args.use_aup: aup_measures = defaultdict(list) # rollout policy to collect num_batch of experience and place in storage for step in range(args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action, action_log_prob = actor_critic.act( rollouts.obs[step]) # observe rewards and next obs obs, reward, done, infos = train_envs.step(action) # calculate AUP reward if args.use_aup: intrinsic_reward = torch.zeros_like(reward) with torch.no_grad(): for model in q_aux: # get action-values action_values = model.get_action_value( rollouts.obs[step]) # get action-value for action taken action_value = torch.sum( action_values * torch.nn.functional.one_hot( action, num_classes=train_envs.action_space.n).squeeze( dim=1), dim=1) # calculate the penalty intrinsic_reward += torch.abs( action_value.unsqueeze(dim=1) - action_values[:, 4].unsqueeze(dim=1)) intrinsic_reward /= args.num_q_aux # add intrinsic reward to the extrinsic reward reward -= aup_coef * intrinsic_reward # log the intrinsic reward from the first env. aup_measures['intrinsic_reward'].append(aup_coef * intrinsic_reward[0, 0]) if done[0] and infos[0]['prev_level_complete'] == 1: aup_measures['episode_complete'].append(2) elif done[0] and infos[0]['prev_level_complete'] == 0: aup_measures['episode_complete'].append(1) else: aup_measures['episode_complete'].append(0) # log episode info if episode finished for info in infos: if 'episode' in info.keys(): episode_info_buf.append(info['episode']) # create mask for episode ends masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # add experience to storage rollouts.insert(obs, reward, action, value, action_log_prob, masks) frames += args.num_processes # linearly increase aup coefficient after every update if args.use_aup: aup_coef *= aup_linear_increase_val # --- UPDATE --- # bootstrap next value prediction with torch.no_grad(): next_value = actor_critic.get_value(rollouts.obs[-1]).detach() # compute returns for current rollouts rollouts.compute_returns(next_value, args.policy_gamma, args.policy_gae_lambda) # update actor-critic using policy training algorithm total_loss, value_loss, action_loss, dist_entropy = policy.update( rollouts) # clean up storage after update rollouts.after_update() # --- LOGGING --- if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1: # --- EVALUATION --- eval_episode_info_buf = utl_eval.evaluate( eval_envs=eval_envs, actor_critic=actor_critic, device=device) # get stats for run update_end_time = time.time() num_interval_updates = 1 if iter_idx == 0 else args.log_interval fps = num_interval_updates * ( args.num_processes * args.policy_num_steps) / (update_end_time - update_start_time) update_start_time = update_end_time # calculates whether the value function is a good predicator of the returns (ev > 1) # or if it's just worse than predicting nothing (ev =< 0) ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds), utl.sf01(rollouts.returns)) if args.use_aup: step = frames - args.num_processes * args.policy_num_steps for i in range(args.policy_num_steps): wandb.log( { 'aup/intrinsic_reward': aup_measures['intrinsic_reward'][i], 'aup/episode_complete': aup_measures['episode_complete'][i] }, step=step) step += args.num_processes wandb.log( { 'misc/timesteps': frames, 'misc/fps': fps, 'misc/explained_variance': float(ev), 'losses/total_loss': total_loss, 'losses/value_loss': value_loss, 'losses/action_loss': action_loss, 'losses/dist_entropy': dist_entropy, 'train/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in episode_info_buf ]), 'train/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in episode_info_buf ]), 'eval/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in eval_episode_info_buf ]), 'eval/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in eval_episode_info_buf ]) }, step=frames) # --- SAVE MODEL --- # save for every interval-th episode or for the last epoch if iter_idx != 0 and (iter_idx % args.save_interval == 0 or iter_idx == args.num_updates - 1): print("Saving Actor-Critic Model.") torch.save(actor_critic.state_dict(), os.path.join(save_dir, "policy{0}.pt".format(iter_idx))) # close envs train_envs.close() eval_envs.close() # --- TEST --- if args.test: print("Testing beginning.") episodic_return = utl_test.test(args=args, actor_critic=actor_critic, device=device) # save returns from train and test levels to analyse using interactive mode train_levels = torch.arange( args.train_start_level, args.train_start_level + args.train_num_levels) for i, level in enumerate(train_levels): wandb.log({ 'test/train_levels': level, 'test/train_returns': episodic_return[0][i] }) test_levels = torch.arange( args.test_start_level, args.test_start_level + args.test_num_levels) for i, level in enumerate(test_levels): wandb.log({ 'test/test_levels': level, 'test/test_returns': episodic_return[1][i] }) # log returns from test envs wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean( episodic_return[0]) wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean( episodic_return[1])
def main(): parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='coinrun', help='name of the environment to train on.') parser.add_argument( '--model', type=str, default='ppo', help='the model to use for training. {ppo, ibac, ibac_sni, dist_match}' ) args, rest_args = parser.parse_known_args() env_name = args.env_name model = args.model # --- ARGUMENTS --- if model == 'ppo': args = args_ppo.get_args(rest_args) elif model == 'ibac': args = args_ibac.get_args(rest_args) elif model == 'ibac_sni': args = args_ibac_sni.get_args(rest_args) elif model == 'dist_match': args = args_dist_match.get_args(rest_args) else: raise NotImplementedError # place other args back into argparse.Namespace args.env_name = env_name args.model = model args.num_train_envs = args.num_processes - args.num_val_envs if args.num_val_envs > 0 else args.num_processes # warnings if args.deterministic_execution: print('Envoking deterministic code execution.') if torch.backends.cudnn.enabled: warnings.warn('Running with deterministic CUDNN.') if args.num_processes > 1: raise RuntimeError( 'If you want fully deterministic code, run it with num_processes=1.' 'Warning: This will slow things down and might break A2C if ' 'policy_num_steps < env._max_episode_steps.') elif args.num_val_envs > 0 and (args.num_val_envs >= args.num_processes or not args.percentage_levels_train < 1.0): raise ValueError( 'If --args.num_val_envs>0 then you must also have' '--num_val_envs < --num_processes and 0 < --percentage_levels_train < 1.' ) elif args.num_val_envs > 0 and not args.use_dist_matching and args.dist_matching_coef != 0: raise ValueError( 'If --num_val_envs>0 and --use_dist_matching=False then you must also have' '--dist_matching_coef=0.') elif args.use_dist_matching and not args.num_val_envs > 0: raise ValueError( 'If --use_dist_matching=True then you must also have' '0 < --num_val_envs < --num_processes and 0 < --percentage_levels_train < 1.' ) elif args.analyse_rep and not args.use_bottleneck: raise ValueError('If --analyse_rep=True then you must also have' '--use_bottleneck=True.') # --- TRAINING --- print("Setting up wandb logging.") # Weights & Biases logger if args.run_name is None: # make run name as {env_name}_{TIME} now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S') args.run_name = args.env_name + '_' + args.algo + now # initialise wandb wandb.init(project=args.proj_name, name=args.run_name, group=args.group_name, config=args, monitor_gym=False) # save wandb dir path args.run_dir = wandb.run.dir # make directory for saving models save_dir = os.path.join(wandb.run.dir, 'models') if not os.path.exists(save_dir): os.makedirs(save_dir) # set random seed of random, torch and numpy utl.set_global_seed(args.seed, args.deterministic_execution) # initialise environments for training print("Setting up Environments.") if args.num_val_envs > 0: train_num_levels = int(args.train_num_levels * args.percentage_levels_train) val_start_level = args.train_start_level + train_num_levels val_num_levels = args.train_num_levels - train_num_levels train_envs = make_vec_envs(env_name=args.env_name, start_level=args.train_start_level, num_levels=train_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_train_envs, num_frame_stack=args.num_frame_stack, device=device) val_envs = make_vec_envs(env_name=args.env_name, start_level=val_start_level, num_levels=val_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_val_envs, num_frame_stack=args.num_frame_stack, device=device) else: train_envs = make_vec_envs(env_name=args.env_name, start_level=args.train_start_level, num_levels=args.train_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # initialise environments for evaluation eval_envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) _ = eval_envs.reset() # initialise environments for analysing the representation if args.analyse_rep: analyse_rep_train1_envs, analyse_rep_train2_envs, analyse_rep_val_envs, analyse_rep_test_envs = make_rep_analysis_envs( args, device) print("Setting up Actor-Critic model and Training algorithm.") # initialise policy network actor_critic = ACModel(obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space, hidden_size=args.hidden_size, use_bottleneck=args.use_bottleneck, sni_type=args.sni_type).to(device) # initialise policy training algorithm if args.algo == 'ppo': policy = PPO(actor_critic=actor_critic, ppo_epoch=args.policy_ppo_epoch, num_mini_batch=args.policy_num_mini_batch, clip_param=args.policy_clip_param, value_loss_coef=args.policy_value_loss_coef, entropy_coef=args.policy_entropy_coef, max_grad_norm=args.policy_max_grad_norm, lr=args.policy_lr, eps=args.policy_eps, vib_coef=args.vib_coef, sni_coef=args.sni_coef, use_dist_matching=args.use_dist_matching, dist_matching_loss=args.dist_matching_loss, dist_matching_coef=args.dist_matching_coef, num_train_envs=args.num_train_envs, num_val_envs=args.num_val_envs) else: raise NotImplementedError # initialise rollout storage for the policy training algorithm rollouts = RolloutStorage(num_steps=args.policy_num_steps, num_processes=args.num_processes, obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space) # count number of frames and updates frames = 0 iter_idx = 0 # update wandb args wandb.config.update(args) # wandb.watch(actor_critic, log="all") # to log gradients of actor-critic network update_start_time = time.time() # reset environments if args.num_val_envs > 0: obs = torch.cat([train_envs.reset(), val_envs.reset()]) # obs.shape = (n_envs,C,H,W) else: obs = train_envs.reset() # obs.shape = (n_envs,C,H,W) # insert initial observation to rollout storage rollouts.obs[0].copy_(obs) rollouts.to(device) # initialise buffer for calculating mean episodic returns train_episode_info_buf = deque(maxlen=10) val_episode_info_buf = deque(maxlen=10) # calculate number of updates # number of frames ÷ number of policy steps before update ÷ number of processes args.num_batch = args.num_processes * args.policy_num_steps args.num_updates = int(args.num_frames) // args.num_batch print("Training beginning.") print("Number of updates: ", args.num_updates) for iter_idx in range(args.num_updates): print("Iter: ", iter_idx) # put actor-critic into train mode actor_critic.train() # rollout policy to collect num_batch of experience and place in storage for step in range(args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action, action_log_prob, _ = actor_critic.act( rollouts.obs[step]) # observe rewards and next obs if args.num_val_envs > 0: obs, reward, done, infos = train_envs.step( action[:args.num_train_envs, :]) val_obs, val_reward, val_done, val_infos = val_envs.step( action[args.num_train_envs:, :]) obs = torch.cat([obs, val_obs]) reward = torch.cat([reward, val_reward]) done, val_done = list(done), list(val_done) done.extend(val_done) infos.extend(val_infos) else: obs, reward, done, infos = train_envs.step(action) # log episode info if episode finished for i, info in enumerate(infos): if i < args.num_train_envs and 'episode' in info.keys(): train_episode_info_buf.append(info['episode']) elif i >= args.num_train_envs and 'episode' in info.keys(): val_episode_info_buf.append(info['episode']) # create mask for episode ends masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # add experience to storage rollouts.insert(obs, reward, action, value, action_log_prob, masks) frames += args.num_processes # --- UPDATE --- # bootstrap next value prediction with torch.no_grad(): next_value = actor_critic.get_value(rollouts.obs[-1]).detach() # compute returns for current rollouts rollouts.compute_returns(next_value, args.policy_gamma, args.policy_gae_lambda) # update actor-critic using policy gradient algo total_loss, value_loss, action_loss, dist_entropy, vib_kl, dist_matching_loss = policy.update( rollouts) # clean up storage after update rollouts.after_update() # --- LOGGING --- if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1: # --- EVALUATION --- eval_episode_info_buf = utl_eval.evaluate( eval_envs=eval_envs, actor_critic=actor_critic, device=device) # --- ANALYSE REPRESENTATION --- if args.analyse_rep: rep_measures = utl_rep.analyse_rep( args=args, train1_envs=analyse_rep_train1_envs, train2_envs=analyse_rep_train2_envs, val_envs=analyse_rep_val_envs, test_envs=analyse_rep_test_envs, actor_critic=actor_critic, device=device) # get stats for run update_end_time = time.time() num_interval_updates = 1 if iter_idx == 0 else args.log_interval fps = num_interval_updates * ( args.num_processes * args.policy_num_steps) / (update_end_time - update_start_time) update_start_time = update_end_time # Calculates if value function is a good predicator of the returns (ev > 1) # or if it's just worse than predicting nothing (ev =< 0) ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds), utl.sf01(rollouts.returns)) wandb.log( { 'misc/timesteps': frames, 'misc/fps': fps, 'misc/explained_variance': float(ev), 'losses/total_loss': total_loss, 'losses/value_loss': value_loss, 'losses/action_loss': action_loss, 'losses/dist_entropy': dist_entropy, 'train/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in train_episode_info_buf ]), 'train/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in train_episode_info_buf ]), 'eval/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in eval_episode_info_buf ]), 'eval/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in eval_episode_info_buf ]) }, step=iter_idx) if args.use_bottleneck: wandb.log({'losses/vib_kl': vib_kl}, step=iter_idx) if args.num_val_envs > 0: wandb.log( { 'losses/dist_matching_loss': dist_matching_loss, 'val/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in val_episode_info_buf ]), 'val/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in val_episode_info_buf ]) }, step=iter_idx) if args.analyse_rep: wandb.log( { "analysis/" + key: val for key, val in rep_measures.items() }, step=iter_idx) # --- SAVE MODEL --- # save for every interval-th episode or for the last epoch if iter_idx != 0 and (iter_idx % args.save_interval == 0 or iter_idx == args.num_updates - 1): print("Saving Actor-Critic Model.") torch.save(actor_critic.state_dict(), os.path.join(save_dir, "policy{0}.pt".format(iter_idx))) # close envs train_envs.close() eval_envs.close() # --- TEST --- if args.test: print("Testing beginning.") episodic_return, latents_z = utl_test.test(args=args, actor_critic=actor_critic, device=device) # save returns from train and test levels to analyse using interactive mode train_levels = torch.arange( args.train_start_level, args.train_start_level + args.train_num_levels) for i, level in enumerate(train_levels): wandb.log({ 'test/train_levels': level, 'test/train_returns': episodic_return[0][i] }) test_levels = torch.arange( args.test_start_level, args.test_start_level + args.test_num_levels) for i, level in enumerate(test_levels): wandb.log({ 'test/test_levels': level, 'test/test_returns': episodic_return[1][i] }) # log returns from test envs wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean( episodic_return[0]) wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean( episodic_return[1]) # plot latent representation if args.plot_pca: print("Plotting PCA of Latent Representation.") utl_rep.pca(args, latents_z)
def fit( model, env, timesteps_per_batch, # what to train on max_kl, cg_iters, gamma, lam, # advantage estimation entcoeff=0.0, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters=3, max_timesteps=0, max_episodes=0, max_iters=0, # time constraint callback=None): # Setup losses and stuff # ---------------------------------------- # nworkers = MPI.COMM_WORLD.Get_size() # rank = MPI.COMM_WORLD.Get_rank() np.set_printoptions(precision=3) th_init = model.get_flat() MPI.COMM_WORLD.Bcast(th_init, root=0) model.set_from_flat(th_init) model.vfadam.sync() print("Init param sum", th_init.sum(), flush=True) # Prepare for rollouts # ---------------------------------------- seg_gen = model.traj_segment_generator(model.pi, env, timesteps_per_batch, stochastic=True) episodes_so_far = 0 timesteps_so_far = 0 iters_so_far = 0 tstart = time.time() lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1 while True: if callback: callback(locals(), globals()) if max_timesteps and timesteps_so_far >= max_timesteps: break elif max_episodes and episodes_so_far >= max_episodes: break elif max_iters and iters_so_far >= max_iters: break logger.log("********** Iteration %i ************" % iters_so_far) with model.timed("sampling"): seg = seg_gen.__next__() model.add_vtarg_and_adv(seg, gamma, lam) # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[ "tdlamret"] vpredbefore = seg["vpred"] # predicted value function before udpate atarg = (atarg - atarg.mean() ) / atarg.std() # standardized advantage function estimate if hasattr(model.pi, "ret_rms"): model.pi.ret_rms.update(tdlamret) if hasattr(model.pi, "ob_rms"): model.pi.ob_rms.update(ob) # update running mean/std for policy args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] def fisher_vector_product(p): return model.allmean(model.compute_fvp(p, * fvpargs)) + cg_damping * p model.assign_old_eq_new( ) # set old parameter values to new parameter values with model.timed("computegrad"): *lossbefore, g = model.compute_lossandgrad(*args) lossbefore = model.allmean(np.array(lossbefore)) g = model.allmean(g) if np.allclose(g, 0): logger.log("Got zero gradient. not updating") else: with model.timed("cg"): stepdir = conjugate_gradient(fisher_vector_product, g, cg_iters=cg_iters, verbose=model.rank == 0) assert np.isfinite(stepdir).all() shs = .5 * stepdir.dot(fisher_vector_product(stepdir)) lm = np.sqrt(shs / max_kl) # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) fullstep = stepdir / lm expectedimprove = g.dot(fullstep) surrbefore = lossbefore[0] stepsize = 1.0 thbefore = model.get_flat() for _ in range(10): thnew = thbefore + fullstep * stepsize model.set_from_flat(thnew) meanlosses = surr, kl, *_ = model.allmean( np.array(model.compute_losses(*args))) improve = surr - surrbefore logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve)) if not np.isfinite(meanlosses).all(): logger.log("Got non-finite value of losses -- bad!") elif kl > max_kl * 1.5: logger.log("violated KL constraint. shrinking step.") elif improve < 0: logger.log("surrogate didn't improve. shrinking step.") else: logger.log("Stepsize OK!") break stepsize *= .5 else: logger.log("couldn't compute a good step") model.set_from_flat(thbefore) if model.nworkers > 1 and iters_so_far % 20 == 0: paramsums = MPI.COMM_WORLD.allgather( (thnew.sum(), model.vfadam.getflat().sum())) # list of tuples assert all( np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) for (lossname, lossval) in zip(model.loss_names, meanlosses): logger.record_tabular(lossname, lossval) with model.timed("vf"): for _ in range(vf_iters): for (mbob, mbret) in dataset.iterbatches( (seg["ob"], seg["tdlamret"]), include_final_partial_batch=False, batch_size=64): g = model.allmean(model.compute_vflossandgrad(mbob, mbret)) model.vfadam.update(g, vf_stepsize) logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples lens, rews = map(model.flatten_lists, zip(*listoflrpairs)) lenbuffer.extend(lens) rewbuffer.extend(rews) logger.record_tabular("EpLenMean", np.mean(lenbuffer)) logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) timesteps_so_far += sum(lens) iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) logger.record_tabular("TimestepsSoFar", timesteps_so_far) logger.record_tabular("TimeElapsed", time.time() - tstart) if model.rank == 0: logger.dump_tabular()
def fit(policy, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=1, nprocs=32, nsteps=20, ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5, kfac_clip=0.001, save_interval=None, lrschedule='linear'): tf.reset_default_graph() set_global_seeds(seed) nenvs = env.num_envs ob_space = env.observation_space ac_space = env.action_space model = AcktrDiscrete(policy, ob_space, ac_space, nenvs, total_timesteps, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_fisher_coef, lr=lr, max_grad_norm=max_grad_norm, kfac_clip=kfac_clip, lrschedule=lrschedule) # if save_interval and logger.get_dir(): # import cloudpickle # with open(os.path.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh: # fh.write(cloudpickle.dumps(make_model)) # model = make_model() runner = Environment(env, model, nsteps=nsteps, gamma=gamma) nbatch = nenvs * nsteps tstart = time.time() coord = tf.train.Coordinator() enqueue_threads = model.q_runner.create_threads(model.sess, coord=coord, start=True) for update in range(1, total_timesteps // nbatch + 1): obs, states, rewards, masks, actions, values = runner.run() policy_loss, value_loss, policy_entropy = model.train( obs, states, rewards, masks, actions, values) model.old_obs = obs nseconds = time.time() - tstart fps = int((update * nbatch) / nseconds) if update % log_interval == 0 or update == 1: ev = explained_variance(values, rewards) logger.record_tabular("nupdates", update) logger.record_tabular("total_timesteps", update * nbatch) logger.record_tabular("fps", fps) logger.record_tabular("policy_entropy", float(policy_entropy)) logger.record_tabular("policy_loss", float(policy_loss)) logger.record_tabular("value_loss", float(value_loss)) logger.record_tabular("explained_variance", float(ev)) logger.dump_tabular() if save_interval and (update % save_interval == 0 or update == 1) \ and logger.get_dir(): savepath = os.path.join(logger.get_dir(), 'checkpoint%.5i' % update) print('Saving to', savepath) model.save(savepath) coord.request_stop() coord.join(enqueue_threads) env.close()