Esempio n. 1
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    assert isinstance(affinity, list)  # One for each GPU.
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)

    sampler = GpuSampler(
        EnvCls=AtariEnv,
        env_kwargs=config["env"],
        CollectorCls=GpuWaitResetCollector,
        TrajInfoCls=AtariTrajInfo,
        **config["sampler"]
    )
    algo = A2C(optim_kwargs=config["optim"], **config["algo"])
    agent = AtariFfAgent(model_kwargs=config["model"], **config["agent"])
    runner = SyncRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = config["env"]["game"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Esempio n. 2
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = get_affinity(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)

    sampler = GpuParallelSampler(
        EnvCls=AtariEnv,
        env_kwargs=config["env"],
        CollectorCls=WaitResetCollector,
        TrajInfoCls=AtariTrajInfo,
        **config["sampler"]
    )
    algo = A2C(optim_kwargs=config["optim"], **config["algo"])
    agent = AtariLstmAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = config["env"]["game"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Esempio n. 3
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    # variant = load_variant(log_dir)
    # config = update_config(config, variant)

    sampler = CpuSampler(
        EnvCls=AtariEnv,
        env_kwargs=config["env"],
        CollectorCls=EpisodicLivesWaitResetCollector,
        TrajInfoCls=AtariTrajInfo,
        **config["sampler"]
    )
    algo = A2C(optim_kwargs=config["optim"], **config["algo"])
    agent = AtariFfAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = config["env"]["game"]
    with logger_context(log_dir, run_ID, name, config):  # Might have to flatten config
        runner.train()
Esempio n. 4
0
def build_and_train(slot_affinity_code, log_dir, run_ID):
    # (Or load from a central store of configs.)
    config = dict(
        env=dict(game="pong"),
        algo=dict(learning_rate=7e-4),
        sampler=dict(batch_B=16),
    )

    affinity = get_affinity(slot_affinity_code)
    variant = load_variant(log_dir)
    global config
    config = update_config(config, variant)

    sampler = GpuParallelSampler(
        EnvCls=AtariEnv,
        env_kwargs=config["env"],
        CollectorCls=WaitResetCollector,
        batch_T=5,
        # batch_B=16,  # Get from config.
        max_decorrelation_steps=400,
        **config["sampler"])
    algo = A2C(**config["algo"])  # Run with defaults.
    agent = AtariFfAgent()
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e5,
        affinity=affinity,
    )
    name = "a2c_" + config["env"]["game"]
    log_dir = "example_6"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Esempio n. 5
0
def build_and_train(game="pong",
                    run_ID=0,
                    cuda_idx=None,
                    mid_batch_reset=False,
                    n_parallel=2):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    Collector = GpuResetCollector if mid_batch_reset else GpuWaitResetCollector
    print(f"To satisfy mid_batch_reset=={mid_batch_reset}, using {Collector}.")

    sampler = GpuParallelSampler(
        EnvCls=AtariEnv,
        env_kwargs=dict(game=game,
                        num_img_obs=1),  # Learn on individual frames.
        CollectorCls=Collector,
        batch_T=20,  # Longer sampling/optimization horizon for recurrence.
        batch_B=16,  # 16 parallel environments.
        max_decorrelation_steps=400,
    )
    algo = A2C()  # Run with defaults.
    agent = AtariLstmAgent()
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e5,
        affinity=affinity,
    )
    config = dict(game=game)
    name = "a2c_" + game
    log_dir = "example_4"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Esempio n. 6
