コード例 #1
0
ファイル: example_6a.py プロジェクト: wwxFromTju/rlpyt
def build_and_train(slot_affinity_code, log_dir, run_ID):
    # (Or load from a central store of configs.)
    config = dict(
        env=dict(game="pong"),
        algo=dict(learning_rate=7e-4),
        sampler=dict(batch_B=16),
    )

    affinity = get_affinity(slot_affinity_code)
    variant = load_variant(log_dir)
    global config
    config = update_config(config, variant)

    sampler = GpuParallelSampler(
        EnvCls=AtariEnv,
        env_kwargs=config["env"],
        CollectorCls=WaitResetCollector,
        batch_T=5,
        # batch_B=16,  # Get from config.
        max_decorrelation_steps=400,
        **config["sampler"])
    algo = A2C(**config["algo"])  # Run with defaults.
    agent = AtariFfAgent()
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e5,
        affinity=affinity,
    )
    name = "a2c_" + config["env"]["game"]
    log_dir = "example_6"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
コード例 #2
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    assert isinstance(affinity, list)  # One for each GPU.
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)

    sampler = GpuSampler(
        EnvCls=AtariEnv,
        env_kwargs=config["env"],
        CollectorCls=GpuWaitResetCollector,
        TrajInfoCls=AtariTrajInfo,
        **config["sampler"]
    )
    algo = A2C(optim_kwargs=config["optim"], **config["algo"])
    agent = AtariFfAgent(model_kwargs=config["model"], **config["agent"])
    runner = SyncRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = config["env"]["game"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
コード例 #3
0
ファイル: atari_ff_a2c_gpu.py プロジェクト: wwxFromTju/rlpyt
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = get_affinity(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)

    sampler = GpuParallelSampler(
        EnvCls=AtariEnv,
        env_kwargs=config["env"],
        CollectorCls=WaitResetCollector,
        TrajInfoCls=AtariTrajInfo,
        **config["sampler"]
    )
    algo = A2C(optim_kwargs=config["optim"], **config["algo"])
    agent = AtariFfAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = config["env"]["game"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
コード例 #4
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    # variant = load_variant(log_dir)
    # config = update_config(config, variant)

    sampler = CpuSampler(
        EnvCls=AtariEnv,
        env_kwargs=config["env"],
        CollectorCls=EpisodicLivesWaitResetCollector,
        TrajInfoCls=AtariTrajInfo,
        **config["sampler"]
    )
    algo = A2C(optim_kwargs=config["optim"], **config["algo"])
    agent = AtariFfAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = config["env"]["game"]
    with logger_context(log_dir, run_ID, name, config):  # Might have to flatten config
        runner.train()
コード例 #5
0
def build_and_train(level="nav_maze_random_goal_01", run_ID=0, cuda_idx=None):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(8)))
    sampler = SerialSampler(
        EnvCls=DeepmindLabEnv,
        env_kwargs=dict(level=level),
        eval_env_kwargs=dict(level=level),
        batch_T=4,  # Four time-steps per sampler iteration.
        batch_B=1,
        max_decorrelation_steps=0,
        eval_n_envs=5,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )
    algo = PPO()
    agent = AtariFfAgent()
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e3,
        affinity=affinity,
    )
    config = dict(level=level)
    name = "lab_ppo"
    log_dir = "lab_example_3"
    with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
        runner.train()
コード例 #6
0
ファイル: example_3.py プロジェクト: kevinghst/rl_ul
def build_and_train(game="pong",
                    run_ID=0,
                    cuda_idx=None,
                    sample_mode="serial",
                    n_parallel=2):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        Sampler = CpuSampler
        print(
            f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing."
        )
    elif sample_mode == "gpu":
        Sampler = GpuSampler
        print(
            f"Using GPU parallel sampler (agent in master), {gpu_cpu} for sampling and optimizing."
        )
    elif sample_mode == "alternating":
        Sampler = AlternatingSampler
        affinity["workers_cpus"] += affinity["workers_cpus"]  # (Double list)
        affinity["alternating"] = True  # Sampler will check for this.
        print(
            f"Using Alternating GPU parallel sampler, {gpu_cpu} for sampling and optimizing."
        )

    sampler = Sampler(
        EnvCls=AtariEnv,
        TrajInfoCls=AtariTrajInfo,
        env_kwargs=dict(game=game),
        batch_T=5,  # 5 time-steps per sampler iteration.
        batch_B=16,  # 16 parallel environments.
        max_decorrelation_steps=400,
    )
    algo = A2C()  # Run with defaults.
    agent = AtariFfAgent()
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e5,
        affinity=affinity,
    )
    config = dict(game=game)
    name = "a2c_" + game
    log_dir = "example_3"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
