Example #1
0
def build_and_train(game="pong",
                    run_ID=0,
                    cuda_idx=None,
                    mid_batch_reset=False,
                    n_parallel=2):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    Collector = GpuResetCollector if mid_batch_reset else GpuWaitResetCollector
    print(f"To satisfy mid_batch_reset=={mid_batch_reset}, using {Collector}.")

    sampler = GpuParallelSampler(
        EnvCls=AtariEnv,
        env_kwargs=dict(game=game,
                        num_img_obs=1),  # Learn on individual frames.
        CollectorCls=Collector,
        batch_T=20,  # Longer sampling/optimization horizon for recurrence.
        batch_B=16,  # 16 parallel environments.
        max_decorrelation_steps=400,
    )
    algo = A2C()  # Run with defaults.
    agent = AtariLstmAgent()
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e5,
        affinity=affinity,
    )
    config = dict(game=game)
    name = "a2c_" + game
    log_dir = "example_4"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Example #2
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)

    sampler = GpuSampler(
        EnvCls=AtariEnv,
        env_kwargs=config["env"],
        CollectorCls=WaitResetCollector,
        TrajInfoCls=AtariTrajInfo,
        **config["sampler"]
    )
    algo = PPO(optim_kwargs=config["optim"], **config["algo"])
    agent = AtariLstmAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = config["env"]["game"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    # variant = load_variant(log_dir)
    # config = update_config(config, variant)

    sampler = CpuSampler(EnvCls=AtariEnv,
                         env_kwargs=config["env"],
                         CollectorCls=EpisodicLivesWaitResetCollector,
                         **config["sampler"])
    algo = A2C(optim_kwargs=config["optim"], **config["algo"])
    agent = AtariLstmAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRl(algo=algo,
                         agent=agent,
                         sampler=sampler,
                         affinity=affinity,
                         **config["runner"])
    name = config["env"]["game"] + str(config["algo"]["entropy_loss_coeff"])
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Example #4
0
def build_and_train(game="academy_empty_goal_close", run_ID=1, cuda_idx=None):
    env_vector_size = args.envVectorSize
    coach = Coach(envOptions=args.envOptions,
                  vectorSize=env_vector_size,
                  algo='Bandit',
                  initialQ=args.initialQ,
                  beta=args.beta)
    sampler = SerialSampler(
        EnvCls=create_single_football_env,
        env_kwargs=dict(game=game),
        eval_env_kwargs=dict(game=game),
        batch_T=5,  # Four time-steps per sampler iteration.
        batch_B=env_vector_size,
        max_decorrelation_steps=0,
        eval_n_envs=args.evalNumOfEnvs,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
        coach=coach,
        eval_env=args.evalEnv,
    )
    algo = PPO(minibatches=1)  # Run with defaults.
    agent = AtariLstmAgent()  # TODO: move to ff
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=args.numOfSteps,
        log_interval_steps=1e3,
        affinity=dict(cuda_idx=cuda_idx),
    )
    name = args.name
    log_dir = "example_1"
    with logger_context(log_dir,
                        run_ID,
                        name,
                        log_params=vars(args),
                        snapshot_mode="last"):
        runner.train()