0
def build_and_train(game="pong",
                    run_ID=0,
                    cuda_idx=None,
                    sample_mode="serial",
                    n_parallel=2):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        Sampler = CpuSampler
        print(
            f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing."
        )
    elif sample_mode == "gpu":
        Sampler = GpuSampler
        print(
            f"Using GPU parallel sampler (agent in master), {gpu_cpu} for sampling and optimizing."
        )
    elif sample_mode == "alternating":
        Sampler = AlternatingSampler
        affinity["workers_cpus"] += affinity["workers_cpus"]  # (Double list)
        affinity["alternating"] = True  # Sampler will check for this.
        print(
            f"Using Alternating GPU parallel sampler, {gpu_cpu} for sampling and optimizing."
        )

    sampler = Sampler(
        EnvCls=AtariEnv,
        TrajInfoCls=AtariTrajInfo,
        env_kwargs=dict(game=game),
        batch_T=5,  # 5 time-steps per sampler iteration.
        batch_B=16,  # 16 parallel environments.
        max_decorrelation_steps=400,
    )
    algo = A2C()  # Run with defaults.
    agent = AtariFfAgent()
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e5,
        affinity=affinity,
    )
    config = dict(game=game)
    name = "a2c_" + game
    log_dir = "example_3"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Esempio n. 7
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)

    sampler = AlternatingSampler(EnvCls=gym_make,
                                 env_kwargs=config["env"],
                                 **config["sampler"])
    algo = A2C(optim_kwargs=config["optim"], **config["algo"])
    agent = MujocoFfAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRl(algo=algo,
                         agent=agent,
                         sampler=sampler,
                         affinity=affinity,
                         **config["runner"])
    name = config["env"]["id"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Esempio n. 8
0
def build_and_train(game="pong", run_ID=0):
    # Seems like we should be able to skip the intermediate step of the code,
    # but so far have just always run that way.
    # Change these inputs to match local machine and desired parallelism.
    affinity_code = encode_affinity(
        n_cpu_cores=16,  # Use 16 cores across all experiments.
        n_gpu=8,  # Use 8 gpus across all experiments.
        hyperthread_offset=24,  # If machine has 24 cores.
        n_socket=2,  # Presume CPU socket affinity to lower/upper half GPUs.
        gpu_per_run=2,  # How many GPUs to parallelize one run across.
        # cpu_per_run=1,
    )
    slot_affinity_code = prepend_run_slot(run_slot=0, affinity_code=affinity_code)
    affinity = get_affinity(slot_affinity_code)
    breakpoint()

    sampler = GpuParallelSampler(
        EnvCls=AtariEnv,
        env_kwargs=dict(game=game),
        CollectorCls=WaitResetCollector,
        batch_T=5,
        batch_B=16,
        max_decorrelation_steps=400,
    )
    algo = A2C()  # Run with defaults.
    agent = AtariFfAgent()
    runner = MultiGpuRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e5,
        affinity=affinity,
    )
    config = dict(game=game)
    name = "a2c_" + game
    log_dir = "example_7"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Esempio n. 9
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)
    config["algo_name"] = 'A2C'
    t_env = pomdp_interface(**config["env"])
    config["algo"]["discount"] = t_env.discount

    sampler = GpuSampler(EnvCls=pomdp_interface,
                         env_kwargs=config["env"],
                         **config["sampler"])
    algo = A2C(optim_kwargs=config["optim"], **config["algo"])
    agent = PomdpFfAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRl(algo=algo,
                         agent=agent,
                         sampler=sampler,
                         affinity=affinity,
                         **config["runner"])
    name = config["env"]["id"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Esempio n. 10
0
def findOptimalAgent(reward, run_ID=0):
    """
    Find the optimal agent for the MDP (see Config for 
    specification) under a custom reward function
    using rlpyt's implementation of A2C.
    """
    cpus = list(range(C.N_PARALLEL))
    affinity = dict(cuda_idx=C.CUDA_IDX, workers_cpus=cpus)
    sampler = SerialSampler(EnvCls=rlpyt_make,
                            env_kwargs=dict(id=C.ENV, reward=reward),
                            batch_T=C.BATCH_T,
                            batch_B=C.BATCH_B,
                            max_decorrelation_steps=400,
                            eval_env_kwargs=dict(id=C.ENV),
                            eval_n_envs=5,
                            eval_max_steps=2500)
    algo = A2C(discount=C.DISCOUNT,
               learning_rate=C.LR,
               value_loss_coeff=C.VALUE_LOSS_COEFF,
               entropy_loss_coeff=C.ENTROPY_LOSS_COEFF)
    agent = CategoricalPgAgent(AcrobotNet)
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=C.N_STEPS,
        log_interval_steps=C.LOG_STEP,
        affinity=affinity,
    )
    name = "a2c_" + C.ENV.lower()
    log_dir = name
    with logger_context(log_dir,
                        run_ID,
                        name,
                        snapshot_mode='last',
                        override_prefix=True):
        runner.train()
    return agent
Esempio n. 11
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)
    config["algo_name"] = 'A2C_RNN'

    env = BatchPOMDPEnv(batch_B=config["sampler"]["batch_B"], **config["env"])
    config["algo"]["discount"] = env.discount
    sampler = BatchPOMDPSampler(env=env, **config["sampler"])

    algo = A2C(optim_kwargs=config["optim"], **config["algo"])
    agent = PomdpRnnAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = config["env"]["id"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Esempio n. 12
0
def build_and_train(env_id="POMDP-hallway-episodic-v0",
                    run_ID=0,
                    cuda_idx=None,
                    n_parallel=6,
                    fomdp=False):
    EnvCls = BatchPOMDPEnv
    SamplerCls = BatchPOMDPSampler
    batch_B = 30
    batch_T = 100
    env_args = dict(fomdp=fomdp, id=env_id, time_limit=100, batch_B=batch_B)
    env = EnvCls(**env_args)
    gamma = env.discount
    affinity = dict(cuda_idx=cuda_idx,
                    workers_cpus=list(range(n_parallel)),
                    alternating=False)
    lr = 1e-3
    po = np.array([1, 0, 0, 1, 0], dtype=bool)
    # Model kwargs
    # model_kwargs = dict()
    # model_kwargs = dict(hidden_sizes=[64, 64], shared_processor=False)
    model_kwargs = dict(hidden_sizes=[64, 64],
                        rnn_type='gru',
                        rnn_size=256,
                        rnn_placement=1,
                        shared_processor=False,
                        layer_norm=True,
                        prev_action='All',
                        prev_reward='All')
    # model_kwargs = dict(hidden_sizes=[64, 64], option_size=4, shared_processor=False, use_interest=False, use_diversity=False, use_attention=False)
    # model_kwargs = dict(hidden_sizes=[64, 64], option_size=4, use_interest=True, use_diversity=False,
    #                     use_attention=False, rnn_type='gru', rnn_size=256, rnn_placement=1, shared_processor=False, layer_norm=True, prev_option=po)
    sampler = SamplerCls(env, batch_T, max_decorrelation_steps=0)

    # Samplers

    # Algos (swapping out discount)
    algo = A2C(discount=gamma, learning_rate=lr, clip_grad_norm=2.)
    # algo = A2OC(discount=gamma, learning_rate=lr, clip_grad_norm=2.)
    # algo = PPO(discount=gamma, learning_rate=lr, clip_grad_norm=2.)
    # algo = PPOC(discount=gamma, learning_rate=lr, clip_grad_norm=2.)

    # Agents
    # agent = PomdpFfAgent(model_kwargs=model_kwargs)
    agent = PomdpRnnAgent(model_kwargs=model_kwargs)
    # agent = PomdpOcFfAgent(model_kwargs=model_kwargs)
    # agent = PomdpOcRnnAgent(model_kwargs=model_kwargs)
    # agent = AlternatingPomdpRnnAgent(model_kwargs=model_kwargs)
    # agent = AlternatingPomdpOcRnnAgent(model_kwargs=model_kwargs)
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=1e6,
        log_interval_steps=1e3,
        affinity=affinity,
    )
    config = dict(env_id=env_id,
                  fomdp=fomdp,
                  algo_name=algo.__class__.__name__,
                  learning_rate=lr,
                  sampler=sampler.__class__.__name__,
                  model=model_kwargs)
    name = algo.NAME + '_' + env_id
    log_dir = "pomdps"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Esempio n. 13
