def build_and_train(game="pong", run_ID=0, cuda_idx=None, mid_batch_reset=False, n_parallel=2): affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel))) Collector = GpuResetCollector if mid_batch_reset else GpuWaitResetCollector print(f"To satisfy mid_batch_reset=={mid_batch_reset}, using {Collector}.") sampler = GpuParallelSampler( EnvCls=AtariEnv, env_kwargs=dict(game=game, num_img_obs=1), # Learn on individual frames. CollectorCls=Collector, batch_T=20, # Longer sampling/optimization horizon for recurrence. batch_B=16, # 16 parallel environments. max_decorrelation_steps=400, ) algo = A2C() # Run with defaults. agent = AtariLstmAgent() runner = MinibatchRl( algo=algo, agent=agent, sampler=sampler, n_steps=50e6, log_interval_steps=1e5, affinity=affinity, ) config = dict(game=game) name = "a2c_" + game log_dir = "example_4" with logger_context(log_dir, run_ID, name, config): runner.train()
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key): affinity = affinity_from_code(slot_affinity_code) config = configs[config_key] variant = load_variant(log_dir) config = update_config(config, variant) sampler = GpuSampler( EnvCls=AtariEnv, env_kwargs=config["env"], CollectorCls=WaitResetCollector, TrajInfoCls=AtariTrajInfo, **config["sampler"] ) algo = PPO(optim_kwargs=config["optim"], **config["algo"]) agent = AtariLstmAgent(model_kwargs=config["model"], **config["agent"]) runner = MinibatchRl( algo=algo, agent=agent, sampler=sampler, affinity=affinity, **config["runner"] ) name = config["env"]["game"] with logger_context(log_dir, run_ID, name, config): runner.train()
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key): affinity = affinity_from_code(slot_affinity_code) config = configs[config_key] # variant = load_variant(log_dir) # config = update_config(config, variant) sampler = CpuSampler(EnvCls=AtariEnv, env_kwargs=config["env"], CollectorCls=EpisodicLivesWaitResetCollector, **config["sampler"]) algo = A2C(optim_kwargs=config["optim"], **config["algo"]) agent = AtariLstmAgent(model_kwargs=config["model"], **config["agent"]) runner = MinibatchRl(algo=algo, agent=agent, sampler=sampler, affinity=affinity, **config["runner"]) name = config["env"]["game"] + str(config["algo"]["entropy_loss_coeff"]) with logger_context(log_dir, run_ID, name, config): runner.train()
def build_and_train(game="academy_empty_goal_close", run_ID=1, cuda_idx=None): env_vector_size = args.envVectorSize coach = Coach(envOptions=args.envOptions, vectorSize=env_vector_size, algo='Bandit', initialQ=args.initialQ, beta=args.beta) sampler = SerialSampler( EnvCls=create_single_football_env, env_kwargs=dict(game=game), eval_env_kwargs=dict(game=game), batch_T=5, # Four time-steps per sampler iteration. batch_B=env_vector_size, max_decorrelation_steps=0, eval_n_envs=args.evalNumOfEnvs, eval_max_steps=int(10e3), eval_max_trajectories=5, coach=coach, eval_env=args.evalEnv, ) algo = PPO(minibatches=1) # Run with defaults. agent = AtariLstmAgent() # TODO: move to ff runner = MinibatchRlEval( algo=algo, agent=agent, sampler=sampler, n_steps=args.numOfSteps, log_interval_steps=1e3, affinity=dict(cuda_idx=cuda_idx), ) name = args.name log_dir = "example_1" with logger_context(log_dir, run_ID, name, log_params=vars(args), snapshot_mode="last"): runner.train()
def start_experiment(args): args_json = json.dumps(vars(args), indent=4) if not os.path.isdir(args.log_dir): os.makedirs(args.log_dir) with open(args.log_dir + '/arguments.json', 'w') as jsonfile: jsonfile.write(args_json) with open(args.log_dir + '/git.txt', 'w') as git_file: branch = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip().decode('utf-8') commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip().decode('utf-8') git_file.write('{}/{}'.format(branch, commit)) config = dict(env_id=args.env) if args.sample_mode == 'gpu': # affinity = dict(num_gpus=args.num_gpus, workers_cpus=list(range(args.num_cpus))) if args.num_gpus > 0: # import ipdb; ipdb.set_trace() affinity = make_affinity( run_slot=0, n_cpu_core=args.num_cpus, # Use 16 cores across all experiments. n_gpu=args.num_gpus, # Use 8 gpus across all experiments. # contexts_per_gpu=2, # hyperthread_offset=72, # If machine has 24 cores. # n_socket=2, # Presume CPU socket affinity to lower/upper half GPUs. gpu_per_run=args.gpu_per_run, # How many GPUs to parallelize one run across. # cpu_per_run=1, ) print('Make multi-gpu affinity') else: affinity = dict(cuda_idx=0, workers_cpus=list(range(args.num_cpus))) os.environ['CUDA_VISIBLE_DEVICES'] = str(0) else: affinity = dict(workers_cpus=list(range(args.num_cpus))) # potentially reload models initial_optim_state_dict = None initial_model_state_dict = None if args.pretrain != 'None': os.system(f"find {args.log_dir} -name '*.json' -delete") # clean up json files for video recorder checkpoint = torch.load(os.path.join(_RESULTS_DIR, args.pretrain, 'params.pkl')) initial_optim_state_dict = checkpoint['optimizer_state_dict'] initial_model_state_dict = checkpoint['agent_state_dict'] # ----------------------------------------------------- POLICY ----------------------------------------------------- # model_args = dict(curiosity_kwargs=dict(curiosity_alg=args.curiosity_alg), curiosity_step_kwargs=dict()) if args.curiosity_alg =='icm': model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt model_args['curiosity_kwargs']['forward_model'] = args.forward_model model_args['curiosity_kwargs']['feature_space'] = args.feature_space elif args.curiosity_alg == 'micm': model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt model_args['curiosity_kwargs']['forward_model'] = args.forward_model model_args['curiosity_kwargs']['ensemble_mode'] = args.ensemble_mode model_args['curiosity_kwargs']['device'] = args.sample_mode elif args.curiosity_alg == 'disagreement': model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['ensemble_size'] = args.ensemble_size model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt model_args['curiosity_kwargs']['device'] = args.sample_mode model_args['curiosity_kwargs']['forward_model'] = args.forward_model elif args.curiosity_alg == 'ndigo': model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['pred_horizon'] = args.pred_horizon model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm model_args['curiosity_kwargs']['device'] = args.sample_mode elif args.curiosity_alg == 'rnd': model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs']['drop_probability'] = args.drop_probability model_args['curiosity_kwargs']['gamma'] = args.discount model_args['curiosity_kwargs']['device'] = args.sample_mode if args.curiosity_alg != 'none': model_args['curiosity_step_kwargs']['curiosity_step_minibatches'] = args.curiosity_step_minibatches if args.env in _MUJOCO_ENVS: if args.lstm: agent = MujocoLstmAgent(initial_model_state_dict=initial_model_state_dict) else: agent = MujocoFfAgent(initial_model_state_dict=initial_model_state_dict) else: if args.lstm: agent = AtariLstmAgent( initial_model_state_dict=initial_model_state_dict, model_kwargs=model_args, no_extrinsic=args.no_extrinsic, dual_model=args.dual_model, ) else: agent = AtariFfAgent(initial_model_state_dict=initial_model_state_dict, model_kwargs=model_args, no_extrinsic=args.no_extrinsic, dual_model=args.dual_model) # ----------------------------------------------------- LEARNING ALG ----------------------------------------------------- # if args.alg == 'ppo': algo = PPO( discount=args.discount, learning_rate=args.lr, value_loss_coeff=args.v_loss_coeff, entropy_loss_coeff=args.entropy_loss_coeff, OptimCls=torch.optim.Adam, optim_kwargs=None, clip_grad_norm=args.grad_norm_bound, initial_optim_state_dict=initial_optim_state_dict, # is None is not reloading a checkpoint gae_lambda=args.gae_lambda, minibatches=args.minibatches, # if recurrent: batch_B needs to be at least equal, if not recurrent: batch_B*batch_T needs to be at least equal to this epochs=args.epochs, ratio_clip=args.ratio_clip, linear_lr_schedule=args.linear_lr, normalize_advantage=args.normalize_advantage, normalize_reward=args.normalize_reward, curiosity_type=args.curiosity_alg, policy_loss_type=args.policy_loss_type ) elif args.alg == 'a2c': algo = A2C( discount=args.discount, learning_rate=args.lr, value_loss_coeff=args.v_loss_coeff, entropy_loss_coeff=args.entropy_loss_coeff, OptimCls=torch.optim.Adam, optim_kwargs=None, clip_grad_norm=args.grad_norm_bound, initial_optim_state_dict=initial_optim_state_dict, gae_lambda=args.gae_lambda, normalize_advantage=args.normalize_advantage ) # ----------------------------------------------------- SAMPLER ----------------------------------------------------- # # environment setup traj_info_cl = TrajInfo # environment specific - potentially overriden below if 'mario' in args.env.lower(): env_cl = mario_make env_args = dict( game=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=args.normalize_obs, normalize_obs_steps=10000 ) elif args.env in _PYCOLAB_ENVS: env_cl = deepmind_make traj_info_cl = PycolabTrajInfo env_args = dict( game=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=args.normalize_obs, normalize_obs_steps=10000, log_heatmaps=args.log_heatmaps, logdir=args.log_dir, obs_type=args.obs_type, grayscale=args.grayscale, max_steps_per_episode=args.max_episode_steps ) elif args.env in _MUJOCO_ENVS: env_cl = gym_make env_args = dict( id=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=False, normalize_obs_steps=10000 ) elif args.env in _ATARI_ENVS: env_cl = AtariEnv traj_info_cl = AtariTrajInfo env_args = dict( game=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=args.normalize_obs, normalize_obs_steps=10000, downsampling_scheme='classical', record_freq=args.record_freq, record_dir=args.log_dir, horizon=args.max_episode_steps, score_multiplier=args.score_multiplier, repeat_action_probability=args.repeat_action_probability, fire_on_reset=args.fire_on_reset ) if args.sample_mode == 'gpu': if args.lstm: collector_class = GpuWaitResetCollector else: collector_class = GpuResetCollector sampler = GpuSampler( EnvCls=env_cl, env_kwargs=env_args, eval_env_kwargs=env_args, batch_T=args.timestep_limit, batch_B=args.num_envs, max_decorrelation_steps=0, TrajInfoCls=traj_info_cl, eval_n_envs=args.eval_envs, eval_max_steps=args.eval_max_steps, eval_max_trajectories=args.eval_max_traj, record_freq=args.record_freq, log_dir=args.log_dir, CollectorCls=collector_class ) else: if args.lstm: collector_class = CpuWaitResetCollector else: collector_class = CpuResetCollector sampler = CpuSampler( EnvCls=env_cl, env_kwargs=env_args, eval_env_kwargs=env_args, batch_T=args.timestep_limit, # timesteps in a trajectory episode batch_B=args.num_envs, # environments distributed across workers max_decorrelation_steps=0, TrajInfoCls=traj_info_cl, eval_n_envs=args.eval_envs, eval_max_steps=args.eval_max_steps, eval_max_trajectories=args.eval_max_traj, record_freq=args.record_freq, log_dir=args.log_dir, CollectorCls=collector_class ) # ----------------------------------------------------- RUNNER ----------------------------------------------------- # if args.eval_envs > 0: runner = (MinibatchRlEval if args.num_gpus <= 1 else SyncRlEval)( algo=algo, agent=agent, sampler=sampler, n_steps=args.iterations, affinity=affinity, log_interval_steps=args.log_interval, log_dir=args.log_dir, pretrain=args.pretrain ) else: runner = (MinibatchRl if args.num_gpus <= 1 else SyncRl)( algo=algo, agent=agent, sampler=sampler, n_steps=args.iterations, affinity=affinity, log_interval_steps=args.log_interval, log_dir=args.log_dir, pretrain=args.pretrain ) with logger_context(args.log_dir, config, snapshot_mode="last", use_summary_writer=True): runner.train()
def build_and_train(windx, windy, game="pong", run_ID=0, cuda_idx=None, sample_mode="serial", n_parallel=2, num_envs=2, eval=False, train_mask=[True, True], wandb_log=False, save_models_to_wandb=False, log_interval_steps=1e5, alt_train=False, n_steps=50e6, max_episode_length=np.inf, b_size=5, max_decor_steps=10): # def envs: # player_model_kwargs = dict(hidden_sizes=[32], lstm_size=16, nonlinearity=torch.nn.ReLU, normalize_observation=False, # norm_obs_clip=10, norm_obs_var_clip=1e-6) # observer_model_kwargs = dict(hidden_sizes=[128], lstm_size=16, nonlinearity=torch.nn.ReLU, # normalize_observation=False, norm_obs_clip=10, norm_obs_var_clip=1e-6) player_reward_shaping = None observer_reward_shaping = None window_size = np.asarray([windx, windy]) affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel))) gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}" if sample_mode == "serial": alt = False Sampler = SerialSampler # (Ignores workers_cpus.) if eval: eval_collector_cl = SerialEvalCollector else: eval_collector_cl = None print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.") elif sample_mode == "cpu": alt = False Sampler = CpuSampler if eval: eval_collector_cl = CpuEvalCollector else: eval_collector_cl = None print( f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing." ) env_kwargs = dict(env_name=game, window_size=window_size, player_reward_shaping=player_reward_shaping, observer_reward_shaping=observer_reward_shaping, max_episode_length=max_episode_length) if eval: eval_env_kwargs = env_kwargs eval_max_steps = 1e4 num_eval_envs = num_envs else: eval_env_kwargs = None eval_max_steps = None num_eval_envs = 0 sampler = Sampler( EnvCls=CWTO_EnvWrapperAtari, env_kwargs=env_kwargs, batch_T=b_size, batch_B=num_envs, max_decorrelation_steps=max_decor_steps, eval_n_envs=num_eval_envs, eval_CollectorCls=eval_collector_cl, eval_env_kwargs=eval_env_kwargs, eval_max_steps=eval_max_steps, ) player_algo = A2C() observer_algo = A2C() player = AtariLstmAgent() #model_kwargs=player_model_kwargs) observer = CWTO_AtariLstmAgent() #model_kwargs=observer_model_kwargs) agent = CWTO_AgentWrapper(player, observer, alt=alt, train_mask=train_mask) if eval: RunnerCl = MinibatchRlEval else: RunnerCl = MinibatchRl runner = RunnerCl(player_algo=player_algo, observer_algo=observer_algo, agent=agent, sampler=sampler, n_steps=n_steps, log_interval_steps=log_interval_steps, affinity=affinity, wandb_log=wandb_log, alt_train=alt_train) config = dict(domain=game) log_dir = os.getcwd() + "/cwto_logs/" + game with logger_context(log_dir, run_ID, game, config): runner.train() if save_models_to_wandb: agent.save_models_to_wandb()
def start_experiment(args): args_json = json.dumps(vars(args), indent=4) if not os.path.isdir(args.log_dir): os.makedirs(args.log_dir) with open(args.log_dir + '/arguments.json', 'w') as jsonfile: jsonfile.write(args_json) config = dict(env_id=args.env) if args.sample_mode == 'gpu': assert args.num_gpus > 0 affinity = dict(cuda_idx=0, workers_cpus=list(range(args.num_cpus))) os.environ['CUDA_VISIBLE_DEVICES'] = str(0) else: affinity = dict(workers_cpus=list(range(args.num_cpus))) # potentially reload models initial_optim_state_dict = None initial_model_state_dict = None if args.pretrain != 'None': os.system(f"find {args.log_dir} -name '*.json' -delete" ) # clean up json files for video recorder checkpoint = torch.load( os.path.join(_RESULTS_DIR, args.pretrain, 'params.pkl')) initial_optim_state_dict = checkpoint['optimizer_state_dict'] initial_model_state_dict = checkpoint['agent_state_dict'] # ----------------------------------------------------- POLICY ----------------------------------------------------- # model_args = dict(curiosity_kwargs=dict(curiosity_alg=args.curiosity_alg)) if args.curiosity_alg == 'icm': model_args['curiosity_kwargs'][ 'feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm model_args['curiosity_kwargs'][ 'prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs'][ 'forward_loss_wt'] = args.forward_loss_wt elif args.curiosity_alg == 'disagreement': model_args['curiosity_kwargs'][ 'feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['ensemble_size'] = args.ensemble_size model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm model_args['curiosity_kwargs'][ 'prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs'][ 'forward_loss_wt'] = args.forward_loss_wt model_args['curiosity_kwargs']['device'] = args.sample_mode elif args.curiosity_alg == 'ndigo': model_args['curiosity_kwargs'][ 'feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs']['pred_horizon'] = args.pred_horizon model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm model_args['curiosity_kwargs']['num_predictors'] = args.num_predictors model_args['curiosity_kwargs']['device'] = args.sample_mode elif args.curiosity_alg == 'rnd': model_args['curiosity_kwargs'][ 'feature_encoding'] = args.feature_encoding model_args['curiosity_kwargs'][ 'prediction_beta'] = args.prediction_beta model_args['curiosity_kwargs'][ 'drop_probability'] = args.drop_probability model_args['curiosity_kwargs']['gamma'] = args.discount model_args['curiosity_kwargs']['device'] = args.sample_mode if args.env in _MUJOCO_ENVS: if args.lstm: agent = MujocoLstmAgent( initial_model_state_dict=initial_model_state_dict) else: agent = MujocoFfAgent( initial_model_state_dict=initial_model_state_dict) else: if args.lstm: agent = AtariLstmAgent( initial_model_state_dict=initial_model_state_dict, model_kwargs=model_args, no_extrinsic=args.no_extrinsic) else: agent = AtariFfAgent( initial_model_state_dict=initial_model_state_dict) # ----------------------------------------------------- LEARNING ALG ----------------------------------------------------- # if args.alg == 'ppo': if args.kernel_mu == 0.: kernel_params = None else: kernel_params = (args.kernel_mu, args.kernel_sigma) algo = PPO( discount=args.discount, learning_rate=args.lr, value_loss_coeff=args.v_loss_coeff, entropy_loss_coeff=args.entropy_loss_coeff, OptimCls=torch.optim.Adam, optim_kwargs=None, clip_grad_norm=args.grad_norm_bound, initial_optim_state_dict= initial_optim_state_dict, # is None is not reloading a checkpoint gae_lambda=args.gae_lambda, minibatches=args. minibatches, # if recurrent: batch_B needs to be at least equal, if not recurrent: batch_B*batch_T needs to be at least equal to this epochs=args.epochs, ratio_clip=args.ratio_clip, linear_lr_schedule=args.linear_lr, normalize_advantage=args.normalize_advantage, normalize_reward=args.normalize_reward, kernel_params=kernel_params, curiosity_type=args.curiosity_alg) elif args.alg == 'a2c': algo = A2C(discount=args.discount, learning_rate=args.lr, value_loss_coeff=args.v_loss_coeff, entropy_loss_coeff=args.entropy_loss_coeff, OptimCls=torch.optim.Adam, optim_kwargs=None, clip_grad_norm=args.grad_norm_bound, initial_optim_state_dict=initial_optim_state_dict, gae_lambda=args.gae_lambda, normalize_advantage=args.normalize_advantage) # ----------------------------------------------------- SAMPLER ----------------------------------------------------- # # environment setup traj_info_cl = TrajInfo # environment specific - potentially overriden below if 'mario' in args.env.lower(): env_cl = mario_make env_args = dict(game=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=args.normalize_obs, normalize_obs_steps=10000) elif 'deepmind' in args.env.lower(): # pycolab deepmind environments env_cl = deepmind_make traj_info_cl = PycolabTrajInfo env_args = dict(game=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=args.normalize_obs, normalize_obs_steps=10000, log_heatmaps=args.log_heatmaps, logdir=args.log_dir, obs_type=args.obs_type, max_steps_per_episode=args.max_episode_steps) elif args.env in _MUJOCO_ENVS: env_cl = gym_make env_args = dict(id=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=False, normalize_obs_steps=10000) elif args.env in _ATARI_ENVS: env_cl = AtariEnv traj_info_cl = AtariTrajInfo env_args = dict( game=args.env, no_extrinsic=args.no_extrinsic, no_negative_reward=args.no_negative_reward, normalize_obs=args.normalize_obs, normalize_obs_steps=10000, downsampling_scheme='classical', record_freq=args.record_freq, record_dir=args.log_dir, horizon=args.max_episode_steps, ) if args.sample_mode == 'gpu': if args.lstm: collector_class = GpuWaitResetCollector else: collector_class = GpuResetCollector sampler = GpuSampler(EnvCls=env_cl, env_kwargs=env_args, eval_env_kwargs=env_args, batch_T=args.timestep_limit, batch_B=args.num_envs, max_decorrelation_steps=0, TrajInfoCls=traj_info_cl, eval_n_envs=args.eval_envs, eval_max_steps=args.eval_max_steps, eval_max_trajectories=args.eval_max_traj, record_freq=args.record_freq, log_dir=args.log_dir, CollectorCls=collector_class) else: if args.lstm: collector_class = CpuWaitResetCollector else: collector_class = CpuResetCollector sampler = CpuSampler( EnvCls=env_cl, env_kwargs=env_args, eval_env_kwargs=env_args, batch_T=args.timestep_limit, # timesteps in a trajectory episode batch_B=args.num_envs, # environments distributed across workers max_decorrelation_steps=0, TrajInfoCls=traj_info_cl, eval_n_envs=args.eval_envs, eval_max_steps=args.eval_max_steps, eval_max_trajectories=args.eval_max_traj, record_freq=args.record_freq, log_dir=args.log_dir, CollectorCls=collector_class) # ----------------------------------------------------- RUNNER ----------------------------------------------------- # if args.eval_envs > 0: runner = MinibatchRlEval(algo=algo, agent=agent, sampler=sampler, n_steps=args.iterations, affinity=affinity, log_interval_steps=args.log_interval, log_dir=args.log_dir, pretrain=args.pretrain) else: runner = MinibatchRl(algo=algo, agent=agent, sampler=sampler, n_steps=args.iterations, affinity=affinity, log_interval_steps=args.log_interval, log_dir=args.log_dir, pretrain=args.pretrain) with logger_context(args.log_dir, config, snapshot_mode="last", use_summary_writer=True): runner.train()