def start_experiment(args):

    args_json = json.dumps(vars(args), indent=4)
    if not os.path.isdir(args.log_dir):
        os.makedirs(args.log_dir)
    with open(args.log_dir + '/arguments.json', 'w') as jsonfile:
        jsonfile.write(args_json)
    with open(args.log_dir + '/git.txt', 'w') as git_file:
        branch = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip().decode('utf-8')
        commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip().decode('utf-8')
        git_file.write('{}/{}'.format(branch, commit))

    config = dict(env_id=args.env)
    
    if args.sample_mode == 'gpu':
        # affinity = dict(num_gpus=args.num_gpus, workers_cpus=list(range(args.num_cpus)))
        if args.num_gpus > 0:
            # import ipdb; ipdb.set_trace()
            affinity = make_affinity(
                run_slot=0,
                n_cpu_core=args.num_cpus,  # Use 16 cores across all experiments.
                n_gpu=args.num_gpus,  # Use 8 gpus across all experiments.
                # contexts_per_gpu=2,
                # hyperthread_offset=72,  # If machine has 24 cores.
                # n_socket=2,  # Presume CPU socket affinity to lower/upper half GPUs.
                gpu_per_run=args.gpu_per_run,  # How many GPUs to parallelize one run across.

                # cpu_per_run=1,
            )
            print('Make multi-gpu affinity')
        else:
            affinity = dict(cuda_idx=0, workers_cpus=list(range(args.num_cpus)))
            os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
    else:
        affinity = dict(workers_cpus=list(range(args.num_cpus)))
    
    # potentially reload models
    initial_optim_state_dict = None
    initial_model_state_dict = None
    if args.pretrain != 'None':
        os.system(f"find {args.log_dir} -name '*.json' -delete") # clean up json files for video recorder
        checkpoint = torch.load(os.path.join(_RESULTS_DIR, args.pretrain, 'params.pkl'))
        initial_optim_state_dict = checkpoint['optimizer_state_dict']
        initial_model_state_dict = checkpoint['agent_state_dict']

    # ----------------------------------------------------- POLICY ----------------------------------------------------- #
    model_args = dict(curiosity_kwargs=dict(curiosity_alg=args.curiosity_alg), curiosity_step_kwargs=dict())
    if args.curiosity_alg =='icm':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['forward_model'] = args.forward_model
        model_args['curiosity_kwargs']['feature_space'] = args.feature_space
    elif args.curiosity_alg == 'micm':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['forward_model'] = args.forward_model
        model_args['curiosity_kwargs']['ensemble_mode'] = args.ensemble_mode
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'disagreement':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['ensemble_size'] = args.ensemble_size
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['device'] = args.sample_mode
        model_args['curiosity_kwargs']['forward_model'] = args.forward_model
    elif args.curiosity_alg == 'ndigo':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['pred_horizon'] = args.pred_horizon
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'rnd':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['drop_probability'] = args.drop_probability
        model_args['curiosity_kwargs']['gamma'] = args.discount
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    
    if args.curiosity_alg != 'none':
        model_args['curiosity_step_kwargs']['curiosity_step_minibatches'] = args.curiosity_step_minibatches

    if args.env in _MUJOCO_ENVS:
        if args.lstm:
            agent = MujocoLstmAgent(initial_model_state_dict=initial_model_state_dict)
        else:
            agent = MujocoFfAgent(initial_model_state_dict=initial_model_state_dict)
    else:
        if args.lstm:            
            agent = AtariLstmAgent(
                        initial_model_state_dict=initial_model_state_dict,
                        model_kwargs=model_args,
                        no_extrinsic=args.no_extrinsic,
                        dual_model=args.dual_model,
                        )
        else:
            agent = AtariFfAgent(initial_model_state_dict=initial_model_state_dict,
                model_kwargs=model_args,
                no_extrinsic=args.no_extrinsic,
                dual_model=args.dual_model)

    # ----------------------------------------------------- LEARNING ALG ----------------------------------------------------- #
    if args.alg == 'ppo':
        algo = PPO(
                discount=args.discount,
                learning_rate=args.lr,
                value_loss_coeff=args.v_loss_coeff,
                entropy_loss_coeff=args.entropy_loss_coeff,
                OptimCls=torch.optim.Adam,
                optim_kwargs=None,
                clip_grad_norm=args.grad_norm_bound,
                initial_optim_state_dict=initial_optim_state_dict, # is None is not reloading a checkpoint
                gae_lambda=args.gae_lambda,
                minibatches=args.minibatches, # if recurrent: batch_B needs to be at least equal, if not recurrent: batch_B*batch_T needs to be at least equal to this
                epochs=args.epochs,
                ratio_clip=args.ratio_clip,
                linear_lr_schedule=args.linear_lr,
                normalize_advantage=args.normalize_advantage,
                normalize_reward=args.normalize_reward,
                curiosity_type=args.curiosity_alg,
                policy_loss_type=args.policy_loss_type
                )
    elif args.alg == 'a2c':
        algo = A2C(
                discount=args.discount,
                learning_rate=args.lr,
                value_loss_coeff=args.v_loss_coeff,
                entropy_loss_coeff=args.entropy_loss_coeff,
                OptimCls=torch.optim.Adam,
                optim_kwargs=None,
                clip_grad_norm=args.grad_norm_bound,
                initial_optim_state_dict=initial_optim_state_dict,
                gae_lambda=args.gae_lambda,
                normalize_advantage=args.normalize_advantage
                )

    # ----------------------------------------------------- SAMPLER ----------------------------------------------------- #

    # environment setup
    traj_info_cl = TrajInfo # environment specific - potentially overriden below
    if 'mario' in args.env.lower():
        env_cl = mario_make
        env_args = dict(
            game=args.env,  
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000
            )
    elif args.env in _PYCOLAB_ENVS:
        env_cl = deepmind_make
        traj_info_cl = PycolabTrajInfo
        env_args = dict(
            game=args.env,
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000,
            log_heatmaps=args.log_heatmaps,
            logdir=args.log_dir,
            obs_type=args.obs_type,
            grayscale=args.grayscale,
            max_steps_per_episode=args.max_episode_steps
            )
    elif args.env in _MUJOCO_ENVS:
        env_cl = gym_make
        env_args = dict(
            id=args.env, 
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=False,
            normalize_obs_steps=10000
            )
    elif args.env in _ATARI_ENVS:
        env_cl = AtariEnv
        traj_info_cl = AtariTrajInfo
        env_args = dict(
            game=args.env, 
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000,
            downsampling_scheme='classical',
            record_freq=args.record_freq,
            record_dir=args.log_dir,
            horizon=args.max_episode_steps,
            score_multiplier=args.score_multiplier,
            repeat_action_probability=args.repeat_action_probability,
            fire_on_reset=args.fire_on_reset
            )

    if args.sample_mode == 'gpu':
        if args.lstm:
            collector_class = GpuWaitResetCollector
        else:
            collector_class = GpuResetCollector
        sampler = GpuSampler(
            EnvCls=env_cl,
            env_kwargs=env_args,
            eval_env_kwargs=env_args,
            batch_T=args.timestep_limit,
            batch_B=args.num_envs,
            max_decorrelation_steps=0,
            TrajInfoCls=traj_info_cl,
            eval_n_envs=args.eval_envs,
            eval_max_steps=args.eval_max_steps,
            eval_max_trajectories=args.eval_max_traj,
            record_freq=args.record_freq,
            log_dir=args.log_dir,
            CollectorCls=collector_class
        )
    else:
        if args.lstm:
            collector_class = CpuWaitResetCollector
        else:
            collector_class = CpuResetCollector
        sampler = CpuSampler(
            EnvCls=env_cl,
            env_kwargs=env_args,
            eval_env_kwargs=env_args,
            batch_T=args.timestep_limit, # timesteps in a trajectory episode
            batch_B=args.num_envs, # environments distributed across workers
            max_decorrelation_steps=0,
            TrajInfoCls=traj_info_cl,
            eval_n_envs=args.eval_envs,
            eval_max_steps=args.eval_max_steps,
            eval_max_trajectories=args.eval_max_traj,
            record_freq=args.record_freq,
            log_dir=args.log_dir,
            CollectorCls=collector_class
            )

    # ----------------------------------------------------- RUNNER ----------------------------------------------------- #     
    if args.eval_envs > 0:
        runner = (MinibatchRlEval if args.num_gpus <= 1 else SyncRlEval)(
            algo=algo,
            agent=agent,
            sampler=sampler,
            n_steps=args.iterations,
            affinity=affinity,
            log_interval_steps=args.log_interval,
            log_dir=args.log_dir,
            pretrain=args.pretrain
            )
    else:
        runner = (MinibatchRl if args.num_gpus <= 1 else SyncRl)(
            algo=algo,
            agent=agent,
            sampler=sampler,
            n_steps=args.iterations,
            affinity=affinity,
            log_interval_steps=args.log_interval,
            log_dir=args.log_dir,
            pretrain=args.pretrain
            )

    with logger_context(args.log_dir, config, snapshot_mode="last", use_summary_writer=True):
        runner.train()