0
def build_and_train(windx,
                    windy,
                    game="pong",
                    run_ID=0,
                    cuda_idx=None,
                    sample_mode="serial",
                    n_parallel=2,
                    num_envs=2,
                    eval=False,
                    train_mask=[True, True],
                    wandb_log=False,
                    save_models_to_wandb=False,
                    log_interval_steps=1e5,
                    alt_train=False,
                    n_steps=50e6,
                    max_episode_length=np.inf,
                    b_size=5,
                    max_decor_steps=10):
    # def envs:
    # player_model_kwargs = dict(hidden_sizes=[32], lstm_size=16, nonlinearity=torch.nn.ReLU, normalize_observation=False,
    #                            norm_obs_clip=10, norm_obs_var_clip=1e-6)
    # observer_model_kwargs = dict(hidden_sizes=[128], lstm_size=16, nonlinearity=torch.nn.ReLU,
    #                              normalize_observation=False, norm_obs_clip=10, norm_obs_var_clip=1e-6)
    player_reward_shaping = None
    observer_reward_shaping = None
    window_size = np.asarray([windx, windy])

    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        alt = False
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        if eval:
            eval_collector_cl = SerialEvalCollector
        else:
            eval_collector_cl = None
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        alt = False
        Sampler = CpuSampler
        if eval:
            eval_collector_cl = CpuEvalCollector
        else:
            eval_collector_cl = None
        print(
            f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing."
        )
    env_kwargs = dict(env_name=game,
                      window_size=window_size,
                      player_reward_shaping=player_reward_shaping,
                      observer_reward_shaping=observer_reward_shaping,
                      max_episode_length=max_episode_length)
    if eval:
        eval_env_kwargs = env_kwargs
        eval_max_steps = 1e4
        num_eval_envs = num_envs
    else:
        eval_env_kwargs = None
        eval_max_steps = None
        num_eval_envs = 0
    sampler = Sampler(
        EnvCls=CWTO_EnvWrapperAtari,
        env_kwargs=env_kwargs,
        batch_T=b_size,
        batch_B=num_envs,
        max_decorrelation_steps=max_decor_steps,
        eval_n_envs=num_eval_envs,
        eval_CollectorCls=eval_collector_cl,
        eval_env_kwargs=eval_env_kwargs,
        eval_max_steps=eval_max_steps,
    )

    player_algo = A2C()
    observer_algo = A2C()
    player = AtariLstmAgent()  #model_kwargs=player_model_kwargs)
    observer = CWTO_AtariLstmAgent()  #model_kwargs=observer_model_kwargs)
    agent = CWTO_AgentWrapper(player, observer, alt=alt, train_mask=train_mask)

    if eval:
        RunnerCl = MinibatchRlEval
    else:
        RunnerCl = MinibatchRl

    runner = RunnerCl(player_algo=player_algo,
                      observer_algo=observer_algo,
                      agent=agent,
                      sampler=sampler,
                      n_steps=n_steps,
                      log_interval_steps=log_interval_steps,
                      affinity=affinity,
                      wandb_log=wandb_log,
                      alt_train=alt_train)
    config = dict(domain=game)
    log_dir = os.getcwd() + "/cwto_logs/" + game
    with logger_context(log_dir, run_ID, game, config):
        runner.train()
    if save_models_to_wandb:
        agent.save_models_to_wandb()
