def train(args): args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) log_dir = os.path.expanduser(args.log_dir) utils.cleanup_log_dir(log_dir) torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") log_file = '-{}-{}-reproduce-s{}'.format(args.run_name, args.env_name, args.seed) venv = ProcgenEnv(num_envs=args.num_processes, env_name=args.env_name, \ num_levels=args.num_levels, start_level=args.start_level, \ distribution_mode=args.distribution_mode) venv = VecExtractDictObs(venv, "rgb") venv = VecMonitor(venv=venv, filename=None, keep_buf=100) venv = VecNormalize(venv=venv, ob=False) envs = VecPyTorchProcgen(venv, device) obs_shape = envs.observation_space.shape ################################ actor_critic = Policy(obs_shape, envs.action_space.n, base_kwargs={ 'recurrent': False, 'hidden_size': args.hidden_size }) actor_critic.to(device) ################################ rollouts = RolloutStorage(args.num_steps, args.num_processes, envs.observation_space.shape, envs.action_space, actor_critic.recurrent_hidden_state_size, aug_type=args.aug_type, split_ratio=args.split_ratio) batch_size = int(args.num_processes * args.num_steps / args.num_mini_batch) ################################ if args.use_ucb: aug_id = data_augs.Identity aug_list = [ aug_to_func[t](batch_size=batch_size) for t in list(aug_to_func.keys()) ] agent = algo.UCBDrAC(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_list=aug_list, aug_id=aug_id, aug_coef=args.aug_coef, num_aug_types=len(list(aug_to_func.keys())), ucb_exploration_coef=args.ucb_exploration_coef, ucb_window_length=args.ucb_window_length) elif args.use_meta_learning: aug_id = data_augs.Identity aug_list = [aug_to_func[t](batch_size=batch_size) \ for t in list(aug_to_func.keys())] aug_model = AugCNN() aug_model.to(device) agent = algo.MetaDrAC(actor_critic, aug_model, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, meta_grad_clip=args.meta_grad_clip, meta_num_train_steps=args.meta_num_train_steps, meta_num_test_steps=args.meta_num_test_steps, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_id=aug_id, aug_coef=args.aug_coef) elif args.use_rl2: aug_id = data_augs.Identity aug_list = [ aug_to_func[t](batch_size=batch_size) for t in list(aug_to_func.keys()) ] rl2_obs_shape = [envs.action_space.n + 1] rl2_learner = Policy(rl2_obs_shape, len(list(aug_to_func.keys())), base_kwargs={ 'recurrent': True, 'hidden_size': args.rl2_hidden_size }) rl2_learner.to(device) agent = algo.RL2DrAC(actor_critic, rl2_learner, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, args.rl2_entropy_coef, lr=args.lr, eps=args.eps, rl2_lr=args.rl2_lr, rl2_eps=args.rl2_eps, max_grad_norm=args.max_grad_norm, aug_list=aug_list, aug_id=aug_id, aug_coef=args.aug_coef, num_aug_types=len(list(aug_to_func.keys())), recurrent_hidden_size=args.rl2_hidden_size, num_actions=envs.action_space.n, device=device) else: aug_id = data_augs.Identity aug_func = aug_to_func[args.aug_type](batch_size=batch_size) agent = algo.DrAC(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_id=aug_id, aug_func=aug_func, aug_coef=args.aug_coef, env_name=args.env_name) checkpoint_path = os.path.join(args.save_dir, "agent" + log_file + ".pt") if os.path.exists(checkpoint_path) and args.preempt: checkpoint = torch.load(checkpoint_path) agent.actor_critic.load_state_dict(checkpoint['model_state_dict']) agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) init_epoch = checkpoint['epoch'] + 1 logger.configure(dir=args.log_dir, format_strs=['csv', 'stdout'], log_suffix=log_file + "-e%s" % init_epoch) else: init_epoch = 0 logger.configure(dir=args.log_dir, format_strs=['csv', 'stdout'], log_suffix=log_file) obs = envs.reset() # envs!!!!!!!!!! rollouts.obs[0].copy_(obs) # 초기 obs 장착 rollouts.to(device) episode_rewards = deque(maxlen=10) # args.num_steps -> 256, 'number of forward steps in A2C') # args.num_env_steps -> 25e6, 'number of environment steps to train' num_updates = int( args.num_env_steps) // args.num_processes // args.num_steps # todo : 에폭이라... 그런데 이거 에피소드마다 종료되는 스탭이 다를텐데... for j in range(init_epoch, num_updates): actor_critic.train() for step in range(args.num_steps): # Sample actions with torch.no_grad(): obs_id = aug_id(rollouts.obs[step]) value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( obs_id, rollouts.recurrent_hidden_states[step], rollouts.masks[step]) # Observe reward and next obs # todo : check the shapes of obs, reward, done, infos obs, reward, done, infos = envs.step(action) for info in infos: if 'episode' in info.keys(): # todo : difference between reward and info['episode']['r'] episode_rewards.append(info['episode']['r']) # If done then clean the history of observations. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]) rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks) with torch.no_grad(): obs_id = aug_id(rollouts.obs[-1]) # todo : what is next_value for? next_value = actor_critic.get_value( obs_id, rollouts.recurrent_hidden_states[-1], rollouts.masks[-1]).detach() rollouts.compute_returns(next_value, args.gamma, args.gae_lambda) if args.use_ucb and j > 0: # from second epoch agent.update_ucb_values(rollouts) # update ucb # todo : 와 여기가 장난아니네 ㅠㅠ value_loss, action_loss, dist_entropy = agent.update(rollouts) # 뭔가 클리어! rollouts.after_update() # save for every interval-th episode or for the last epoch total_num_steps = (j + 1) * args.num_processes * args.num_steps if j % args.log_interval == 0 and len(episode_rewards) > 1: total_num_steps = (j + 1) * args.num_processes * args.num_steps print( "\nUpdate {}, step {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}" .format(j, total_num_steps, len(episode_rewards), np.mean(episode_rewards), np.median(episode_rewards), dist_entropy, value_loss, action_loss)) logger.logkv("train/nupdates", j) logger.logkv("train/total_num_steps", total_num_steps) logger.logkv("losses/dist_entropy", dist_entropy) logger.logkv("losses/value_loss", value_loss) logger.logkv("losses/action_loss", action_loss) logger.logkv("train/mean_episode_reward", np.mean(episode_rewards)) logger.logkv("train/median_episode_reward", np.median(episode_rewards)) ### Eval on the Full Distribution of Levels ### eval_episode_rewards = evaluate(args, actor_critic, device, aug_id=aug_id) logger.logkv("test/mean_episode_reward", np.mean(eval_episode_rewards)) logger.logkv("test/median_episode_reward", np.median(eval_episode_rewards)) logger.dumpkvs() # Save Model if (j > 0 and j % args.save_interval == 0 or j == num_updates - 1) and args.save_dir != "": try: os.makedirs(args.save_dir) except OSError: pass torch.save( { 'epoch': j, 'model_state_dict': agent.actor_critic.state_dict(), 'optimizer_state_dict': agent.optimizer.state_dict(), }, os.path.join(args.save_dir, "agent" + log_file + ".pt"))
def train(args): args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) log_dir = os.path.expanduser(args.log_dir) utils.cleanup_log_dir(log_dir) torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") log_file = '-{}-{}-reproduce-s{}'.format(args.run_name, args.env_name, args.seed) logger.configure(dir=args.log_dir, format_strs=['csv', 'stdout'], log_suffix=log_file) venv = ProcgenEnv(num_envs=args.num_processes, env_name=args.env_name, \ num_levels=args.num_levels, start_level=args.start_level, \ distribution_mode=args.distribution_mode) venv = VecExtractDictObs(venv, "rgb") venv = VecMonitor(venv=venv, filename=None, keep_buf=100) venv = VecNormalize(venv=venv, ob=False) envs = VecPyTorchProcgen(venv, device) obs_shape = envs.observation_space.shape actor_critic = Policy(obs_shape, envs.action_space.n, base_kwargs={ 'recurrent': False, 'hidden_size': args.hidden_size }) actor_critic.to(device) if modelbased: rollouts = BiggerRolloutStorage( args.num_steps, args.num_processes, envs.observation_space.shape, envs.action_space, actor_critic.recurrent_hidden_state_size, aug_type=args.aug_type, split_ratio=args.split_ratio) else: rollouts = RolloutStorage(args.num_steps, args.num_processes, envs.observation_space.shape, envs.action_space, actor_critic.recurrent_hidden_state_size, aug_type=args.aug_type, split_ratio=args.split_ratio) batch_size = int(args.num_processes * args.num_steps / args.num_mini_batch) if args.use_ucb: aug_id = data_augs.Identity aug_list = [ aug_to_func[t](batch_size=batch_size) for t in list(aug_to_func.keys()) ] agent = algo.UCBDrAC(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_list=aug_list, aug_id=aug_id, aug_coef=args.aug_coef, num_aug_types=len(list(aug_to_func.keys())), ucb_exploration_coef=args.ucb_exploration_coef, ucb_window_length=args.ucb_window_length) elif args.use_meta_learning: aug_id = data_augs.Identity aug_list = [aug_to_func[t](batch_size=batch_size) \ for t in list(aug_to_func.keys())] aug_model = AugCNN() aug_model.to(device) agent = algo.MetaDrAC(actor_critic, aug_model, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, meta_grad_clip=args.meta_grad_clip, meta_num_train_steps=args.meta_num_train_steps, meta_num_test_steps=args.meta_num_test_steps, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_id=aug_id, aug_coef=args.aug_coef) elif args.use_rl2: aug_id = data_augs.Identity aug_list = [ aug_to_func[t](batch_size=batch_size) for t in list(aug_to_func.keys()) ] rl2_obs_shape = [envs.action_space.n + 1] rl2_learner = Policy(rl2_obs_shape, len(list(aug_to_func.keys())), base_kwargs={ 'recurrent': True, 'hidden_size': args.rl2_hidden_size }) rl2_learner.to(device) agent = algo.RL2DrAC(actor_critic, rl2_learner, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, args.rl2_entropy_coef, lr=args.lr, eps=args.eps, rl2_lr=args.rl2_lr, rl2_eps=args.rl2_eps, max_grad_norm=args.max_grad_norm, aug_list=aug_list, aug_id=aug_id, aug_coef=args.aug_coef, num_aug_types=len(list(aug_to_func.keys())), recurrent_hidden_size=args.rl2_hidden_size, num_actions=envs.action_space.n, device=device) elif False: # Regular Drac aug_id = data_augs.Identity aug_func = aug_to_func[args.aug_type](batch_size=batch_size) agent = algo.DrAC(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_id=aug_id, aug_func=aug_func, aug_coef=args.aug_coef, env_name=args.env_name) elif False: # Model Free Planning Drac aug_id = data_augs.Identity aug_func = aug_to_func[args.aug_type](batch_size=batch_size) actor_critic = PlanningPolicy(obs_shape, envs.action_space.n, base_kwargs={ 'recurrent': False, 'hidden_size': args.hidden_size }) actor_critic.to(device) agent = algo.DrAC(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_id=aug_id, aug_func=aug_func, aug_coef=args.aug_coef, env_name=args.env_name) else: # Model based Drac aug_id = data_augs.Identity aug_func = aug_to_func[args.aug_type](batch_size=batch_size) actor_critic = ModelBasedPolicy(obs_shape, envs.action_space.n, base_kwargs={ 'recurrent': False, 'hidden_size': args.hidden_size }) actor_critic.to(device) agent = algo.ConvDrAC(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_id=aug_id, aug_func=aug_func, aug_coef=args.aug_coef, env_name=args.env_name) obs = envs.reset() rollouts.obs[0].copy_(obs) if modelbased: rollouts.next_obs[0].copy_(obs) # TODO: is this right? rollouts.to(device) episode_rewards = deque(maxlen=10) num_updates = int( args.num_env_steps) // args.num_steps // args.num_processes for j in trange(num_updates): actor_critic.train() for step in range(args.num_steps): # Sample actions with torch.no_grad(): obs_id = aug_id(rollouts.obs[step]) value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( obs_id, rollouts.recurrent_hidden_states[step], rollouts.masks[step]) # Obser reward and next obs obs, reward, done, infos = envs.step(action) for info in infos: if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) # If done then clean the history of observations. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]) rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks) with torch.no_grad(): obs_id = aug_id(rollouts.obs[-1]) next_value = actor_critic.get_value( obs_id, rollouts.recurrent_hidden_states[-1], rollouts.masks[-1]).detach() rollouts.compute_returns(next_value, args.gamma, args.gae_lambda) if args.use_ucb and j > 0: agent.update_ucb_values(rollouts) if isinstance(agent, algo.ConvDrAC): value_loss, action_loss, dist_entropy, transition_model_loss, reward_model_loss = agent.update( rollouts) else: value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.after_update() # save for every interval-th episode or for the last epoch total_num_steps = (j + 1) * args.num_processes * args.num_steps if j % args.log_interval == 0 and len(episode_rewards) > 1: total_num_steps = (j + 1) * args.num_processes * args.num_steps print( "\nUpdate {}, step {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}" .format(j, total_num_steps, len(episode_rewards), np.mean(episode_rewards), np.median(episode_rewards), dist_entropy, value_loss, action_loss)) logger.logkv("train/nupdates", j) logger.logkv("train/total_num_steps", total_num_steps) logger.logkv("losses/dist_entropy", dist_entropy) logger.logkv("losses/value_loss", value_loss) logger.logkv("losses/action_loss", action_loss) if isinstance(agent, algo.ConvDrAC): logger.logkv("losses/transition_model_loss", transition_model_loss) logger.logkv("losses/reward_model_loss", reward_model_loss) logger.logkv("train/mean_episode_reward", np.mean(episode_rewards)) logger.logkv("train/median_episode_reward", np.median(episode_rewards)) ### Eval on the Full Distribution of Levels ### eval_episode_rewards = evaluate(args, actor_critic, device, aug_id=aug_id) logger.logkv("test/mean_episode_reward", np.mean(eval_episode_rewards)) logger.logkv("test/median_episode_reward", np.median(eval_episode_rewards)) logger.dumpkvs()
args.num_processes = 1 venv = ProcgenEnv(num_envs=args.num_processes, env_name=args.env_name, \ num_levels=args.num_levels, start_level=args.start_level, \ distribution_mode=args.distribution_mode, render_mode="rgb_array") venv = VecExtractDictObs(venv, "rgb") venv = VecMonitor(venv=venv, filename=None, keep_buf=100) venv = VecNormalize(venv=venv, ob=False) envs = VecPyTorchProcgen(venv, device) obs_shape = envs.observation_space.shape actor_critic = Policy(obs_shape, envs.action_space.n, base_kwargs={ 'recurrent': False, 'hidden_size': args.hidden_size }) actor_critic.to(device) aug_id = data_augs.Identity batch_size = int(args.num_processes * args.num_steps / args.num_mini_batch) aug_list = [ aug_to_func[t](batch_size=batch_size) for t in list(aug_to_func.keys()) ] log_file = '-{}-{}-reproduce-s{}'.format(args.run_name, args.env_name, args.seed)
def train(args): args.cuda = not args.no_cuda and torch.cuda.is_available() print('Using CUDA: {}'.format(args.cuda)) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) log_dir = args.log_dir if not log_dir.startswith('gs://'): log_dir = os.path.expanduser(args.log_dir) torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") if not args.preempt: utils.cleanup_log_dir(log_dir) try: gfile.makedirs(log_dir) except: pass log_file = '-{}-{}-reproduce-s{}'.format(args.run_name, args.env_name, args.seed) save_dir = os.path.join(log_dir, 'checkpoints', log_file) gfile.makedirs(save_dir) venv = ProcgenEnv(num_envs=args.num_processes, env_name=args.env_name, \ num_levels=args.num_levels, start_level=args.start_level, \ distribution_mode=args.distribution_mode) venv = VecExtractDictObs(venv, "rgb") venv = VecMonitor(venv=venv, filename=None, keep_buf=100) venv = VecNormalize(venv=venv, ob=False) envs = VecPyTorchProcgen(venv, device) obs_shape = envs.observation_space.shape actor_critic = Policy(obs_shape, envs.action_space.n, base_kwargs={ 'recurrent': False, 'hidden_size': args.hidden_size }) actor_critic.to(device) rollouts = RolloutStorage(args.num_steps, args.num_processes, envs.observation_space.shape, envs.action_space, actor_critic.recurrent_hidden_state_size, aug_type=args.aug_type, split_ratio=args.split_ratio, store_policy=args.use_pse) batch_size = int(args.num_processes * args.num_steps / args.num_mini_batch) if args.use_ucb: aug_id = data_augs.Identity aug_list = [ aug_to_func[t](batch_size=batch_size) for t in list(aug_to_func.keys()) ] agent = algo.UCBDrAC(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_list=aug_list, aug_id=aug_id, aug_coef=args.aug_coef, num_aug_types=len(list(aug_to_func.keys())), ucb_exploration_coef=args.ucb_exploration_coef, ucb_window_length=args.ucb_window_length) elif args.use_meta_learning: aug_id = data_augs.Identity aug_list = [aug_to_func[t](batch_size=batch_size) \ for t in list(aug_to_func.keys())] aug_model = AugCNN() aug_model.to(device) agent = algo.MetaDrAC(actor_critic, aug_model, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, meta_grad_clip=args.meta_grad_clip, meta_num_train_steps=args.meta_num_train_steps, meta_num_test_steps=args.meta_num_test_steps, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_id=aug_id, aug_coef=args.aug_coef) elif args.use_rl2: aug_id = data_augs.Identity aug_list = [ aug_to_func[t](batch_size=batch_size) for t in list(aug_to_func.keys()) ] rl2_obs_shape = [envs.action_space.n + 1] rl2_learner = Policy(rl2_obs_shape, len(list(aug_to_func.keys())), base_kwargs={ 'recurrent': True, 'hidden_size': args.rl2_hidden_size }) rl2_learner.to(device) agent = algo.RL2DrAC(actor_critic, rl2_learner, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, args.rl2_entropy_coef, lr=args.lr, eps=args.eps, rl2_lr=args.rl2_lr, rl2_eps=args.rl2_eps, max_grad_norm=args.max_grad_norm, aug_list=aug_list, aug_id=aug_id, aug_coef=args.aug_coef, num_aug_types=len(list(aug_to_func.keys())), recurrent_hidden_size=args.rl2_hidden_size, num_actions=envs.action_space.n, device=device) elif args.use_rad: aug_id = data_augs.Identity aug_func = aug_to_func[args.aug_type](batch_size=batch_size) pse_coef = args.pse_coef if args.use_pse: assert args.pse_coef > 0, "Please pass a non-zero pse_coef" else: pse_coef = 0.0 print("Running RAD ..") print( "PSE: {}, Coef: {}, Gamma: {}, Temp: {}, Coupling Temp: {}".format( args.use_pse, pse_coef, args.pse_gamma, args.pse_temperature, args.pse_coupling_temperature)) print('use_augmentation: {}'.format(args.use_augmentation)) agent = algo.RAD( actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_id=aug_id, aug_func=aug_func, env_name=args.env_name, use_augmentation=args.use_augmentation, pse_gamma=args.pse_gamma, pse_coef=pse_coef, pse_temperature=args.pse_temperature, pse_coupling_temperature=args.pse_coupling_temperature) else: aug_id = data_augs.Identity aug_func = aug_to_func[args.aug_type](batch_size=batch_size) pse_coef = args.pse_coef if args.use_pse: assert args.pse_coef > 0, "Please pass a non-zero pse_coef" else: pse_coef = 0.0 print("Running DraC ..") print("PSE: {}, Coef: {}, Gamma: {}, Temp: {}".format( args.use_pse, pse_coef, args.pse_gamma, args.pse_temperature)) agent = algo.DrAC(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_id=aug_id, aug_func=aug_func, aug_coef=args.aug_coef, env_name=args.env_name, pse_gamma=args.pse_gamma, pse_coef=pse_coef, pse_temperature=args.pse_temperature) checkpoint_path = os.path.join(save_dir, "agent" + log_file + ".pt") if gfile.exists(checkpoint_path) and args.preempt: with gfile.GFile(checkpoint_path, 'rb') as f: inbuffer = io.BytesIO(f.read()) checkpoint = torch.load(inbuffer) agent.actor_critic.load_state_dict(checkpoint['model_state_dict']) agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) init_epoch = checkpoint['epoch'] + 1 print('Loaded ckpt from epoch {}'.format(init_epoch - 1)) logger.configure(dir=args.log_dir, format_strs=['csv', 'stdout', 'tensorboard'], log_suffix=log_file, init_step=init_epoch) else: init_epoch = 0 logger.configure(dir=args.log_dir, format_strs=['csv', 'stdout', 'tensorboard'], log_suffix=log_file, init_step=init_epoch) obs = envs.reset() rollouts.obs[0].copy_(obs) rollouts.to(device) episode_rewards = deque(maxlen=10) num_updates = int( args.num_env_steps) // args.num_steps // args.num_processes for j in range(init_epoch, num_updates): actor_critic.train() for step in range(args.num_steps): # Sample actions with torch.no_grad(): obs_id = aug_id(rollouts.obs[step]) value, action, action_log_prob, recurrent_hidden_states, pi = actor_critic.act( obs_id, rollouts.recurrent_hidden_states[step], rollouts.masks[step], policy=True) # Obser reward and next obs obs, reward, done, infos = envs.step(action) for info in infos: if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) # If done then clean the history of observations. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]) rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks, pi=pi) with torch.no_grad(): obs_id = aug_id(rollouts.obs[-1]) next_value = actor_critic.get_value( obs_id, rollouts.recurrent_hidden_states[-1], rollouts.masks[-1]).detach() rollouts.compute_returns(next_value, args.gamma, args.gae_lambda) if args.use_ucb and j > 0: agent.update_ucb_values(rollouts) value_loss, action_loss, dist_entropy, pse_loss = agent.update( rollouts) rollouts.after_update() # save for every interval-th episode or for the last epoch total_num_steps = (j + 1) * args.num_processes * args.num_steps if j % args.log_interval == 0 and len(episode_rewards) > 1: total_num_steps = (j + 1) * args.num_processes * args.num_steps print( "\nUpdate {}, step {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}" .format(j, total_num_steps, len(episode_rewards), np.mean(episode_rewards), np.median(episode_rewards), dist_entropy, value_loss, action_loss)) logger.logkv("train/nupdates", j) logger.logkv("train/total_num_steps", total_num_steps) logger.logkv("losses/dist_entropy", dist_entropy) logger.logkv("losses/value_loss", value_loss) logger.logkv("losses/action_loss", action_loss) if args.use_pse: logger.logkv("losses/pse_loss", pse_loss) logger.logkv("train/mean_episode_reward", np.mean(episode_rewards)) logger.logkv("train/median_episode_reward", np.median(episode_rewards)) ### Eval on the Full Distribution of Levels ### eval_episode_rewards = evaluate(args, actor_critic, device, aug_id=aug_id) logger.logkv("test/mean_episode_reward", np.mean(eval_episode_rewards)) logger.logkv("test/median_episode_reward", np.median(eval_episode_rewards)) logger.dumpkvs() # Save Model if (j > 0 and j % args.save_interval == 0 or j == num_updates - 1) and save_dir != "": try: gfile.makedirs(save_dir) except OSError: pass ckpt_file = os.path.join(save_dir, "agent" + log_file + ".pt") outbuffer = io.BytesIO() torch.save( { 'epoch': j, 'model_state_dict': agent.actor_critic.state_dict(), 'optimizer_state_dict': agent.optimizer.state_dict() }, outbuffer) with gfile.GFile(ckpt_file, 'wb') as fout: fout.write(outbuffer.getvalue()) save_num_steps = (j + 1) * args.num_processes * args.num_steps print("\nUpdate {}, step {}, Saved {}.".format( j, save_num_steps, ckpt_file))