コード例 #7
0
def build_and_train(game="pong", run_ID=0):
    # Seems like we should be able to skip the intermediate step of the code,
    # but so far have just always run that way.
    # Change these inputs to match local machine and desired parallelism.
    affinity_code = encode_affinity(
        n_cpu_cores=16,  # Use 16 cores across all experiments.
        n_gpu=8,  # Use 8 gpus across all experiments.
        hyperthread_offset=24,  # If machine has 24 cores.
        n_socket=2,  # Presume CPU socket affinity to lower/upper half GPUs.
        gpu_per_run=2,  # How many GPUs to parallelize one run across.
        # cpu_per_run=1,
    )
    slot_affinity_code = prepend_run_slot(run_slot=0, affinity_code=affinity_code)
    affinity = get_affinity(slot_affinity_code)
    breakpoint()

    sampler = GpuParallelSampler(
        EnvCls=AtariEnv,
        env_kwargs=dict(game=game),
        CollectorCls=WaitResetCollector,
        batch_T=5,
        batch_B=16,
        max_decorrelation_steps=400,
    )
    algo = A2C()  # Run with defaults.
    agent = AtariFfAgent()
    runner = MultiGpuRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e5,
        affinity=affinity,
    )
    config = dict(game=game)
    name = "a2c_" + game
    log_dir = "example_7"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
コード例 #8
0
ファイル: baseline.py プロジェクト: lowellw6/intrinsic-rl
def build_and_train(game="breakout",
                    run_ID=0,
                    cuda_idx=None,
                    sample_mode="serial",
                    n_parallel=2):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "gpu":
        Sampler = GpuSampler
        print(
            f"Using GPU parallel sampler (agent in master), {gpu_cpu} for sampling and optimizing."
        )

    env_kwargs = dict(game=game,
                      repeat_action_probability=0.25,
                      horizon=int(18e3))

    sampler = Sampler(
        EnvCls=AtariEnv,
        TrajInfoCls=AtariTrajInfo,  # default traj info + GameScore
        env_kwargs=env_kwargs,
        batch_T=128,
        batch_B=1,
        max_decorrelation_steps=0)

    algo = PPO(minibatches=4,
               epochs=4,
               entropy_loss_coeff=0.001,
               learning_rate=0.0001,
               gae_lambda=0.95,
               discount=0.999)

    base_model_kwargs = dict(  # Same front-end architecture as RND model, different fc kwarg name
        channels=[32, 64, 64],
        kernel_sizes=[8, 4, 4],
        strides=[(4, 4), (2, 2), (1, 1)],
        paddings=[0, 0, 0],
        fc_sizes=[512]
        # Automatically applies nonlinearity=torch.nn.ReLU in this case,
        # but can't specify due to rlpyt limitations
    )
    agent = AtariFfAgent(model_kwargs=base_model_kwargs)

    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=int(
            49152e4
        ),  # this is 30k rollouts per environment at (T, B) = (128, 128)
        log_interval_steps=int(1e3),
        affinity=affinity)

    config = dict(game=game)
    name = "ppo_" + game
    log_dir = "baseline"
    with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
        runner.train()