Esempio n. 14
0
def build_and_train(env_id="catch/0", run_ID=0, cuda_idx=None, n_parallel=6):
    EnvCls = BSuiteEnv
    n_episodes = 1e4
    env_args = dict(id=env_id)
    affinity = dict(cuda_idx=cuda_idx,
                    workers_cpus=list(range(n_parallel)),
                    alternating=True)
    lr = 1e-3

    # Model kwargs
    # model_kwargs = dict()
    # model_kwargs = dict(hidden_sizes=[64, 64])
    model_kwargs = dict(hidden_sizes=[64, 64],
                        rnn_type='gru',
                        rnn_size=256,
                        rnn_placement=1,
                        shared_processor=True,
                        layer_norm=True)
    # model_kwargs = dict(hidden_sizes=[64, 64], option_size=4, use_interest=False, use_diversity=False, use_attention=False)
    # model_kwargs = dict(hidden_sizes=[64, 64], option_size=4, use_interest=False, use_diversity=False,
    #                     use_attention=False, rnn_type='gru', rnn_size=128)

    # Samplers
    # sampler = AlternatingSampler(
    #     EnvCls=EnvCls,
    #     env_kwargs=env_args,
    #     eval_env_kwargs=env_args,
    #     batch_T=20,  # One time-step per sampler iteration.
    #     batch_B=30,  # One environment (i.e. sampler Batch dimension).
    #     max_decorrelation_steps=0,
    #     eval_n_envs=5,
    #     eval_max_steps=int(25e3),
    #     eval_max_trajectories=30
    # )
    #
    sampler = SerialSampler(
        EnvCls=EnvCls,
        env_kwargs=env_args,
        eval_env_kwargs=env_args,
        batch_T=32,  # One time-step per sampler iteration.
        batch_B=1,  # One environment (i.e. sampler Batch dimension).
        max_decorrelation_steps=0,
        # eval_n_envs=2,
        # eval_max_steps=int(51e2),
        # eval_max_trajectories=5,
    )

    # Algos (swapping out discount)
    algo = A2C(learning_rate=lr, clip_grad_norm=2.)
    # algo = A2OC(discount=gamma, learning_rate=lr, clip_grad_norm=2.)

    # Agents
    # agent = BsuiteFfAgent(model_kwargs=model_kwargs)
    agent = BsuiteRnnAgent(model_kwargs=model_kwargs)
    # agent = BsuiteOcFfAgent(model_kwargs=model_kwargs)
    # agent = BsuiteOcRnnAgent(model_kwargs=model_kwargs)
    # agent = AlternatingBsuiteRnnAgent(model_kwargs=model_kwargs)
    # agent = AlternatingBsuiteRnnAgent(model_kwargs=model_kwargs)
    # agent = AlternatingBsuiteOcRnnAgent(model_kwargs=model_kwargs)
    runner = EpisodicRlRunner(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=1e8,
        n_episodes=1e4,
        log_interval_steps=1e3,
        affinity=affinity,
    )
    config = dict(env_id=env_id,
                  algo_name=algo.__class__.__name__,
                  learning_rate=lr)
    name = algo.NAME + '_' + env_id
    log_dir = "Bsuites"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Esempio n. 15
0
def build_and_train(env_id="POMDP-hallway-episodic-v0", run_ID=0, cuda_idx=None, n_parallel=1, fomdp=False):
    EnvCls = pomdp_interface
    env_args = dict(fomdp=fomdp, id=env_id, time_limit=100)
    test_instance = EnvCls(**env_args)
    gamma = test_instance.discount
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)), alternating=True)
    lr = 1e-3
    po = np.array([1,0,0,1,0], dtype=bool)
    # Model kwargs
    # model_kwargs = dict()
    # model_kwargs = dict(hidden_sizes=[64, 64], shared_processor=False)
    model_kwargs = dict(hidden_sizes=[64, 64], rnn_type='gru', rnn_size=256, rnn_placement=1, shared_processor=False, layer_norm=True, prev_action=3, prev_reward=3)
    # model_kwargs = dict(hidden_sizes=[64, 64], option_size=4, shared_processor=False, use_interest=False, use_diversity=False, use_attention=False)
    # model_kwargs = dict(hidden_sizes=[64, 64], option_size=4, use_interest=True, use_diversity=False,
    #                     use_attention=False, rnn_type='gru', rnn_size=256, rnn_placement=1, shared_processor=False, layer_norm=True, prev_option=po)

    # Samplers
    sampler = GpuSampler(
        EnvCls=EnvCls,
        env_kwargs=env_args,
        eval_env_kwargs=env_args,
        batch_T=20,  # One time-step per sampler iteration.
        batch_B=30,  # One environment (i.e. sampler Batch dimension).
        max_decorrelation_steps=0,
        eval_n_envs=5,
        eval_max_steps=int(25e3),
        eval_max_trajectories=30
    )
    # sampler = AlternatingSampler(
    #     EnvCls=EnvCls,
    #     env_kwargs=env_args,
    #     eval_env_kwargs=env_args,
    #     batch_T=20,  # One time-step per sampler iteration.
    #     batch_B=30,  # One environment (i.e. sampler Batch dimension).
    #     max_decorrelation_steps=0,
    #     eval_n_envs=5,
    #     eval_max_steps=int(25e3),
    #     eval_max_trajectories=30
    # )
    #
    # sampler = SerialSampler(
    #     EnvCls=EnvCls,
    #     env_kwargs=env_args,
    #     eval_env_kwargs=env_args,
    #     batch_T=20,  # One time-step per sampler iteration.
    #     batch_B=30,  # One environment (i.e. sampler Batch dimension).
    #     max_decorrelation_steps=0,
    #     # eval_n_envs=2,
    #     # eval_max_steps=int(51e2),
    #     # eval_max_trajectories=5,
    # )

    # Algos (swapping out discount)
    algo = A2C(discount=gamma, learning_rate=lr, clip_grad_norm=2.)
    # algo = A2OC(discount=gamma, learning_rate=lr, clip_grad_norm=2.)
    # algo = PPO(discount=gamma, learning_rate=lr, clip_grad_norm=2.)
    # algo = PPOC(discount=gamma, learning_rate=lr, clip_grad_norm=2.)

    # Agents
    # agent = PomdpFfAgent(model_kwargs=model_kwargs)
    agent = PomdpRnnAgent(model_kwargs=model_kwargs)
    # agent = PomdpOcFfAgent(model_kwargs=model_kwargs)
    # agent = PomdpOcRnnAgent(model_kwargs=model_kwargs)
    # agent = AlternatingPomdpRnnAgent(model_kwargs=model_kwargs)
    # agent = AlternatingPomdpOcRnnAgent(model_kwargs=model_kwargs)
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=1e6,
        log_interval_steps=1e3,
        affinity=affinity,
    )
    config = dict(env_id=env_id, fomdp=fomdp, algo_name=algo.__class__.__name__, learning_rate=lr, sampler=sampler.__class__.__name__, model=model_kwargs)
    name = algo.NAME + '_' + env_id
    log_dir = "pomdps"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Esempio n. 16
