def train(args, seeds): args.cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda:0" if args.cuda else "cpu") if 'cuda' in device.type: print('Using CUDA\n') torch.set_num_threads(1) utils.seed(args.seed) # Configure logging if args.xpid is None: args.xpid = "lr-%s" % time.strftime("%Y%m%d-%H%M%S") log_dir = os.path.expandvars(os.path.expanduser(args.log_dir)) plogger = FileWriter( xpid=args.xpid, xp_args=args.__dict__, rootdir=log_dir, seeds=seeds, ) stdout_logger = HumanOutputFormat(sys.stdout) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (log_dir, args.xpid, "model.tar")) ) global last_checkpoint_time current_update_count = 0 initial_update_count = 0 last_logged_update_count_at_restart = -1 start_level = 0 num_levels = 1 level_sampler_args = dict( num_actors=args.num_processes, strategy=args.level_replay_strategy, max_score_coef=args.level_replay_max_score_coef, replay_schedule=args.level_replay_schedule, score_transform=args.level_replay_score_transform, temperature=args.level_replay_temperature, eps=args.level_replay_eps, rho=args.level_replay_rho, replay_prob=args.level_replay_prob, alpha=args.level_replay_alpha, staleness_coef=args.staleness_coef, staleness_transform=args.staleness_transform, staleness_temperature=args.staleness_temperature, sample_full_distribution=args.train_full_distribution, seed_buffer_size=args.level_replay_seed_buffer_size, seed_buffer_priority=args.level_replay_seed_buffer_priority, tscl_window_size=args.tscl_window_size) level_sampler_secondary_args = {} if args.level_replay_secondary_strategy: level_sampler_secondary_args = dict( strategy=args.level_replay_secondary_strategy, score_transform=args.level_replay_secondary_score_transform, temperature=args.level_replay_secondary_temperature, eps=args.level_replay_secondary_eps, staleness_coef=args.secondary_staleness_coef, staleness_transform=args.secondary_staleness_transform, staleness_temperature=args.secondary_staleness_temperature,) args_tmp = level_sampler_args.copy() args_tmp.update(level_sampler_secondary_args) level_sampler_secondary_args = args_tmp envs, level_sampler, secondary_level_sampler = make_lr_venv( num_envs=args.num_processes, env_name=args.env_name, seeds=seeds, device=device, num_levels=num_levels, start_level=start_level, no_ret_normalization=args.no_ret_normalization, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, level_sampler_args=level_sampler_args, level_sampler_secondary_args=level_sampler_secondary_args, level_replay_strategy_mix_coef=args.level_replay_strategy_mix_coef) is_minigrid = args.env_name.startswith('MiniGrid') actor_critic = model_for_env_name(args, envs) actor_critic.to(device) rollouts = RolloutStorage(args.num_steps, args.num_processes, envs.observation_space.shape, envs.action_space, actor_critic.recurrent_hidden_state_size) batch_size = int(args.num_processes * args.num_steps / args.num_mini_batch) aug_id = None if args.algo == 'ucb' or args.use_ucb: print('Using UCB') aug_id = data_augs.Identity aug_list = [aug_to_func[t](batch_size=batch_size) for t in list(aug_to_func.keys())] mix_alpha = None if not args.use_mixreg else args.mixreg_alpha agent = algo.UCBDrAC( actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, aug_list=aug_list, aug_id=aug_id, aug_coef=args.aug_coef, num_aug_types=len(list(aug_to_func.keys())), ucb_exploration_coef=args.ucb_exploration_coef, ucb_window_length=args.ucb_window_length, mix_alpha=mix_alpha, log_grad_norm=args.log_grad_norm) elif args.algo == 'mixreg' or args.use_mixreg: agent = algo.MixRegPPO( actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, log_grad_norm=args.log_grad_norm, mix_alpha=args.mixreg_alpha) else: agent = algo.PPO( actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, log_grad_norm=args.log_grad_norm) level_seeds = torch.zeros(args.num_processes) if level_sampler: obs, level_seeds = envs.reset() else: obs = envs.reset() level_seeds = level_seeds.unsqueeze(-1) rollouts.obs[0].copy_(obs) rollouts.to(device) episode_rewards = deque(maxlen=10) num_updates = int( args.num_env_steps) // args.num_steps // args.num_processes def checkpoint(): if args.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) checkpoint_states = { "model_state_dict": actor_critic.state_dict(), "optimizer_state_dict": agent.optimizer.state_dict(), "rollouts": rollouts, "episode_rewards": episode_rewards, "level_sampler": level_sampler, "current_update_count": current_update_count } if hasattr(agent, 'bandit_state_dict'): checkpoint_states.update({ "bandit_state_dict": agent.bandit_state_dict() }) torch.save( checkpoint_states, checkpointpath ) # Load checkpoint if args.checkpoint and os.path.exists(checkpointpath): checkpoint_states = torch.load(checkpointpath) actor_critic.load_state_dict(checkpoint_states['model_state_dict']) agent.optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) rollouts = checkpoint_states["rollouts"] episode_rewards = checkpoint_states["episode_rewards"] level_sampler = checkpoint_states["level_sampler"] current_update_count = checkpoint_states["current_update_count"] initial_update_count = current_update_count last_logged_update_count_at_restart = plogger.latest_tick() + 1 # ticks are 0-indexed updates if hasattr(agent, 'load_bandit_state_dict'): agent.load_bandit_state_dict(checkpoint_states["bandit_state_dict"]) logging.info(f"Resuming preempted job from update: {current_update_count}\n") timer = timeit.default_timer update_start_time = timer() agent_id = 0 # np.random.choice(range(actor_critic.ensemble_size)) for j in range(initial_update_count, num_updates): actor_critic.train() for step in range(args.num_steps): # Sample actions with torch.no_grad(): obs_id = rollouts.obs[step] if aug_id: obs_id = aug_id(obs_id) value, action, action_log_dist, recurrent_hidden_states = actor_critic.act( obs_id, rollouts.recurrent_hidden_states[step], rollouts.masks[step], agent_id=agent_id) action_log_prob = action_log_dist.gather(-1, action) uncertainties = actor_critic.get_uncertainty(obs_id, rollouts.recurrent_hidden_states[step], rollouts.masks[step]) # Observed reward and next obs obs, reward, done, infos = envs.step(action) # Reset all done levels by sampling from level sampler for i, info in enumerate(infos): if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) if level_sampler: level_seeds[i][0] = info['level_seed'] # If done then clean the history of observations. masks = torch.FloatTensor( [[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]) rollouts.insert( obs, recurrent_hidden_states, action, action_log_prob, action_log_dist, value, reward, masks, bad_masks, uncertainties, level_seeds) with torch.no_grad(): obs_id = rollouts.obs[-1] if aug_id: obs_id = aug_id(obs_id) next_value = actor_critic.get_value( obs_id, rollouts.recurrent_hidden_states[-1], rollouts.masks[-1]).detach() rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda) # Update level sampler if level_sampler: level_sampler.update_with_rollouts(rollouts) if secondary_level_sampler: secondary_level_sampler.update_with_rollouts(rollouts) if args.use_ucb and j > 0: agent.update_ucb_values(rollouts) value_loss, action_loss, dist_entropy, info = agent.update(rollouts) rollouts.after_update() if level_sampler: level_sampler.after_update() if secondary_level_sampler: secondary_level_sampler.after_update() current_update_count = j + 1 # ==== Everything below here is for logging + checkpointing ==== if current_update_count <= last_logged_update_count_at_restart: continue # Log stats every log_interval updates or if it is the last update if (j % args.log_interval == 0 and len(episode_rewards) > 1) or j == num_updates - 1: total_num_steps = (j + 1) * args.num_processes * args.num_steps update_end_time = timer() num_interval_updates = 1 if j == 0 else args.log_interval sps = num_interval_updates*(args.num_processes * args.num_steps) / (update_end_time - update_start_time) update_start_time = update_end_time logging.info(f"\nUpdate {j} done, {total_num_steps} steps\n ") logging.info(f"\nEvaluating on {args.num_test_seeds} test levels...\n ") eval_episode_rewards = evaluate( args, actor_critic, args.num_test_seeds, device, aug_id=aug_id) logging.info(f"\nEvaluating on {args.num_test_seeds} train levels...\n ") train_eval_episode_rewards = evaluate( args, actor_critic, args.num_test_seeds, # Use same number of levels for evaluating train and test seeds device, start_level=0, num_levels=args.num_train_seeds, seeds=seeds, aug_id=aug_id) stats = { "step": total_num_steps, "pg_loss": action_loss, "value_loss": value_loss, "dist_entropy": dist_entropy, "train:mean_episode_return": np.mean(episode_rewards), "train:median_episode_return": np.median(episode_rewards), "test:mean_episode_return": np.mean(eval_episode_rewards), "test:median_episode_return": np.median(eval_episode_rewards), "train_eval:mean_episode_return": np.mean(train_eval_episode_rewards), "train_eval:median_episode_return": np.median(train_eval_episode_rewards), "sps": sps, } if args.log_grad_norm: stats.update({ "mean_grad_norm": np.mean(info['grad_norms']) }) if is_minigrid: stats["train:success_rate"] = np.mean(np.array(episode_rewards) > 0) stats["train_eval:success_rate"] = np.mean(np.array(train_eval_episode_rewards) > 0) stats["test:success_rate"] = np.mean(np.array(eval_episode_rewards) > 0) if j == num_updates - 1: logging.info(f"\nLast update: Evaluating on {args.num_test_seeds} test levels...\n ") final_eval_episode_rewards = evaluate(args, actor_critic, args.final_num_test_seeds, device) mean_final_eval_episode_rewards = np.mean(final_eval_episode_rewards) median_final_eval_episide_rewards = np.median(final_eval_episode_rewards) plogger.log_final_test_eval({ 'num_test_seeds': args.final_num_test_seeds, 'mean_episode_return': mean_final_eval_episode_rewards, 'median_episode_return': median_final_eval_episide_rewards }) plogger.log(stats) if args.verbose: stdout_logger.writekvs(stats) # Log level weights if level_sampler and j % args.weight_log_interval == 0: plogger.log_level_weights(level_sampler.sample_weights(), level_sampler.seeds) # Checkpoint timer = timeit.default_timer if last_checkpoint_time is None: last_checkpoint_time = timer() try: if j == num_updates - 1 or \ (args.save_interval > 0 and timer() - last_checkpoint_time > args.save_interval * 60): # Save every 60 min. checkpoint() last_checkpoint_time = timer() logging.info(f"\nSaved checkpoint after update {current_update_count}") except KeyboardInterrupt: return
def evaluate( args, actor_critic, num_episodes, device, num_processes=1, deterministic=False, start_level=0, num_levels=0, seeds=None, level_sampler=None, progressbar=None): actor_critic.eval() # if level_sampler: # start_level = level_sampler.seed_range()[0] # num_levels = 1 eval_envs, level_sampler = make_lr_venv( num_envs=num_processes, env_name=args.env_name, seeds=seeds, device=device, num_levels=num_levels, start_level=start_level, no_ret_normalization=args.no_ret_normalization, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, level_sampler=level_sampler) eval_episode_rewards = [] if level_sampler: obs, _ = eval_envs.reset() else: obs = eval_envs.reset() eval_recurrent_hidden_states = torch.zeros( num_processes, actor_critic.recurrent_hidden_state_size, device=device) eval_masks = torch.ones(num_processes, 1, device=device) transitions = [] print(deterministic) while len(eval_episode_rewards) < num_episodes: with torch.no_grad(): _, action, _, eval_recurrent_hidden_states = actor_critic.act( obs, eval_recurrent_hidden_states, eval_masks, deterministic=deterministic) obs, r, done, infos = eval_envs.step(action) transitions.append((get_np(r), get_np(done), infos)) eval_masks = torch.tensor( [[0.0] if done_ else [1.0] for done_ in done], dtype=torch.float32, device=device) for info in infos: if 'episode' in info.keys(): eval_episode_rewards.append(info['episode']['r']) if progressbar: progressbar.update(1) eval_envs.close() if progressbar: progressbar.close() if args.verbose: print("Last {} test episodes: mean/median reward {:.1f}/{:.1f}\n"\ .format(len(eval_episode_rewards), \ np.mean(eval_episode_rewards), np.median(eval_episode_rewards))) return eval_episode_rewards, transitions
def train(args, seeds): global last_checkpoint_time args.cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda:0" if args.cuda else "cpu") if 'cuda' in device.type: print('Using CUDA\n') torch.set_num_threads(1) utils.seed(args.seed) # Configure logging if args.xpid is None: args.xpid = "lr-%s" % time.strftime("%Y%m%d-%H%M%S") log_dir = os.path.expandvars(os.path.expanduser(args.log_dir)) plogger = FileWriter( xpid=args.xpid, xp_args=args.__dict__, rootdir=log_dir, seeds=seeds, ) stdout_logger = HumanOutputFormat(sys.stdout) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (log_dir, args.xpid, "model.tar"))) # Configure actor envs start_level = 0 if args.full_train_distribution: num_levels = 0 level_sampler_args = None seeds = None else: num_levels = 1 level_sampler_args = dict( num_actors=args.num_processes, strategy=args.level_replay_strategy, replay_schedule=args.level_replay_schedule, score_transform=args.level_replay_score_transform, temperature=args.level_replay_temperature, eps=args.level_replay_eps, rho=args.level_replay_rho, nu=args.level_replay_nu, alpha=args.level_replay_alpha, staleness_coef=args.staleness_coef, staleness_transform=args.staleness_transform, staleness_temperature=args.staleness_temperature) envs, level_sampler = make_lr_venv( num_envs=args.num_processes, env_name=args.env_name, seeds=seeds, device=device, num_levels=num_levels, start_level=start_level, no_ret_normalization=args.no_ret_normalization, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, level_sampler_args=level_sampler_args) is_minigrid = args.env_name.startswith('MiniGrid') actor_critic = model_for_env_name(args, envs) actor_critic.to(device) print(actor_critic) rollouts = RolloutStorage(args.num_steps, args.num_processes, envs.observation_space.shape, envs.action_space, actor_critic.recurrent_hidden_state_size) batch_size = int(args.num_processes * args.num_steps / args.num_mini_batch) def checkpoint(): if args.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": actor_critic.state_dict(), "optimizer_state_dict": agent.optimizer.state_dict(), "args": vars(args), }, checkpointpath, ) agent = algo.PPO(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm, env_name=args.env_name) level_seeds = torch.zeros(args.num_processes) if level_sampler: obs, level_seeds = envs.reset() else: obs = envs.reset() level_seeds = level_seeds.unsqueeze(-1) rollouts.obs[0].copy_(obs) rollouts.to(device) episode_rewards = deque(maxlen=10) num_updates = int( args.num_env_steps) // args.num_steps // args.num_processes timer = timeit.default_timer update_start_time = timer() for j in range(num_updates): actor_critic.train() for step in range(args.num_steps): # Sample actions with torch.no_grad(): obs_id = rollouts.obs[step] value, action, action_log_dist, recurrent_hidden_states = actor_critic.act( obs_id, rollouts.recurrent_hidden_states[step], rollouts.masks[step]) action_log_prob = action_log_dist.gather(-1, action) # Obser reward and next obs obs, reward, done, infos = envs.step(action) # Reset all done levels by sampling from level sampler for i, info in enumerate(infos): if 'episode' in info.keys(): episode_rewards.append(info['episode']['r']) if level_sampler: level_seeds[i][0] = info['level_seed'] # If done then clean the history of observations. masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]) rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, action_log_dist, value, reward, masks, bad_masks, level_seeds) with torch.no_grad(): obs_id = rollouts.obs[-1] next_value = actor_critic.get_value( obs_id, rollouts.recurrent_hidden_states[-1], rollouts.masks[-1]).detach() rollouts.compute_returns(next_value, args.gamma, args.gae_lambda) # Update level sampler if level_sampler: level_sampler.update_with_rollouts(rollouts) value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.after_update() if level_sampler: level_sampler.after_update() # Log stats every log_interval updates or if it is the last update if (j % args.log_interval == 0 and len(episode_rewards) > 1) or j == num_updates - 1: total_num_steps = (j + 1) * args.num_processes * args.num_steps update_end_time = timer() num_interval_updates = 1 if j == 0 else args.log_interval sps = num_interval_updates * (args.num_processes * args.num_steps) / (update_end_time - update_start_time) update_start_time = update_end_time logging.info(f"\nUpdate {j} done, {total_num_steps} steps\n ") logging.info( f"\nEvaluating on {args.num_test_seeds} test levels...\n ") eval_episode_rewards, transitions = evaluate( args, actor_critic, args.num_test_seeds, device) plogger._save_data(transitions, f'test_trajectories_{j}.pkl') logging.info( f"\nEvaluating on {args.num_test_seeds} train levels...\n ") train_eval_episode_rewards, transitions = evaluate( args, actor_critic, args.num_test_seeds, device, start_level=0, num_levels=args.num_train_seeds, seeds=seeds, level_sampler=level_sampler) stats = { "step": total_num_steps, "pg_loss": action_loss, "value_loss": value_loss, "dist_entropy": dist_entropy, "train:mean_episode_return": np.mean(episode_rewards), "train:median_episode_return": np.median(episode_rewards), "test:mean_episode_return": np.mean(eval_episode_rewards), "test:median_episode_return": np.median(eval_episode_rewards), "train_eval:mean_episode_return": np.mean(train_eval_episode_rewards), "train_eval:median_episode_return": np.median(train_eval_episode_rewards), "sps": sps, } if is_minigrid: stats["train:success_rate"] = np.mean( np.array(episode_rewards) > 0) stats["train_eval:success_rate"] = np.mean( np.array(train_eval_episode_rewards) > 0) stats["test:success_rate"] = np.mean( np.array(eval_episode_rewards) > 0) if j == num_updates - 1: logging.info( f"\nLast update: Evaluating on {args.num_test_seeds} test levels...\n " ) final_eval_episode_rewards, transitions = evaluate( args, actor_critic, args.final_num_test_seeds, device) mean_final_eval_episode_rewards = np.mean( final_eval_episode_rewards) median_final_eval_episide_rewards = np.median( final_eval_episode_rewards) plogger.log_final_test_eval({ 'num_test_seeds': args.final_num_test_seeds, 'mean_episode_return': mean_final_eval_episode_rewards, 'median_episode_return': median_final_eval_episide_rewards }) plogger.log(stats) if args.verbose: stdout_logger.writekvs(stats) # Log level weights if level_sampler and j % args.weight_log_interval == 0: plogger.log_level_weights(level_sampler.sample_weights()) # Checkpoint timer = timeit.default_timer if last_checkpoint_time is None: last_checkpoint_time = timer() try: if j == num_updates - 1 or \ (args.save_interval > 0 and timer() - last_checkpoint_time > args.save_interval * 60): # Save every 10 min. checkpoint() last_checkpoint_time = timer() except KeyboardInterrupt: return
def evaluate_saved_model( args, result_dir, xpid, num_episodes=10, seeds=None, deterministic=False, verbose=False, progressbar=False, num_processes=1): args.cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda:0" if args.cuda else "cpu") if 'cuda' in device.type: print('Using CUDA\n') if verbose: logging.basicConfig(stream=sys.stdout, level=logging.INFO) if args.xpid is None: checkpointpath = os.path.expandvars( os.path.expanduser(os.path.join(result_dir, "latest", "model.tar")) ) else: checkpointpath = os.path.expandvars( os.path.expanduser(os.path.join(result_dir, xpid, "model.tar")) ) # Set up level sampler if seeds is None: seeds = [int.from_bytes(os.urandom(4), byteorder="little") for _ in range(num_episodes)] dummy_env, _ = make_lr_venv( num_envs=num_processes, env_name=args.env_name, seeds=None, device=device, num_levels=1, start_level=1, no_ret_normalization=args.no_ret_normalization, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info) level_sampler = LevelSampler( seeds, dummy_env.observation_space, dummy_env.action_space, strategy='sequential') model = model_for_env_name(args, dummy_env) pbar = None if progressbar: pbar = tqdm(total=num_episodes) if torch.cuda.is_available(): map_location=lambda storage, loc: storage.cuda() else: map_location='cpu' checkpoint = torch.load(checkpointpath, map_location=map_location) model.load_state_dict(checkpoint["model_state_dict"]) num_processes = min(num_processes, num_episodes) eval_episode_rewards = \ evaluate(args, model, num_episodes, device=device, num_processes=num_processes, level_sampler=level_sampler, progressbar=pbar) mean_return = np.mean(eval_episode_rewards) median_return = np.median(eval_episode_rewards) logging.info( "Average returns over %i episodes: %.2f", num_episodes, mean_return ) logging.info( "Median returns over %i episodes: %.2f", num_episodes, median_return ) return mean_return, median_return