コード例 #9
0
def start_experiment(args):

    args_json = json.dumps(vars(args), indent=4)
    if not os.path.isdir(args.log_dir):
        os.makedirs(args.log_dir)
    with open(args.log_dir + '/arguments.json', 'w') as jsonfile:
        jsonfile.write(args_json)
    with open(args.log_dir + '/git.txt', 'w') as git_file:
        branch = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip().decode('utf-8')
        commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip().decode('utf-8')
        git_file.write('{}/{}'.format(branch, commit))

    config = dict(env_id=args.env)
    
    if args.sample_mode == 'gpu':
        # affinity = dict(num_gpus=args.num_gpus, workers_cpus=list(range(args.num_cpus)))
        if args.num_gpus > 0:
            # import ipdb; ipdb.set_trace()
            affinity = make_affinity(
                run_slot=0,
                n_cpu_core=args.num_cpus,  # Use 16 cores across all experiments.
                n_gpu=args.num_gpus,  # Use 8 gpus across all experiments.
                # contexts_per_gpu=2,
                # hyperthread_offset=72,  # If machine has 24 cores.
                # n_socket=2,  # Presume CPU socket affinity to lower/upper half GPUs.
                gpu_per_run=args.gpu_per_run,  # How many GPUs to parallelize one run across.

                # cpu_per_run=1,
            )
            print('Make multi-gpu affinity')
        else:
            affinity = dict(cuda_idx=0, workers_cpus=list(range(args.num_cpus)))
            os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
    else:
        affinity = dict(workers_cpus=list(range(args.num_cpus)))
    
    # potentially reload models
    initial_optim_state_dict = None
    initial_model_state_dict = None
    if args.pretrain != 'None':
        os.system(f"find {args.log_dir} -name '*.json' -delete") # clean up json files for video recorder
        checkpoint = torch.load(os.path.join(_RESULTS_DIR, args.pretrain, 'params.pkl'))
        initial_optim_state_dict = checkpoint['optimizer_state_dict']
        initial_model_state_dict = checkpoint['agent_state_dict']

    # ----------------------------------------------------- POLICY ----------------------------------------------------- #
    model_args = dict(curiosity_kwargs=dict(curiosity_alg=args.curiosity_alg), curiosity_step_kwargs=dict())
    if args.curiosity_alg =='icm':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['forward_model'] = args.forward_model
        model_args['curiosity_kwargs']['feature_space'] = args.feature_space
    elif args.curiosity_alg == 'micm':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['forward_model'] = args.forward_model
        model_args['curiosity_kwargs']['ensemble_mode'] = args.ensemble_mode
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'disagreement':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['ensemble_size'] = args.ensemble_size
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['device'] = args.sample_mode
        model_args['curiosity_kwargs']['forward_model'] = args.forward_model
    elif args.curiosity_alg == 'ndigo':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['pred_horizon'] = args.pred_horizon
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'rnd':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['drop_probability'] = args.drop_probability
        model_args['curiosity_kwargs']['gamma'] = args.discount
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    
    if args.curiosity_alg != 'none':
        model_args['curiosity_step_kwargs']['curiosity_step_minibatches'] = args.curiosity_step_minibatches

    if args.env in _MUJOCO_ENVS:
        if args.lstm:
            agent = MujocoLstmAgent(initial_model_state_dict=initial_model_state_dict)
        else:
            agent = MujocoFfAgent(initial_model_state_dict=initial_model_state_dict)
    else:
        if args.lstm:            
            agent = AtariLstmAgent(
                        initial_model_state_dict=initial_model_state_dict,
                        model_kwargs=model_args,
                        no_extrinsic=args.no_extrinsic,
                        dual_model=args.dual_model,
                        )
        else:
            agent = AtariFfAgent(initial_model_state_dict=initial_model_state_dict,
                model_kwargs=model_args,
                no_extrinsic=args.no_extrinsic,
                dual_model=args.dual_model)

    # ----------------------------------------------------- LEARNING ALG ----------------------------------------------------- #
    if args.alg == 'ppo':
        algo = PPO(
                discount=args.discount,
                learning_rate=args.lr,
                value_loss_coeff=args.v_loss_coeff,
                entropy_loss_coeff=args.entropy_loss_coeff,
                OptimCls=torch.optim.Adam,
                optim_kwargs=None,
                clip_grad_norm=args.grad_norm_bound,
                initial_optim_state_dict=initial_optim_state_dict, # is None is not reloading a checkpoint
                gae_lambda=args.gae_lambda,
                minibatches=args.minibatches, # if recurrent: batch_B needs to be at least equal, if not recurrent: batch_B*batch_T needs to be at least equal to this
                epochs=args.epochs,
                ratio_clip=args.ratio_clip,
                linear_lr_schedule=args.linear_lr,
                normalize_advantage=args.normalize_advantage,
                normalize_reward=args.normalize_reward,
                curiosity_type=args.curiosity_alg,
                policy_loss_type=args.policy_loss_type
                )
    elif args.alg == 'a2c':
        algo = A2C(
                discount=args.discount,
                learning_rate=args.lr,
                value_loss_coeff=args.v_loss_coeff,
                entropy_loss_coeff=args.entropy_loss_coeff,
                OptimCls=torch.optim.Adam,
                optim_kwargs=None,
                clip_grad_norm=args.grad_norm_bound,
                initial_optim_state_dict=initial_optim_state_dict,
                gae_lambda=args.gae_lambda,
                normalize_advantage=args.normalize_advantage
                )

    # ----------------------------------------------------- SAMPLER ----------------------------------------------------- #

    # environment setup
    traj_info_cl = TrajInfo # environment specific - potentially overriden below
    if 'mario' in args.env.lower():
        env_cl = mario_make
        env_args = dict(
            game=args.env,  
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000
            )
    elif args.env in _PYCOLAB_ENVS:
        env_cl = deepmind_make
        traj_info_cl = PycolabTrajInfo
        env_args = dict(
            game=args.env,
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000,
            log_heatmaps=args.log_heatmaps,
            logdir=args.log_dir,
            obs_type=args.obs_type,
            grayscale=args.grayscale,
            max_steps_per_episode=args.max_episode_steps
            )
    elif args.env in _MUJOCO_ENVS:
        env_cl = gym_make
        env_args = dict(
            id=args.env, 
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=False,
            normalize_obs_steps=10000
            )
    elif args.env in _ATARI_ENVS:
        env_cl = AtariEnv
        traj_info_cl = AtariTrajInfo
        env_args = dict(
            game=args.env, 
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000,
            downsampling_scheme='classical',
            record_freq=args.record_freq,
            record_dir=args.log_dir,
            horizon=args.max_episode_steps,
            score_multiplier=args.score_multiplier,
            repeat_action_probability=args.repeat_action_probability,
            fire_on_reset=args.fire_on_reset
            )

    if args.sample_mode == 'gpu':
        if args.lstm:
            collector_class = GpuWaitResetCollector
        else:
            collector_class = GpuResetCollector
        sampler = GpuSampler(
            EnvCls=env_cl,
            env_kwargs=env_args,
            eval_env_kwargs=env_args,
            batch_T=args.timestep_limit,
            batch_B=args.num_envs,
            max_decorrelation_steps=0,
            TrajInfoCls=traj_info_cl,
            eval_n_envs=args.eval_envs,
            eval_max_steps=args.eval_max_steps,
            eval_max_trajectories=args.eval_max_traj,
            record_freq=args.record_freq,
            log_dir=args.log_dir,
            CollectorCls=collector_class
        )
    else:
        if args.lstm:
            collector_class = CpuWaitResetCollector
        else:
            collector_class = CpuResetCollector
        sampler = CpuSampler(
            EnvCls=env_cl,
            env_kwargs=env_args,
            eval_env_kwargs=env_args,
            batch_T=args.timestep_limit, # timesteps in a trajectory episode
            batch_B=args.num_envs, # environments distributed across workers
            max_decorrelation_steps=0,
            TrajInfoCls=traj_info_cl,
            eval_n_envs=args.eval_envs,
            eval_max_steps=args.eval_max_steps,
            eval_max_trajectories=args.eval_max_traj,
            record_freq=args.record_freq,
            log_dir=args.log_dir,
            CollectorCls=collector_class
            )

    # ----------------------------------------------------- RUNNER ----------------------------------------------------- #     
    if args.eval_envs > 0:
        runner = (MinibatchRlEval if args.num_gpus <= 1 else SyncRlEval)(
            algo=algo,
            agent=agent,
            sampler=sampler,
            n_steps=args.iterations,
            affinity=affinity,
            log_interval_steps=args.log_interval,
            log_dir=args.log_dir,
            pretrain=args.pretrain
            )
    else:
        runner = (MinibatchRl if args.num_gpus <= 1 else SyncRl)(
            algo=algo,
            agent=agent,
            sampler=sampler,
            n_steps=args.iterations,
            affinity=affinity,
            log_interval_steps=args.log_interval,
            log_dir=args.log_dir,
            pretrain=args.pretrain
            )

    with logger_context(args.log_dir, config, snapshot_mode="last", use_summary_writer=True):
        runner.train()