0
def build_and_train(env_id="Taxi-v3", run_ID=0, cuda_idx=None, n_parallel=6, fomdp=False):
    EnvCls = gym_make if fomdp else make_po_taxi
    env_args = dict(id=env_id)
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)), alternating=True)
    lr = 1e-3

    # Model kwargs
    # model_kwargs = dict()
    model_kwargs = dict(hidden_sizes=[64, 64])
    # model_kwargs = dict(hidden_sizes=[64, 64], rnn_type='gru', rnn_size=128)
    # model_kwargs = dict(hidden_sizes=[64, 64], option_size=4, use_interest=False, use_diversity=False, use_attention=False)
    # model_kwargs = dict(hidden_sizes=[64, 64], option_size=4, use_interest=False, use_diversity=False,
    #                     use_attention=False, rnn_type='gru', rnn_size=128)

    # Samplers
    # sampler = AlternatingSampler(
    #     EnvCls=EnvCls,
    #     env_kwargs=env_args,
    #     eval_env_kwargs=env_args,
    #     batch_T=20,  # One time-step per sampler iteration.
    #     batch_B=30,  # One environment (i.e. sampler Batch dimension).
    #     max_decorrelation_steps=0,
    #     eval_n_envs=5,
    #     eval_max_steps=int(25e3),
    #     eval_max_trajectories=30
    # )
    #
    sampler = SerialSampler(
        EnvCls=EnvCls,
        env_kwargs=env_args,
        eval_env_kwargs=env_args,
        batch_T=20,  # One time-step per sampler iteration.
        batch_B=30,  # One environment (i.e. sampler Batch dimension).
        max_decorrelation_steps=0,
        # eval_n_envs=2,
        # eval_max_steps=int(51e2),
        # eval_max_trajectories=5,
    )

    # Algos (swapping out discount)
    algo = A2C(discount=0.9, learning_rate=lr, entropy_loss_coeff=0.01, normalize_rewards='reward')
    # algo = A2OC(discount=0.618, learning_rate=lr, clip_grad_norm=2.)

    # Agents
    agent = PomdpFfAgent(model_kwargs=model_kwargs)
    # agent = PomdpRnnAgent(model_kwargs=model_kwargs)
    # agent = PomdpOcFfAgent(model_kwargs=model_kwargs)
    # agent = PomdpOcRnnAgent(model_kwargs=model_kwargs)
    # agent = AlternatingPomdpRnnAgent(model_kwargs=model_kwargs)
    # agent = AlternatingPomdpRnnAgent(model_kwargs=model_kwargs)
    # agent = AlternatingPomdpOcRnnAgent(model_kwargs=model_kwargs)
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=1e6,
        log_interval_steps=1e3,
        affinity=affinity,
    )
    config = dict(env_id=env_id, fomdp=fomdp, algo_name=algo.__class__.__name__, learning_rate=lr)
    name = algo.NAME + '_' + env_id
    log_dir = "pomdps"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Esempio n. 17