Example #6
0
def build_and_train(windx,
                    windy,
                    game="pong",
                    run_ID=0,
                    cuda_idx=None,
                    sample_mode="serial",
                    n_parallel=2,
                    num_envs=2,
                    eval=False,
                    train_mask=[True, True],
                    wandb_log=False,
                    save_models_to_wandb=False,
                    log_interval_steps=1e5,
                    alt_train=False,
                    n_steps=50e6,
                    max_episode_length=np.inf,
                    b_size=5,
                    max_decor_steps=10):
    # def envs:
    # player_model_kwargs = dict(hidden_sizes=[32], lstm_size=16, nonlinearity=torch.nn.ReLU, normalize_observation=False,
    #                            norm_obs_clip=10, norm_obs_var_clip=1e-6)
    # observer_model_kwargs = dict(hidden_sizes=[128], lstm_size=16, nonlinearity=torch.nn.ReLU,
    #                              normalize_observation=False, norm_obs_clip=10, norm_obs_var_clip=1e-6)
    player_reward_shaping = None
    observer_reward_shaping = None
    window_size = np.asarray([windx, windy])

    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        alt = False
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        if eval:
            eval_collector_cl = SerialEvalCollector
        else:
            eval_collector_cl = None
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        alt = False
        Sampler = CpuSampler
        if eval:
            eval_collector_cl = CpuEvalCollector
        else:
            eval_collector_cl = None
        print(
            f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing."
        )
    env_kwargs = dict(env_name=game,
                      window_size=window_size,
                      player_reward_shaping=player_reward_shaping,
                      observer_reward_shaping=observer_reward_shaping,
                      max_episode_length=max_episode_length)
    if eval:
        eval_env_kwargs = env_kwargs
        eval_max_steps = 1e4
        num_eval_envs = num_envs
    else:
        eval_env_kwargs = None
        eval_max_steps = None
        num_eval_envs = 0
    sampler = Sampler(
        EnvCls=CWTO_EnvWrapperAtari,
        env_kwargs=env_kwargs,
        batch_T=b_size,
        batch_B=num_envs,
        max_decorrelation_steps=max_decor_steps,
        eval_n_envs=num_eval_envs,
        eval_CollectorCls=eval_collector_cl,
        eval_env_kwargs=eval_env_kwargs,
        eval_max_steps=eval_max_steps,
    )

    player_algo = A2C()
    observer_algo = A2C()
    player = AtariLstmAgent()  #model_kwargs=player_model_kwargs)
    observer = CWTO_AtariLstmAgent()  #model_kwargs=observer_model_kwargs)
    agent = CWTO_AgentWrapper(player, observer, alt=alt, train_mask=train_mask)

    if eval:
        RunnerCl = MinibatchRlEval
    else:
        RunnerCl = MinibatchRl

    runner = RunnerCl(player_algo=player_algo,
                      observer_algo=observer_algo,
                      agent=agent,
                      sampler=sampler,
                      n_steps=n_steps,
                      log_interval_steps=log_interval_steps,
                      affinity=affinity,
                      wandb_log=wandb_log,
                      alt_train=alt_train)
    config = dict(domain=game)
    log_dir = os.getcwd() + "/cwto_logs/" + game
    with logger_context(log_dir, run_ID, game, config):
        runner.train()
    if save_models_to_wandb:
        agent.save_models_to_wandb()