コード例 #10
0
def start_experiment(args):

    args_json = json.dumps(vars(args), indent=4)
    if not os.path.isdir(args.log_dir):
        os.makedirs(args.log_dir)
    with open(args.log_dir + '/arguments.json', 'w') as jsonfile:
        jsonfile.write(args_json)

    config = dict(env_id=args.env)

    if args.sample_mode == 'gpu':
        assert args.num_gpus > 0
        affinity = dict(cuda_idx=0, workers_cpus=list(range(args.num_cpus)))
        os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
    else:
        affinity = dict(workers_cpus=list(range(args.num_cpus)))

    # potentially reload models
    initial_optim_state_dict = None
    initial_model_state_dict = None
    if args.pretrain != 'None':
        os.system(f"find {args.log_dir} -name '*.json' -delete"
                  )  # clean up json files for video recorder
        checkpoint = torch.load(
            os.path.join(_RESULTS_DIR, args.pretrain, 'params.pkl'))
        initial_optim_state_dict = checkpoint['optimizer_state_dict']
        initial_model_state_dict = checkpoint['agent_state_dict']

    # ----------------------------------------------------- POLICY ----------------------------------------------------- #
    model_args = dict(curiosity_kwargs=dict(curiosity_alg=args.curiosity_alg))
    if args.curiosity_alg == 'icm':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs'][
            'prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs'][
            'forward_loss_wt'] = args.forward_loss_wt
    elif args.curiosity_alg == 'disagreement':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['ensemble_size'] = args.ensemble_size
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs'][
            'prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs'][
            'forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'ndigo':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['pred_horizon'] = args.pred_horizon
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['num_predictors'] = args.num_predictors
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'rnd':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs'][
            'prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs'][
            'drop_probability'] = args.drop_probability
        model_args['curiosity_kwargs']['gamma'] = args.discount
        model_args['curiosity_kwargs']['device'] = args.sample_mode

    if args.env in _MUJOCO_ENVS:
        if args.lstm:
            agent = MujocoLstmAgent(
                initial_model_state_dict=initial_model_state_dict)
        else:
            agent = MujocoFfAgent(
                initial_model_state_dict=initial_model_state_dict)
    else:
        if args.lstm:
            agent = AtariLstmAgent(
                initial_model_state_dict=initial_model_state_dict,
                model_kwargs=model_args,
                no_extrinsic=args.no_extrinsic)
        else:
            agent = AtariFfAgent(
                initial_model_state_dict=initial_model_state_dict)

    # ----------------------------------------------------- LEARNING ALG ----------------------------------------------------- #
    if args.alg == 'ppo':
        if args.kernel_mu == 0.:
            kernel_params = None
        else:
            kernel_params = (args.kernel_mu, args.kernel_sigma)
        algo = PPO(
            discount=args.discount,
            learning_rate=args.lr,
            value_loss_coeff=args.v_loss_coeff,
            entropy_loss_coeff=args.entropy_loss_coeff,
            OptimCls=torch.optim.Adam,
            optim_kwargs=None,
            clip_grad_norm=args.grad_norm_bound,
            initial_optim_state_dict=
            initial_optim_state_dict,  # is None is not reloading a checkpoint
            gae_lambda=args.gae_lambda,
            minibatches=args.
            minibatches,  # if recurrent: batch_B needs to be at least equal, if not recurrent: batch_B*batch_T needs to be at least equal to this
            epochs=args.epochs,
            ratio_clip=args.ratio_clip,
            linear_lr_schedule=args.linear_lr,
            normalize_advantage=args.normalize_advantage,
            normalize_reward=args.normalize_reward,
            kernel_params=kernel_params,
            curiosity_type=args.curiosity_alg)
    elif args.alg == 'a2c':
        algo = A2C(discount=args.discount,
                   learning_rate=args.lr,
                   value_loss_coeff=args.v_loss_coeff,
                   entropy_loss_coeff=args.entropy_loss_coeff,
                   OptimCls=torch.optim.Adam,
                   optim_kwargs=None,
                   clip_grad_norm=args.grad_norm_bound,
                   initial_optim_state_dict=initial_optim_state_dict,
                   gae_lambda=args.gae_lambda,
                   normalize_advantage=args.normalize_advantage)

    # ----------------------------------------------------- SAMPLER ----------------------------------------------------- #

    # environment setup
    traj_info_cl = TrajInfo  # environment specific - potentially overriden below
    if 'mario' in args.env.lower():
        env_cl = mario_make
        env_args = dict(game=args.env,
                        no_extrinsic=args.no_extrinsic,
                        no_negative_reward=args.no_negative_reward,
                        normalize_obs=args.normalize_obs,
                        normalize_obs_steps=10000)
    elif 'deepmind' in args.env.lower():  # pycolab deepmind environments
        env_cl = deepmind_make
        traj_info_cl = PycolabTrajInfo
        env_args = dict(game=args.env,
                        no_extrinsic=args.no_extrinsic,
                        no_negative_reward=args.no_negative_reward,
                        normalize_obs=args.normalize_obs,
                        normalize_obs_steps=10000,
                        log_heatmaps=args.log_heatmaps,
                        logdir=args.log_dir,
                        obs_type=args.obs_type,
                        max_steps_per_episode=args.max_episode_steps)
    elif args.env in _MUJOCO_ENVS:
        env_cl = gym_make
        env_args = dict(id=args.env,
                        no_extrinsic=args.no_extrinsic,
                        no_negative_reward=args.no_negative_reward,
                        normalize_obs=False,
                        normalize_obs_steps=10000)
    elif args.env in _ATARI_ENVS:
        env_cl = AtariEnv
        traj_info_cl = AtariTrajInfo
        env_args = dict(
            game=args.env,
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000,
            downsampling_scheme='classical',
            record_freq=args.record_freq,
            record_dir=args.log_dir,
            horizon=args.max_episode_steps,
        )

    if args.sample_mode == 'gpu':
        if args.lstm:
            collector_class = GpuWaitResetCollector
        else:
            collector_class = GpuResetCollector
        sampler = GpuSampler(EnvCls=env_cl,
                             env_kwargs=env_args,
                             eval_env_kwargs=env_args,
                             batch_T=args.timestep_limit,
                             batch_B=args.num_envs,
                             max_decorrelation_steps=0,
                             TrajInfoCls=traj_info_cl,
                             eval_n_envs=args.eval_envs,
                             eval_max_steps=args.eval_max_steps,
                             eval_max_trajectories=args.eval_max_traj,
                             record_freq=args.record_freq,
                             log_dir=args.log_dir,
                             CollectorCls=collector_class)
    else:
        if args.lstm:
            collector_class = CpuWaitResetCollector
        else:
            collector_class = CpuResetCollector
        sampler = CpuSampler(
            EnvCls=env_cl,
            env_kwargs=env_args,
            eval_env_kwargs=env_args,
            batch_T=args.timestep_limit,  # timesteps in a trajectory episode
            batch_B=args.num_envs,  # environments distributed across workers
            max_decorrelation_steps=0,
            TrajInfoCls=traj_info_cl,
            eval_n_envs=args.eval_envs,
            eval_max_steps=args.eval_max_steps,
            eval_max_trajectories=args.eval_max_traj,
            record_freq=args.record_freq,
            log_dir=args.log_dir,
            CollectorCls=collector_class)

    # ----------------------------------------------------- RUNNER ----------------------------------------------------- #
    if args.eval_envs > 0:
        runner = MinibatchRlEval(algo=algo,
                                 agent=agent,
                                 sampler=sampler,
                                 n_steps=args.iterations,
                                 affinity=affinity,
                                 log_interval_steps=args.log_interval,
                                 log_dir=args.log_dir,
                                 pretrain=args.pretrain)
    else:
        runner = MinibatchRl(algo=algo,
                             agent=agent,
                             sampler=sampler,
                             n_steps=args.iterations,
                             affinity=affinity,
                             log_interval_steps=args.log_interval,
                             log_dir=args.log_dir,
                             pretrain=args.pretrain)

    with logger_context(args.log_dir,
                        config,
                        snapshot_mode="last",
                        use_summary_writer=True):
        runner.train()