0
def start_experiment(args):

    args_json = json.dumps(vars(args), indent=4)
    if not os.path.isdir(args.log_dir):
        os.makedirs(args.log_dir)
    with open(args.log_dir + '/arguments.json', 'w') as jsonfile:
        jsonfile.write(args_json)
    with open(args.log_dir + '/git.txt', 'w') as git_file:
        branch = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip().decode('utf-8')
        commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip().decode('utf-8')
        git_file.write('{}/{}'.format(branch, commit))

    config = dict(env_id=args.env)
    
    if args.sample_mode == 'gpu':
        # affinity = dict(num_gpus=args.num_gpus, workers_cpus=list(range(args.num_cpus)))
        if args.num_gpus > 0:
            # import ipdb; ipdb.set_trace()
            affinity = make_affinity(
                run_slot=0,
                n_cpu_core=args.num_cpus,  # Use 16 cores across all experiments.
                n_gpu=args.num_gpus,  # Use 8 gpus across all experiments.
                # contexts_per_gpu=2,
                # hyperthread_offset=72,  # If machine has 24 cores.
                # n_socket=2,  # Presume CPU socket affinity to lower/upper half GPUs.
                gpu_per_run=args.gpu_per_run,  # How many GPUs to parallelize one run across.

                # cpu_per_run=1,
            )
            print('Make multi-gpu affinity')
        else:
            affinity = dict(cuda_idx=0, workers_cpus=list(range(args.num_cpus)))
            os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
    else:
        affinity = dict(workers_cpus=list(range(args.num_cpus)))
    
    # potentially reload models
    initial_optim_state_dict = None
    initial_model_state_dict = None
    if args.pretrain != 'None':
        os.system(f"find {args.log_dir} -name '*.json' -delete") # clean up json files for video recorder
        checkpoint = torch.load(os.path.join(_RESULTS_DIR, args.pretrain, 'params.pkl'))
        initial_optim_state_dict = checkpoint['optimizer_state_dict']
        initial_model_state_dict = checkpoint['agent_state_dict']

    # ----------------------------------------------------- POLICY ----------------------------------------------------- #
    model_args = dict(curiosity_kwargs=dict(curiosity_alg=args.curiosity_alg), curiosity_step_kwargs=dict())
    if args.curiosity_alg =='icm':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['forward_model'] = args.forward_model
        model_args['curiosity_kwargs']['feature_space'] = args.feature_space
    elif args.curiosity_alg == 'micm':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['forward_model'] = args.forward_model
        model_args['curiosity_kwargs']['ensemble_mode'] = args.ensemble_mode
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'disagreement':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['ensemble_size'] = args.ensemble_size
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['device'] = args.sample_mode
        model_args['curiosity_kwargs']['forward_model'] = args.forward_model
    elif args.curiosity_alg == 'ndigo':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['pred_horizon'] = args.pred_horizon
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'rnd':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['drop_probability'] = args.drop_probability
        model_args['curiosity_kwargs']['gamma'] = args.discount
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    
    if args.curiosity_alg != 'none':
        model_args['curiosity_step_kwargs']['curiosity_step_minibatches'] = args.curiosity_step_minibatches

    if args.env in _MUJOCO_ENVS:
        if args.lstm:
            agent = MujocoLstmAgent(initial_model_state_dict=initial_model_state_dict)
        else:
            agent = MujocoFfAgent(initial_model_state_dict=initial_model_state_dict)
    else:
        if args.lstm:            
            agent = AtariLstmAgent(
                        initial_model_state_dict=initial_model_state_dict,
                        model_kwargs=model_args,
                        no_extrinsic=args.no_extrinsic,
                        dual_model=args.dual_model,
                        )
        else:
            agent = AtariFfAgent(initial_model_state_dict=initial_model_state_dict,
                model_kwargs=model_args,
                no_extrinsic=args.no_extrinsic,
                dual_model=args.dual_model)

    # ----------------------------------------------------- LEARNING ALG ----------------------------------------------------- #
    if args.alg == 'ppo':
        algo = PPO(
                discount=args.discount,
                learning_rate=args.lr,
                value_loss_coeff=args.v_loss_coeff,
                entropy_loss_coeff=args.entropy_loss_coeff,
                OptimCls=torch.optim.Adam,
                optim_kwargs=None,
                clip_grad_norm=args.grad_norm_bound,
                initial_optim_state_dict=initial_optim_state_dict, # is None is not reloading a checkpoint
                gae_lambda=args.gae_lambda,
                minibatches=args.minibatches, # if recurrent: batch_B needs to be at least equal, if not recurrent: batch_B*batch_T needs to be at least equal to this
                epochs=args.epochs,
                ratio_clip=args.ratio_clip,
                linear_lr_schedule=args.linear_lr,
                normalize_advantage=args.normalize_advantage,
                normalize_reward=args.normalize_reward,
                curiosity_type=args.curiosity_alg,
                policy_loss_type=args.policy_loss_type
                )
    elif args.alg == 'a2c':
        algo = A2C(
                discount=args.discount,
                learning_rate=args.lr,
                value_loss_coeff=args.v_loss_coeff,
                entropy_loss_coeff=args.entropy_loss_coeff,
                OptimCls=torch.optim.Adam,
                optim_kwargs=None,
                clip_grad_norm=args.grad_norm_bound,
                initial_optim_state_dict=initial_optim_state_dict,
                gae_lambda=args.gae_lambda,
                normalize_advantage=args.normalize_advantage
                )

    # ----------------------------------------------------- SAMPLER ----------------------------------------------------- #

    # environment setup
    traj_info_cl = TrajInfo # environment specific - potentially overriden below
    if 'mario' in args.env.lower():
        env_cl = mario_make
        env_args = dict(
            game=args.env,  
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000
            )
    elif args.env in _PYCOLAB_ENVS:
        env_cl = deepmind_make
        traj_info_cl = PycolabTrajInfo
        env_args = dict(
            game=args.env,
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000,
            log_heatmaps=args.log_heatmaps,
            logdir=args.log_dir,
            obs_type=args.obs_type,
            grayscale=args.grayscale,
            max_steps_per_episode=args.max_episode_steps
            )
    elif args.env in _MUJOCO_ENVS:
        env_cl = gym_make
        env_args = dict(
            id=args.env, 
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=False,
            normalize_obs_steps=10000
            )
    elif args.env in _ATARI_ENVS:
        env_cl = AtariEnv
        traj_info_cl = AtariTrajInfo
        env_args = dict(
            game=args.env, 
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000,
            downsampling_scheme='classical',
            record_freq=args.record_freq,
            record_dir=args.log_dir,
            horizon=args.max_episode_steps,
            score_multiplier=args.score_multiplier,
            repeat_action_probability=args.repeat_action_probability,
            fire_on_reset=args.fire_on_reset
            )

    if args.sample_mode == 'gpu':
        if args.lstm:
            collector_class = GpuWaitResetCollector
        else:
            collector_class = GpuResetCollector
        sampler = GpuSampler(
            EnvCls=env_cl,
            env_kwargs=env_args,
            eval_env_kwargs=env_args,
            batch_T=args.timestep_limit,
            batch_B=args.num_envs,
            max_decorrelation_steps=0,
            TrajInfoCls=traj_info_cl,
            eval_n_envs=args.eval_envs,
            eval_max_steps=args.eval_max_steps,
            eval_max_trajectories=args.eval_max_traj,
            record_freq=args.record_freq,
            log_dir=args.log_dir,
            CollectorCls=collector_class
        )
    else:
        if args.lstm:
            collector_class = CpuWaitResetCollector
        else:
            collector_class = CpuResetCollector
        sampler = CpuSampler(
            EnvCls=env_cl,
            env_kwargs=env_args,
            eval_env_kwargs=env_args,
            batch_T=args.timestep_limit, # timesteps in a trajectory episode
            batch_B=args.num_envs, # environments distributed across workers
            max_decorrelation_steps=0,
            TrajInfoCls=traj_info_cl,
            eval_n_envs=args.eval_envs,
            eval_max_steps=args.eval_max_steps,
            eval_max_trajectories=args.eval_max_traj,
            record_freq=args.record_freq,
            log_dir=args.log_dir,
            CollectorCls=collector_class
            )

    # ----------------------------------------------------- RUNNER ----------------------------------------------------- #     
    if args.eval_envs > 0:
        runner = (MinibatchRlEval if args.num_gpus <= 1 else SyncRlEval)(
            algo=algo,
            agent=agent,
            sampler=sampler,
            n_steps=args.iterations,
            affinity=affinity,
            log_interval_steps=args.log_interval,
            log_dir=args.log_dir,
            pretrain=args.pretrain
            )
    else:
        runner = (MinibatchRl if args.num_gpus <= 1 else SyncRl)(
            algo=algo,
            agent=agent,
            sampler=sampler,
            n_steps=args.iterations,
            affinity=affinity,
            log_interval_steps=args.log_interval,
            log_dir=args.log_dir,
            pretrain=args.pretrain
            )

    with logger_context(args.log_dir, config, snapshot_mode="last", use_summary_writer=True):
        runner.train()