def start_experiment(args):

    args_json = json.dumps(vars(args), indent=4)
    if not os.path.isdir(args.log_dir):
        os.makedirs(args.log_dir)
    with open(args.log_dir + '/arguments.json', 'w') as jsonfile:
        jsonfile.write(args_json)

    config = dict(env_id=args.env)

    if args.sample_mode == 'gpu':
        assert args.num_gpus > 0
        affinity = dict(cuda_idx=0, workers_cpus=list(range(args.num_cpus)))
        os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
    else:
        affinity = dict(workers_cpus=list(range(args.num_cpus)))

    # potentially reload models
    initial_optim_state_dict = None
    initial_model_state_dict = None
    if args.pretrain != 'None':
        os.system(f"find {args.log_dir} -name '*.json' -delete"
                  )  # clean up json files for video recorder
        checkpoint = torch.load(
            os.path.join(_RESULTS_DIR, args.pretrain, 'params.pkl'))
        initial_optim_state_dict = checkpoint['optimizer_state_dict']
        initial_model_state_dict = checkpoint['agent_state_dict']

    # ----------------------------------------------------- POLICY ----------------------------------------------------- #
    model_args = dict(curiosity_kwargs=dict(curiosity_alg=args.curiosity_alg))
    if args.curiosity_alg == 'icm':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs'][
            'prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs'][
            'forward_loss_wt'] = args.forward_loss_wt
    elif args.curiosity_alg == 'disagreement':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['ensemble_size'] = args.ensemble_size
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs'][
            'prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs'][
            'forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'ndigo':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['pred_horizon'] = args.pred_horizon
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['num_predictors'] = args.num_predictors
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'rnd':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs'][
            'prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs'][
            'drop_probability'] = args.drop_probability
        model_args['curiosity_kwargs']['gamma'] = args.discount
        model_args['curiosity_kwargs']['device'] = args.sample_mode

    if args.env in _MUJOCO_ENVS:
        if args.lstm:
            agent = MujocoLstmAgent(
                initial_model_state_dict=initial_model_state_dict)
        else:
            agent = MujocoFfAgent(
                initial_model_state_dict=initial_model_state_dict)
    else:
        if args.lstm:
            agent = AtariLstmAgent(
                initial_model_state_dict=initial_model_state_dict,
                model_kwargs=model_args,
                no_extrinsic=args.no_extrinsic)
        else:
            agent = AtariFfAgent(
                initial_model_state_dict=initial_model_state_dict)

    # ----------------------------------------------------- LEARNING ALG ----------------------------------------------------- #
    if args.alg == 'ppo':
        if args.kernel_mu == 0.:
            kernel_params = None
        else:
            kernel_params = (args.kernel_mu, args.kernel_sigma)
        algo = PPO(
            discount=args.discount,
            learning_rate=args.lr,
            value_loss_coeff=args.v_loss_coeff,
            entropy_loss_coeff=args.entropy_loss_coeff,
            OptimCls=torch.optim.Adam,
            optim_kwargs=None,
            clip_grad_norm=args.grad_norm_bound,
            initial_optim_state_dict=
            initial_optim_state_dict,  # is None is not reloading a checkpoint
            gae_lambda=args.gae_lambda,
            minibatches=args.
            minibatches,  # if recurrent: batch_B needs to be at least equal, if not recurrent: batch_B*batch_T needs to be at least equal to this
            epochs=args.epochs,
            ratio_clip=args.ratio_clip,
            linear_lr_schedule=args.linear_lr,
            normalize_advantage=args.normalize_advantage,
            normalize_reward=args.normalize_reward,
            kernel_params=kernel_params,
            curiosity_type=args.curiosity_alg)
    elif args.alg == 'a2c':
        algo = A2C(discount=args.discount,
                   learning_rate=args.lr,
                   value_loss_coeff=args.v_loss_coeff,
                   entropy_loss_coeff=args.entropy_loss_coeff,
                   OptimCls=torch.optim.Adam,
                   optim_kwargs=None,
                   clip_grad_norm=args.grad_norm_bound,
                   initial_optim_state_dict=initial_optim_state_dict,
                   gae_lambda=args.gae_lambda,
                   normalize_advantage=args.normalize_advantage)

    # ----------------------------------------------------- SAMPLER ----------------------------------------------------- #

    # environment setup
    traj_info_cl = TrajInfo  # environment specific - potentially overriden below
    if 'mario' in args.env.lower():
        env_cl = mario_make
        env_args = dict(game=args.env,
                        no_extrinsic=args.no_extrinsic,
                        no_negative_reward=args.no_negative_reward,
                        normalize_obs=args.normalize_obs,
                        normalize_obs_steps=10000)
    elif 'deepmind' in args.env.lower():  # pycolab deepmind environments
        env_cl = deepmind_make
        traj_info_cl = PycolabTrajInfo
        env_args = dict(game=args.env,
                        no_extrinsic=args.no_extrinsic,
                        no_negative_reward=args.no_negative_reward,
                        normalize_obs=args.normalize_obs,
                        normalize_obs_steps=10000,
                        log_heatmaps=args.log_heatmaps,
                        logdir=args.log_dir,
                        obs_type=args.obs_type,
                        max_steps_per_episode=args.max_episode_steps)
    elif args.env in _MUJOCO_ENVS:
        env_cl = gym_make
        env_args = dict(id=args.env,
                        no_extrinsic=args.no_extrinsic,
                        no_negative_reward=args.no_negative_reward,
                        normalize_obs=False,
                        normalize_obs_steps=10000)
    elif args.env in _ATARI_ENVS:
        env_cl = AtariEnv
        traj_info_cl = AtariTrajInfo
        env_args = dict(
            game=args.env,
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000,
            downsampling_scheme='classical',
            record_freq=args.record_freq,
            record_dir=args.log_dir,
            horizon=args.max_episode_steps,
        )

    if args.sample_mode == 'gpu':
        if args.lstm:
            collector_class = GpuWaitResetCollector
        else:
            collector_class = GpuResetCollector
        sampler = GpuSampler(EnvCls=env_cl,
                             env_kwargs=env_args,
                             eval_env_kwargs=env_args,
                             batch_T=args.timestep_limit,
                             batch_B=args.num_envs,
                             max_decorrelation_steps=0,
                             TrajInfoCls=traj_info_cl,
                             eval_n_envs=args.eval_envs,
                             eval_max_steps=args.eval_max_steps,
                             eval_max_trajectories=args.eval_max_traj,
                             record_freq=args.record_freq,
                             log_dir=args.log_dir,
                             CollectorCls=collector_class)
    else:
        if args.lstm:
            collector_class = CpuWaitResetCollector
        else:
            collector_class = CpuResetCollector
        sampler = CpuSampler(
            EnvCls=env_cl,
            env_kwargs=env_args,
            eval_env_kwargs=env_args,
            batch_T=args.timestep_limit,  # timesteps in a trajectory episode
            batch_B=args.num_envs,  # environments distributed across workers
            max_decorrelation_steps=0,
            TrajInfoCls=traj_info_cl,
            eval_n_envs=args.eval_envs,
            eval_max_steps=args.eval_max_steps,
            eval_max_trajectories=args.eval_max_traj,
            record_freq=args.record_freq,
            log_dir=args.log_dir,
            CollectorCls=collector_class)

    # ----------------------------------------------------- RUNNER ----------------------------------------------------- #
    if args.eval_envs > 0:
        runner = MinibatchRlEval(algo=algo,
                                 agent=agent,
                                 sampler=sampler,
                                 n_steps=args.iterations,
                                 affinity=affinity,
                                 log_interval_steps=args.log_interval,
                                 log_dir=args.log_dir,
                                 pretrain=args.pretrain)
    else:
        runner = MinibatchRl(algo=algo,
                             agent=agent,
                             sampler=sampler,
                             n_steps=args.iterations,
                             affinity=affinity,
                             log_interval_steps=args.log_interval,
                             log_dir=args.log_dir,
                             pretrain=args.pretrain)

    with logger_context(args.log_dir,
                        config,
                        snapshot_mode="last",
                        use_summary_writer=True):
        runner.train()