def build_and_train(slot_affinity_code, log_dir, run_ID): # (Or load from a central store of configs.) config = dict( env=dict(game="pong"), algo=dict(learning_rate=7e-4), sampler=dict(batch_B=16), ) affinity = get_affinity(slot_affinity_code) variant = load_variant(log_dir) global config config = update_config(config, variant) sampler = GpuParallelSampler( EnvCls=AtariEnv, env_kwargs=config["env"], CollectorCls=WaitResetCollector, batch_T=5, # batch_B=16, # Get from config. max_decorrelation_steps=400, **config["sampler"]) algo = A2C(**config["algo"]) # Run with defaults. agent = AtariFfAgent() runner = MinibatchRl( algo=algo, agent=agent, sampler=sampler, n_steps=50e6, log_interval_steps=1e5, affinity=affinity, ) name = "a2c_" + config["env"]["game"] log_dir = "example_6" with logger_context(log_dir, run_ID, name, config): runner.train()
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key): affinity = affinity_from_code(slot_affinity_code) assert isinstance(affinity, list) # One for each GPU. config = configs[config_key] variant = load_variant(log_dir) config = update_config(config, variant) sampler = GpuSampler( EnvCls=AtariEnv, env_kwargs=config["env"], CollectorCls=GpuWaitResetCollector, TrajInfoCls=AtariTrajInfo, **config["sampler"] ) algo = A2C(optim_kwargs=config["optim"], **config["algo"]) agent = AtariFfAgent(model_kwargs=config["model"], **config["agent"]) runner = SyncRl( algo=algo, agent=agent, sampler=sampler, affinity=affinity, **config["runner"] ) name = config["env"]["game"] with logger_context(log_dir, run_ID, name, config): runner.train()
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key): affinity = get_affinity(slot_affinity_code) config = configs[config_key] variant = load_variant(log_dir) config = update_config(config, variant) sampler = GpuParallelSampler( EnvCls=AtariEnv, env_kwargs=config["env"], CollectorCls=WaitResetCollector, TrajInfoCls=AtariTrajInfo, **config["sampler"] ) algo = A2C(optim_kwargs=config["optim"], **config["algo"]) agent = AtariFfAgent(model_kwargs=config["model"], **config["agent"]) runner = MinibatchRl( algo=algo, agent=agent, sampler=sampler, affinity=affinity, **config["runner"] ) name = config["env"]["game"] with logger_context(log_dir, run_ID, name, config): runner.train()
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key): affinity = affinity_from_code(slot_affinity_code) config = configs[config_key] # variant = load_variant(log_dir) # config = update_config(config, variant) sampler = CpuSampler( EnvCls=AtariEnv, env_kwargs=config["env"], CollectorCls=EpisodicLivesWaitResetCollector, TrajInfoCls=AtariTrajInfo, **config["sampler"] ) algo = A2C(optim_kwargs=config["optim"], **config["algo"]) agent = AtariFfAgent(model_kwargs=config["model"], **config["agent"]) runner = MinibatchRl( algo=algo, agent=agent, sampler=sampler, affinity=affinity, **config["runner"] ) name = config["env"]["game"] with logger_context(log_dir, run_ID, name, config): # Might have to flatten config runner.train()
def build_and_train(level="nav_maze_random_goal_01", run_ID=0, cuda_idx=None): affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(8))) sampler = SerialSampler( EnvCls=DeepmindLabEnv, env_kwargs=dict(level=level), eval_env_kwargs=dict(level=level), batch_T=4, # Four time-steps per sampler iteration. batch_B=1, max_decorrelation_steps=0, eval_n_envs=5, eval_max_steps=int(10e3), eval_max_trajectories=5, ) algo = PPO() agent = AtariFfAgent() runner = MinibatchRlEval( algo=algo, agent=agent, sampler=sampler, n_steps=50e6, log_interval_steps=1e3, affinity=affinity, ) config = dict(level=level) name = "lab_ppo" log_dir = "lab_example_3" with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"): runner.train()
def build_and_train(game="pong", run_ID=0, cuda_idx=None, sample_mode="serial", n_parallel=2): affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel))) gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}" if sample_mode == "serial": Sampler = SerialSampler # (Ignores workers_cpus.) print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.") elif sample_mode == "cpu": Sampler = CpuSampler print( f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing." ) elif sample_mode == "gpu": Sampler = GpuSampler print( f"Using GPU parallel sampler (agent in master), {gpu_cpu} for sampling and optimizing." ) elif sample_mode == "alternating": Sampler = AlternatingSampler affinity["workers_cpus"] += affinity["workers_cpus"] # (Double list) affinity["alternating"] = True # Sampler will check for this. print( f"Using Alternating GPU parallel sampler, {gpu_cpu} for sampling and optimizing." ) sampler = Sampler( EnvCls=AtariEnv, TrajInfoCls=AtariTrajInfo, env_kwargs=dict(game=game), batch_T=5, # 5 time-steps per sampler iteration. batch_B=16, # 16 parallel environments. max_decorrelation_steps=400, ) algo = A2C() # Run with defaults. agent = AtariFfAgent() runner = MinibatchRl( algo=algo, agent=agent, sampler=sampler, n_steps=50e6, log_interval_steps=1e5, affinity=affinity, ) config = dict(game=game) name = "a2c_" + game log_dir = "example_3" with logger_context(log_dir, run_ID, name, config): runner.train()
def build_and_train(game="pong", run_ID=0): # Seems like we should be able to skip the intermediate step of the code, # but so far have just always run that way. # Change these inputs to match local machine and desired parallelism. affinity_code = encode_affinity( n_cpu_cores=16, # Use 16 cores across all experiments. n_gpu=8, # Use 8 gpus across all experiments. hyperthread_offset=24, # If machine has 24 cores. n_socket=2, # Presume CPU socket affinity to lower/upper half GPUs. gpu_per_run=2, # How many GPUs to parallelize one run across. # cpu_per_run=1, ) slot_affinity_code = prepend_run_slot(run_slot=0, affinity_code=affinity_code) affinity = get_affinity(slot_affinity_code) breakpoint() sampler = GpuParallelSampler( EnvCls=AtariEnv, env_kwargs=dict(game=game), CollectorCls=WaitResetCollector, batch_T=5, batch_B=16, max_decorrelation_steps=400, ) algo = A2C() # Run with defaults. agent = AtariFfAgent() runner = MultiGpuRl( algo=algo, agent=agent, sampler=sampler, n_steps=50e6, log_interval_steps=1e5, affinity=affinity, ) config = dict(game=game) name = "a2c_" + game log_dir = "example_7" with logger_context(log_dir, run_ID, name, config): runner.train()
def build_and_train(game="breakout", run_ID=0, cuda_idx=None, sample_mode="serial", n_parallel=2): affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel))) gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}" if sample_mode == "serial": Sampler = SerialSampler # (Ignores workers_cpus.) print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.") elif sample_mode == "gpu": Sampler = GpuSampler print( f"Using GPU parallel sampler (agent in master), {gpu_cpu} for sampling and optimizing." ) env_kwargs = dict(game=game, repeat_action_probability=0.25, horizon=int(18e3)) sampler = Sampler( EnvCls=AtariEnv, TrajInfoCls=AtariTrajInfo, # default traj info + GameScore env_kwargs=env_kwargs, batch_T=128, batch_B=1, max_decorrelation_steps=0) algo = PPO(minibatches=4, epochs=4, entropy_loss_coeff=0.001, learning_rate=0.0001, gae_lambda=0.95, discount=0.999) base_model_kwargs = dict( # Same front-end architecture as RND model, different fc kwarg name channels=[32, 64, 64], kernel_sizes=[8, 4, 4], strides=[(4, 4), (2, 2), (1, 1)], paddings=[0, 0, 0], fc_sizes=[512] # Automatically applies nonlinearity=torch.nn.ReLU in this case, # but can't specify due to rlpyt limitations ) agent = AtariFfAgent(model_kwargs=base_model_kwargs) runner = MinibatchRl( algo=algo, agent=agent, sampler=sampler, n_steps=int( 49152e4 ), # this is 30k rollouts per environment at (T, B) = (128, 128) log_interval_steps=int(1e3), affinity=affinity) config = dict(game=game) name = "ppo_" + game log_dir = "baseline" with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"): runner.train()
def start_experiment(args): args_json = json.dumps(vars(args), indent=4) if not os.path.isdir(args.log_dir): os.makedirs(args.log_dir) with open(args.log_dir + '/arguments.json', 'w') as jsonfile: jsonfile.write(args_json) with open(args.log_dir + '/git.txt', 'w') as git_file: branch = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip().decode('utf-8') commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip().decode('utf-8') git_file.write('{}/{}'.format(branch, commit)) config = dict(env_id=args.env) if args.sample_mode == 'gpu': # affinity = dict(num_gpus=args.num_gpus, workers_cpus=list(range(args.num_cpus))) if args.num_gpus > 0: # import ipdb; ipdb.set_trace() affinity = make_affinity( run_slot=0, n_cpu_core=args.num_cpus, # Use 16 cores across all experiments. n_gpu=args.num_gpus, # Use 8 gpus across all experiments. # contexts_per_gpu=2, # hyperthread_offset=72, # If machine has 24 cores. # n_socket=2, # Presume CPU socket affinity to lower/upper half GPUs. gpu_per_run=args.gpu_per_run, # How many GPUs to parallelize one run across. # cpu_per_run=1, ) print('Make multi-gpu affinity') else: affinity = dict(cuda_idx=0, workers_cpus=list(range(args.num_cpus))) os.environ['CUDA_VISIBLE_DEVICES'] = str(0) else: affinity = dict(workers_cpus=list(range(args.num_cpus))) # potentially reload models initial_optim_state_dict = None initial_model_state_dict = None if args.pretrain != 'None': os.system(f"find {args.log_dir} -name '*.json' -delete") # clean up json files for video recorder checkpoint = torch.load(os.path.join(_RESULTS_DIR, args.pretrain, 'params.pkl')) initial_optim_state_dict = checkpoint['optimizer_state_dict'] initial_model_state_dict = checkpoint['agent_state_dict'] # ----------------------------------------------------- POLICY ----------------------------------------------------- # model_args = dict(curiosity_kwargs=dict(curiosity_alg=args.curiosity_alg), curiosity_step_kwargs=dict()) if args.curiosity_alg =='icm': model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt model_args['curiosity_kwargs']['forward_model'] = args.forward_model model_args['curiosity_kwargs']['feature_space'] = args.feature_space elif args.curiosity_alg == 'micm': model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt model_args['curiosity_kwargs']['forward_model'] = args.forward_model model_args['curiosity_kwargs']['ensemble_mode'] = args.ensemble_mode model_args['curiosity_kwargs']['device'] = args.sample_mode elif args.curiosity_alg == 'disagreement': model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['ensemble_size'] = args.ensemble_size model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt model_args['curiosity_kwargs']['device'] = args.sample_mode model_args['curiosity_kwargs']['forward_model'] = args.forward_model elif args.curiosity_alg == 'ndigo': model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['pred_horizon'] = args.pred_horizon model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm model_args['curiosity_kwargs']['device'] = args.sample_mode elif args.curiosity_alg == 'rnd': model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs']['drop_probability'] = args.drop_probability model_args['curiosity_kwargs']['gamma'] = args.discount model_args['curiosity_kwargs']['device'] = args.sample_mode if args.curiosity_alg != 'none': model_args['curiosity_step_kwargs']['curiosity_step_minibatches'] = args.curiosity_step_minibatches if args.env in _MUJOCO_ENVS: if args.lstm: agent = MujocoLstmAgent(initial_model_state_dict=initial_model_state_dict) else: agent = MujocoFfAgent(initial_model_state_dict=initial_model_state_dict) else: if args.lstm: agent = AtariLstmAgent( initial_model_state_dict=initial_model_state_dict, model_kwargs=model_args, no_extrinsic=args.no_extrinsic, dual_model=args.dual_model, ) else: agent = AtariFfAgent(initial_model_state_dict=initial_model_state_dict, model_kwargs=model_args, no_extrinsic=args.no_extrinsic, dual_model=args.dual_model) # ----------------------------------------------------- LEARNING ALG ----------------------------------------------------- # if args.alg == 'ppo': algo = PPO( discount=args.discount, learning_rate=args.lr, value_loss_coeff=args.v_loss_coeff, entropy_loss_coeff=args.entropy_loss_coeff, OptimCls=torch.optim.Adam, optim_kwargs=None, clip_grad_norm=args.grad_norm_bound, initial_optim_state_dict=initial_optim_state_dict, # is None is not reloading a checkpoint gae_lambda=args.gae_lambda, minibatches=args.minibatches, # if recurrent: batch_B needs to be at least equal, if not recurrent: batch_B*batch_T needs to be at least equal to this epochs=args.epochs, ratio_clip=args.ratio_clip, linear_lr_schedule=args.linear_lr, normalize_advantage=args.normalize_advantage, normalize_reward=args.normalize_reward, curiosity_type=args.curiosity_alg, policy_loss_type=args.policy_loss_type ) elif args.alg == 'a2c': algo = A2C( discount=args.discount, learning_rate=args.lr, value_loss_coeff=args.v_loss_coeff, entropy_loss_coeff=args.entropy_loss_coeff, OptimCls=torch.optim.Adam, optim_kwargs=None, clip_grad_norm=args.grad_norm_bound, initial_optim_state_dict=initial_optim_state_dict, gae_lambda=args.gae_lambda, normalize_advantage=args.normalize_advantage ) # ----------------------------------------------------- SAMPLER ----------------------------------------------------- # # environment setup traj_info_cl = TrajInfo # environment specific - potentially overriden below if 'mario' in args.env.lower(): env_cl = mario_make env_args = dict( game=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=args.normalize_obs, normalize_obs_steps=10000 ) elif args.env in _PYCOLAB_ENVS: env_cl = deepmind_make traj_info_cl = PycolabTrajInfo env_args = dict( game=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=args.normalize_obs, normalize_obs_steps=10000, log_heatmaps=args.log_heatmaps, logdir=args.log_dir, obs_type=args.obs_type, grayscale=args.grayscale, max_steps_per_episode=args.max_episode_steps ) elif args.env in _MUJOCO_ENVS: env_cl = gym_make env_args = dict( id=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=False, normalize_obs_steps=10000 ) elif args.env in _ATARI_ENVS: env_cl = AtariEnv traj_info_cl = AtariTrajInfo env_args = dict( game=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=args.normalize_obs, normalize_obs_steps=10000, downsampling_scheme='classical', record_freq=args.record_freq, record_dir=args.log_dir, horizon=args.max_episode_steps, score_multiplier=args.score_multiplier, repeat_action_probability=args.repeat_action_probability, fire_on_reset=args.fire_on_reset ) if args.sample_mode == 'gpu': if args.lstm: collector_class = GpuWaitResetCollector else: collector_class = GpuResetCollector sampler = GpuSampler( EnvCls=env_cl, env_kwargs=env_args, eval_env_kwargs=env_args, batch_T=args.timestep_limit, batch_B=args.num_envs, max_decorrelation_steps=0, TrajInfoCls=traj_info_cl, eval_n_envs=args.eval_envs, eval_max_steps=args.eval_max_steps, eval_max_trajectories=args.eval_max_traj, record_freq=args.record_freq, log_dir=args.log_dir, CollectorCls=collector_class ) else: if args.lstm: collector_class = CpuWaitResetCollector else: collector_class = CpuResetCollector sampler = CpuSampler( EnvCls=env_cl, env_kwargs=env_args, eval_env_kwargs=env_args, batch_T=args.timestep_limit, # timesteps in a trajectory episode batch_B=args.num_envs, # environments distributed across workers max_decorrelation_steps=0, TrajInfoCls=traj_info_cl, eval_n_envs=args.eval_envs, eval_max_steps=args.eval_max_steps, eval_max_trajectories=args.eval_max_traj, record_freq=args.record_freq, log_dir=args.log_dir, CollectorCls=collector_class ) # ----------------------------------------------------- RUNNER ----------------------------------------------------- # if args.eval_envs > 0: runner = (MinibatchRlEval if args.num_gpus <= 1 else SyncRlEval)( algo=algo, agent=agent, sampler=sampler, n_steps=args.iterations, affinity=affinity, log_interval_steps=args.log_interval, log_dir=args.log_dir, pretrain=args.pretrain ) else: runner = (MinibatchRl if args.num_gpus <= 1 else SyncRl)( algo=algo, agent=agent, sampler=sampler, n_steps=args.iterations, affinity=affinity, log_interval_steps=args.log_interval, log_dir=args.log_dir, pretrain=args.pretrain ) with logger_context(args.log_dir, config, snapshot_mode="last", use_summary_writer=True): runner.train()
def start_experiment(args): args_json = json.dumps(vars(args), indent=4) if not os.path.isdir(args.log_dir): os.makedirs(args.log_dir) with open(args.log_dir + '/arguments.json', 'w') as jsonfile: jsonfile.write(args_json) config = dict(env_id=args.env) if args.sample_mode == 'gpu': assert args.num_gpus > 0 affinity = dict(cuda_idx=0, workers_cpus=list(range(args.num_cpus))) os.environ['CUDA_VISIBLE_DEVICES'] = str(0) else: affinity = dict(workers_cpus=list(range(args.num_cpus))) # potentially reload models initial_optim_state_dict = None initial_model_state_dict = None if args.pretrain != 'None': os.system(f"find {args.log_dir} -name '*.json' -delete" ) # clean up json files for video recorder checkpoint = torch.load( os.path.join(_RESULTS_DIR, args.pretrain, 'params.pkl')) initial_optim_state_dict = checkpoint['optimizer_state_dict'] initial_model_state_dict = checkpoint['agent_state_dict'] # ----------------------------------------------------- POLICY ----------------------------------------------------- # model_args = dict(curiosity_kwargs=dict(curiosity_alg=args.curiosity_alg)) if args.curiosity_alg == 'icm': model_args['curiosity_kwargs'][ 'feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm model_args['curiosity_kwargs'][ 'prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs'][ 'forward_loss_wt'] = args.forward_loss_wt elif args.curiosity_alg == 'disagreement': model_args['curiosity_kwargs'][ 'feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['ensemble_size'] = args.ensemble_size model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm model_args['curiosity_kwargs'][ 'prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs'][ 'forward_loss_wt'] = args.forward_loss_wt model_args['curiosity_kwargs']['device'] = args.sample_mode elif args.curiosity_alg == 'ndigo': model_args['curiosity_kwargs'][ 'feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['pred_horizon'] = args.pred_horizon model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm model_args['curiosity_kwargs']['num_predictors'] = args.num_predictors model_args['curiosity_kwargs']['device'] = args.sample_mode elif args.curiosity_alg == 'rnd': model_args['curiosity_kwargs'][ 'feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs'][ 'prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs'][ 'drop_probability'] = args.drop_probability model_args['curiosity_kwargs']['gamma'] = args.discount model_args['curiosity_kwargs']['device'] = args.sample_mode if args.env in _MUJOCO_ENVS: if args.lstm: agent = MujocoLstmAgent( initial_model_state_dict=initial_model_state_dict) else: agent = MujocoFfAgent( initial_model_state_dict=initial_model_state_dict) else: if args.lstm: agent = AtariLstmAgent( initial_model_state_dict=initial_model_state_dict, model_kwargs=model_args, no_extrinsic=args.no_extrinsic) else: agent = AtariFfAgent( initial_model_state_dict=initial_model_state_dict) # ----------------------------------------------------- LEARNING ALG ----------------------------------------------------- # if args.alg == 'ppo': if args.kernel_mu == 0.: kernel_params = None else: kernel_params = (args.kernel_mu, args.kernel_sigma) algo = PPO( discount=args.discount, learning_rate=args.lr, value_loss_coeff=args.v_loss_coeff, entropy_loss_coeff=args.entropy_loss_coeff, OptimCls=torch.optim.Adam, optim_kwargs=None, clip_grad_norm=args.grad_norm_bound, initial_optim_state_dict= initial_optim_state_dict, # is None is not reloading a checkpoint gae_lambda=args.gae_lambda, minibatches=args. minibatches, # if recurrent: batch_B needs to be at least equal, if not recurrent: batch_B*batch_T needs to be at least equal to this epochs=args.epochs, ratio_clip=args.ratio_clip, linear_lr_schedule=args.linear_lr, normalize_advantage=args.normalize_advantage, normalize_reward=args.normalize_reward, kernel_params=kernel_params, curiosity_type=args.curiosity_alg) elif args.alg == 'a2c': algo = A2C(discount=args.discount, learning_rate=args.lr, value_loss_coeff=args.v_loss_coeff, entropy_loss_coeff=args.entropy_loss_coeff, OptimCls=torch.optim.Adam, optim_kwargs=None, clip_grad_norm=args.grad_norm_bound, initial_optim_state_dict=initial_optim_state_dict, gae_lambda=args.gae_lambda, normalize_advantage=args.normalize_advantage) # ----------------------------------------------------- SAMPLER ----------------------------------------------------- # # environment setup traj_info_cl = TrajInfo # environment specific - potentially overriden below if 'mario' in args.env.lower(): env_cl = mario_make env_args = dict(game=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=args.normalize_obs, normalize_obs_steps=10000) elif 'deepmind' in args.env.lower(): # pycolab deepmind environments env_cl = deepmind_make traj_info_cl = PycolabTrajInfo env_args = dict(game=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=args.normalize_obs, normalize_obs_steps=10000, log_heatmaps=args.log_heatmaps, logdir=args.log_dir, obs_type=args.obs_type, max_steps_per_episode=args.max_episode_steps) elif args.env in _MUJOCO_ENVS: env_cl = gym_make env_args = dict(id=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=False, normalize_obs_steps=10000) elif args.env in _ATARI_ENVS: env_cl = AtariEnv traj_info_cl = AtariTrajInfo env_args = dict( game=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=args.normalize_obs, normalize_obs_steps=10000, downsampling_scheme='classical', record_freq=args.record_freq, record_dir=args.log_dir, horizon=args.max_episode_steps, ) if args.sample_mode == 'gpu': if args.lstm: collector_class = GpuWaitResetCollector else: collector_class = GpuResetCollector sampler = GpuSampler(EnvCls=env_cl, env_kwargs=env_args, eval_env_kwargs=env_args, batch_T=args.timestep_limit, batch_B=args.num_envs, max_decorrelation_steps=0, TrajInfoCls=traj_info_cl, eval_n_envs=args.eval_envs, eval_max_steps=args.eval_max_steps, eval_max_trajectories=args.eval_max_traj, record_freq=args.record_freq, log_dir=args.log_dir, CollectorCls=collector_class) else: if args.lstm: collector_class = CpuWaitResetCollector else: collector_class = CpuResetCollector sampler = CpuSampler( EnvCls=env_cl, env_kwargs=env_args, eval_env_kwargs=env_args, batch_T=args.timestep_limit, # timesteps in a trajectory episode batch_B=args.num_envs, # environments distributed across workers max_decorrelation_steps=0, TrajInfoCls=traj_info_cl, eval_n_envs=args.eval_envs, eval_max_steps=args.eval_max_steps, eval_max_trajectories=args.eval_max_traj, record_freq=args.record_freq, log_dir=args.log_dir, CollectorCls=collector_class) # ----------------------------------------------------- RUNNER ----------------------------------------------------- # if args.eval_envs > 0: runner = MinibatchRlEval(algo=algo, agent=agent, sampler=sampler, n_steps=args.iterations, affinity=affinity, log_interval_steps=args.log_interval, log_dir=args.log_dir, pretrain=args.pretrain) else: runner = MinibatchRl(algo=algo, agent=agent, sampler=sampler, n_steps=args.iterations, affinity=affinity, log_interval_steps=args.log_interval, log_dir=args.log_dir, pretrain=args.pretrain) with logger_context(args.log_dir, config, snapshot_mode="last", use_summary_writer=True): runner.train()