def build_and_train(env_id="GridEnv-v1",
                    run_ID=0,
                    cuda_idx=None,
                    sample_mode="serial",
                    n_parallel=2,
                    args={}):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        Sampler = CpuSampler
        print(
            f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing."
        )
    elif sample_mode == "gpu":
        Sampler = GpuSampler
        print(
            f"Using GPU parallel sampler (agent in master), {gpu_cpu} for sampling and optimizing."
        )
    elif sample_mode == "alternating":
        Sampler = AlternatingSampler
        affinity["workers_cpus"] += affinity["workers_cpus"]  # (Double list)
        affinity["alternating"] = True  # Sampler will check for this.
        print(
            f"Using Alternating GPU parallel sampler, {gpu_cpu} for sampling and optimizing."
        )

    sampler = Sampler(
        EnvCls=gym_make,
        env_kwargs=dict(id=env_id, stochastic=True, p=0.15, size=(7, 7)),
        batch_T=5,  # 5 time-steps per sampler iteration.
        batch_B=16,
        max_decorrelation_steps=1000,
        eval_n_envs=0,
    )

    algo = A2C(learning_rate=args.lr)

    agentCls, agent_basis = get_agent_cls_grid(args.network)
    agent = agentCls(
        model_kwargs={
            'basis': agent_basis,
            'channels': args.channels,
            'kernel_sizes': args.filters,
            'paddings': args.paddings,
            'fc_sizes': args.fcs,
            'strides': args.strides
        })
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=2e6,
        log_interval_steps=10e3,
        affinity=affinity,
    )

    config = dict(env_id=env_id,
                  lr=args.lr,
                  network=args.network,
                  fcs=str(args.fcs),
                  channels=args.channels,
                  strides=args.strides,
                  paddings=args.paddings)
    name = f"{args.folder}_{args.network}"
    log_dir = f"{args.folder}_{args.network}"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
def start_experiment(args):

    args_json = json.dumps(vars(args), indent=4)
    if not os.path.isdir(args.log_dir):
        os.makedirs(args.log_dir)
    with open(args.log_dir + '/arguments.json', 'w') as jsonfile:
        jsonfile.write(args_json)

    config = dict(env_id=args.env)

    if args.sample_mode == 'gpu':
        assert args.num_gpus > 0
        affinity = dict(cuda_idx=0, workers_cpus=list(range(args.num_cpus)))
        os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
    else:
        affinity = dict(workers_cpus=list(range(args.num_cpus)))

    # potentially reload models
    initial_optim_state_dict = None
    initial_model_state_dict = None
    if args.pretrain != 'None':
        os.system(f"find {args.log_dir} -name '*.json' -delete"
                  )  # clean up json files for video recorder
        checkpoint = torch.load(
            os.path.join(_RESULTS_DIR, args.pretrain, 'params.pkl'))
        initial_optim_state_dict = checkpoint['optimizer_state_dict']
        initial_model_state_dict = checkpoint['agent_state_dict']

    # ----------------------------------------------------- POLICY ----------------------------------------------------- #
    model_args = dict(curiosity_kwargs=dict(curiosity_alg=args.curiosity_alg))
    if args.curiosity_alg == 'icm':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs'][
            'prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs'][
            'forward_loss_wt'] = args.forward_loss_wt
    elif args.curiosity_alg == 'disagreement':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['ensemble_size'] = args.ensemble_size
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs'][
            'prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs'][
            'forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'ndigo':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['pred_horizon'] = args.pred_horizon
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['num_predictors'] = args.num_predictors
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'rnd':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs'][
            'prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs'][
            'drop_probability'] = args.drop_probability
        model_args['curiosity_kwargs']['gamma'] = args.discount
        model_args['curiosity_kwargs']['device'] = args.sample_mode

    if args.env in _MUJOCO_ENVS:
        if args.lstm:
            agent = MujocoLstmAgent(
                initial_model_state_dict=initial_model_state_dict)
        else:
            agent = MujocoFfAgent(
                initial_model_state_dict=initial_model_state_dict)
    else:
        if args.lstm:
            agent = AtariLstmAgent(
                initial_model_state_dict=initial_model_state_dict,
                model_kwargs=model_args,
                no_extrinsic=args.no_extrinsic)
        else:
            agent = AtariFfAgent(
                initial_model_state_dict=initial_model_state_dict)

    # ----------------------------------------------------- LEARNING ALG ----------------------------------------------------- #
    if args.alg == 'ppo':
        if args.kernel_mu == 0.:
            kernel_params = None
        else:
            kernel_params = (args.kernel_mu, args.kernel_sigma)
        algo = PPO(
            discount=args.discount,
            learning_rate=args.lr,
            value_loss_coeff=args.v_loss_coeff,
            entropy_loss_coeff=args.entropy_loss_coeff,
            OptimCls=torch.optim.Adam,
            optim_kwargs=None,
            clip_grad_norm=args.grad_norm_bound,
            initial_optim_state_dict=
            initial_optim_state_dict,  # is None is not reloading a checkpoint
            gae_lambda=args.gae_lambda,
            minibatches=args.
            minibatches,  # if recurrent: batch_B needs to be at least equal, if not recurrent: batch_B*batch_T needs to be at least equal to this
            epochs=args.epochs,
            ratio_clip=args.ratio_clip,
            linear_lr_schedule=args.linear_lr,
            normalize_advantage=args.normalize_advantage,
            normalize_reward=args.normalize_reward,
            kernel_params=kernel_params,
            curiosity_type=args.curiosity_alg)
    elif args.alg == 'a2c':
        algo = A2C(discount=args.discount,
                   learning_rate=args.lr,
                   value_loss_coeff=args.v_loss_coeff,
                   entropy_loss_coeff=args.entropy_loss_coeff,
                   OptimCls=torch.optim.Adam,
                   optim_kwargs=None,
                   clip_grad_norm=args.grad_norm_bound,
                   initial_optim_state_dict=initial_optim_state_dict,
                   gae_lambda=args.gae_lambda,
                   normalize_advantage=args.normalize_advantage)

    # ----------------------------------------------------- SAMPLER ----------------------------------------------------- #

    # environment setup
    traj_info_cl = TrajInfo  # environment specific - potentially overriden below
    if 'mario' in args.env.lower():
        env_cl = mario_make
        env_args = dict(game=args.env,
                        no_extrinsic=args.no_extrinsic,
                        no_negative_reward=args.no_negative_reward,
                        normalize_obs=args.normalize_obs,
                        normalize_obs_steps=10000)
    elif 'deepmind' in args.env.lower():  # pycolab deepmind environments
        env_cl = deepmind_make
        traj_info_cl = PycolabTrajInfo
        env_args = dict(game=args.env,
                        no_extrinsic=args.no_extrinsic,
                        no_negative_reward=args.no_negative_reward,
                        normalize_obs=args.normalize_obs,
                        normalize_obs_steps=10000,
                        log_heatmaps=args.log_heatmaps,
                        logdir=args.log_dir,
                        obs_type=args.obs_type,
                        max_steps_per_episode=args.max_episode_steps)
    elif args.env in _MUJOCO_ENVS:
        env_cl = gym_make
        env_args = dict(id=args.env,
                        no_extrinsic=args.no_extrinsic,
                        no_negative_reward=args.no_negative_reward,
                        normalize_obs=False,
                        normalize_obs_steps=10000)
    elif args.env in _ATARI_ENVS:
        env_cl = AtariEnv
        traj_info_cl = AtariTrajInfo
        env_args = dict(
            game=args.env,
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000,
            downsampling_scheme='classical',
            record_freq=args.record_freq,
            record_dir=args.log_dir,
            horizon=args.max_episode_steps,
        )

    if args.sample_mode == 'gpu':
        if args.lstm:
            collector_class = GpuWaitResetCollector
        else:
            collector_class = GpuResetCollector
        sampler = GpuSampler(EnvCls=env_cl,
                             env_kwargs=env_args,
                             eval_env_kwargs=env_args,
                             batch_T=args.timestep_limit,
                             batch_B=args.num_envs,
                             max_decorrelation_steps=0,
                             TrajInfoCls=traj_info_cl,
                             eval_n_envs=args.eval_envs,
                             eval_max_steps=args.eval_max_steps,
                             eval_max_trajectories=args.eval_max_traj,
                             record_freq=args.record_freq,
                             log_dir=args.log_dir,
                             CollectorCls=collector_class)
    else:
        if args.lstm:
            collector_class = CpuWaitResetCollector
        else:
            collector_class = CpuResetCollector
        sampler = CpuSampler(
            EnvCls=env_cl,
            env_kwargs=env_args,
            eval_env_kwargs=env_args,
            batch_T=args.timestep_limit,  # timesteps in a trajectory episode
            batch_B=args.num_envs,  # environments distributed across workers
            max_decorrelation_steps=0,
            TrajInfoCls=traj_info_cl,
            eval_n_envs=args.eval_envs,
            eval_max_steps=args.eval_max_steps,
            eval_max_trajectories=args.eval_max_traj,
            record_freq=args.record_freq,
            log_dir=args.log_dir,
            CollectorCls=collector_class)

    # ----------------------------------------------------- RUNNER ----------------------------------------------------- #
    if args.eval_envs > 0:
        runner = MinibatchRlEval(algo=algo,
                                 agent=agent,
                                 sampler=sampler,
                                 n_steps=args.iterations,
                                 affinity=affinity,
                                 log_interval_steps=args.log_interval,
                                 log_dir=args.log_dir,
                                 pretrain=args.pretrain)
    else:
        runner = MinibatchRl(algo=algo,
                             agent=agent,
                             sampler=sampler,
                             n_steps=args.iterations,
                             affinity=affinity,
                             log_interval_steps=args.log_interval,
                             log_dir=args.log_dir,
                             pretrain=args.pretrain)

    with logger_context(args.log_dir,
                        config,
                        snapshot_mode="last",
                        use_summary_writer=True):
        runner.train()