Ejemplo n.º 1
0
def build_and_train(env_id="CartPole-v0", run_ID=0, cuda_idx=None):
    sampler = SerialSampler(
        EnvCls=gym_make,
        env_kwargs=dict(id=env_id),
        eval_env_kwargs=dict(id=env_id),
        batch_T=1,  # One time-step per sampler iteration.
        batch_B=1,  # One environment (i.e. sampler Batch dimension).
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(51e3),
        eval_max_trajectories=50,
    )
    algo = PPO()  # Run with defaults.
    agent = RecurrentCategoricalPgAgent()
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=1e6,
        log_interval_steps=1e4,
        affinity=dict(cuda_idx=cuda_idx),
    )
    config = dict(env_id=env_id)
    name = "ppo_" + env_id
    log_dir = "ppo_test"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 2
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)

    sampler = GpuSampler(
        EnvCls=AtariEnv,
        env_kwargs=config["env"],
        CollectorCls=WaitResetCollector,
        TrajInfoCls=AtariTrajInfo,
        **config["sampler"]
    )
    algo = PPO(optim_kwargs=config["optim"], **config["algo"])
    agent = AtariLstmAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = config["env"]["game"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 3
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)

    sampler = CpuSampler(
        EnvCls=gym_make,
        env_kwargs=config["env"],
        CollectorCls=CpuResetCollector,
        **config["sampler"]
    )
    algo = PPO(optim_kwargs=config["optim"], **config["algo"])
    agent = MujocoFfAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        seed=int(run_ID) * 1000,
        **config["runner"]
    )
    name = config["env"]["id"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 4
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    if slot_affinity_code is 'None':
        # affinity = affinity_from_code(run_slot_affinity_code)
        slot_affinity_code = prepend_run_slot(0, affinity_code)
        affinity = affinity_from_code(slot_affinity_code)
    else:
        affinity = affinity_from_code(slot_affinity_code)

    config = configs[config_key]

    # load variant of experiment (there may not be a variant, though)
    variant = load_variant(log_dir)
    config = update_config(config, variant)

    sampler = CpuSampler(
        EnvCls=make_env,
        env_kwargs={},
        CollectorCls=CpuResetCollector,
        **config["sampler"]
    )
    algo = PPO(optim_kwargs=config["optim"], **config["algo"])
    agent = MujocoFfAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = config["env"]["id"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 5
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)

    eval_env_config = config["env"].copy()
    eval_env_config["start_level"] = config["env"]["num_levels"] + 100
    eval_env_config["num_levels"] = 100
    sampler = GpuSampler(
        EnvCls=make,
        env_kwargs=config["env"],
        CollectorCls=GpuResetCollector,
        eval_env_kwargs=eval_env_config,
        **config["sampler"]
    )
    algo = PPO(optim_kwargs=config["optim"], **config["algo"])
    agent = RADPgAgent(ModelCls=RADModel, model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = config["env"]["id"]

    with logger_context(log_dir, run_ID, name, config, snapshot_mode='last'):
        runner.train()
Ejemplo n.º 6
0
def build_and_train(level="nav_maze_random_goal_01", run_ID=0, cuda_idx=None):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(8)))
    sampler = SerialSampler(
        EnvCls=DeepmindLabEnv,
        env_kwargs=dict(level=level),
        eval_env_kwargs=dict(level=level),
        batch_T=4,  # Four time-steps per sampler iteration.
        batch_B=1,
        max_decorrelation_steps=0,
        eval_n_envs=5,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )
    algo = PPO()
    agent = AtariFfAgent()
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e3,
        affinity=affinity,
    )
    config = dict(level=level)
    name = "lab_ppo"
    log_dir = "lab_example_3"
    with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
        runner.train()
Ejemplo n.º 7
0
def makePPOExperiment(config):
    return PPO(discount=config["discount"],
               entropy_loss_coeff=config["entropy_bonus"],
               gae_lambda=config["lambda"],
               minibatches=config["minibatches_per_epoch"],
               epochs=config["epochs_per_rollout"],
               ratio_clip=config["ppo_clip"],
               learning_rate=config["learning_rate"],
               normalize_advantage=True)
Ejemplo n.º 8
0
def build_and_train(game="doom_benchmark", run_ID=0, cuda_idx=None, n_parallel=-1,
                    n_env=-1, n_timestep=-1, sample_mode=None):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))

    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        Sampler = CpuSampler
        print(f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing.")
    elif sample_mode == "gpu":
        Sampler = GpuSampler
        print(f"Using GPU parallel sampler (agent in master), {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "alternating":
        Sampler = AlternatingSampler
        affinity["workers_cpus"] += affinity["workers_cpus"]  # (Double list)
        affinity["alternating"] = True  # Sampler will check for this.
        print(f"Using Alternating GPU parallel sampler, {gpu_cpu} for sampling and optimizing.")

        # !!!
        # COMMENT: to use alternating sampler here we had to comment lines 126-127 in action_server.py
        # if "bootstrap_value" in self.samples_np.agent:
        #     self.bootstrap_value_pair[alt][:] = self.agent.value(*agent_inputs_pair[alt])
        # otherwise it crashes
        # !!!

    sampler = Sampler(
        EnvCls=VizdoomEnv,
        env_kwargs=dict(game=game),
        batch_T=n_timestep,
        batch_B=n_env,
        max_decorrelation_steps=0,
    )
    algo = PPO(minibatches=1, epochs=1)

    agent = DoomLstmAgent()

    # Maybe AsyncRL could give better performance?
    # In the current version however PPO + AsyncRL does not seem to be working (not implemented)
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e5,
        affinity=affinity,
    )
    config = dict(game=game)
    name = "ppo_" + game + str(n_env)
    log_dir = "doom_ppo"

    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 9
0
def build_and_train(
    slot_affinity_code="0slt_1gpu_1cpu",
    log_dir="test",
    run_ID="0",
    config_key="ppo_16env",
    experiment_title="exp",
    snapshot_mode="none",
    snapshot_gap=None,
):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)

    # Hack that the first part of the log_dir matches the source of the model
    model_base_dir = config["pretrain"]["model_dir"]
    if model_base_dir is not None:
        raw_log_dir = log_dir.split(experiment_title)[-1].lstrip(
            "/")  # get rid of ~/GitRepos/adam/rlpyt/data/local/<timestamp>/
        model_sub_dir = raw_log_dir.split("/RlFromUl/")[
            0]  # keep the UL part, which comes first
        config["agent"]["state_dict_filename"] = osp.join(
            model_base_dir, model_sub_dir, "run_0/params.pkl")
    pprint.pprint(config)

    sampler = AlternatingSampler(
        EnvCls=DmlabEnv,
        env_kwargs=config["env"],
        CollectorCls=GpuWaitResetCollector,
        # TrajInfoCls=AtariTrajInfo,
        # eval_env_kwargs=config["env"],  # Same args!
        **config["sampler"])
    algo = PPO(optim_kwargs=config["optim"], **config["algo"])
    agent = DmlabPgLstmAlternatingAgent(model_kwargs=config["model"],
                                        **config["agent"])
    runner = MinibatchRl(algo=algo,
                         agent=agent,
                         sampler=sampler,
                         affinity=affinity,
                         **config["runner"])
    name = config["env"]["level"]
    if snapshot_gap is not None:
        snapshot_gap = int(snapshot_gap)
    with logger_context(
            log_dir,
            run_ID,
            name,
            config,
            snapshot_mode=snapshot_mode,
            snapshot_gap=snapshot_gap,
    ):
        runner.train()
Ejemplo n.º 10
0
def build_and_train(game="pong",
                    run_ID=0,
                    cuda_idx=None,
                    sample_mode="serial",
                    n_parallel=2):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        Sampler = CpuSampler
        print(
            f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing."
        )
    elif sample_mode == "gpu":
        Sampler = GpuSampler
        print(
            f"Using GPU parallel sampler (agent in master), {gpu_cpu} for sampling and optimizing."
        )
    elif sample_mode == "alternating":
        Sampler = AlternatingSampler
        affinity["workers_cpus"] += affinity["workers_cpus"]  # (Double list)
        affinity["alternating"] = True  # Sampler will check for this.
        print(
            f"Using Alternating GPU parallel sampler, {gpu_cpu} for sampling and optimizing."
        )

    sampler = Sampler(
        EnvCls=AtariEnv,
        TrajInfoCls=AtariTrajInfo,
        env_kwargs=dict(game=game),
        batch_T=5,  # 5 time-steps per sampler iteration.
        batch_B=16,  # 16 parallel environments.
        max_decorrelation_steps=400,
    )
    algo = PPO()  # Run with defaults.
    agent = AtariFfAgent()
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=1e3,  #50e6,
        log_interval_steps=1e5,
        affinity=affinity,
    )
    config = dict(game=game)
    name = "ppo" + game
    log_dir = "example_3_test"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 11
0
def build_and_train():
    p = psutil.Process()
    cpus = p.cpu_affinity()
    affinity = dict(cuda_idx=None,
                    master_cpus=cpus,
                    workers_cpus=list([x] for x in cpus),
                    set_affinity=True)
    sampler = CpuSampler(
        EnvCls=_make_env,
        env_kwargs=dict(rank=0),
        max_decorrelation_steps=0,
        batch_T=6000,
        batch_B=len(cpus),  # 20 parallel environments.
    )

    model_kwargs = dict(model_kwargs=dict(hidden_sizes=[256, 256]))
    ppo_config = {
        "discount": 0.98,
        "entropy_loss_coeff": 0.01,
        "learning_rate": 0.00025,
        "value_loss_coeff": 0.5,
        "clip_grad_norm": 0.5,
        "minibatches": 40,
        "gae_lambda": 0.95,
        "ratio_clip": 0.2,
        "epochs": 4
    }

    algo = PPO(**ppo_config)
    agent = MujocoFfAgent(**model_kwargs)

    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=int(60e6),
        log_interval_steps=int(1e6),
        affinity=affinity,
    )
    config = dict(rank=0, env_id='picking')
    name = "ppo_rlpyt_pushing"
    log_dir = os.path.join(os.path.dirname(__file__), name)
    with logger_context(log_dir,
                        0,
                        name,
                        config,
                        use_summary_writer=True,
                        snapshot_mode='all'):
        runner.train()
Ejemplo n.º 12
0
    def run(self):
        config = self.getConfig()
        sampler = CpuSampler(EnvCls=make_env,
                             env_kwargs={"num_levels": config["num_levels"]},
                             batch_T=256,
                             batch_B=8,
                             max_decorrelation_steps=0)

        optim_args = dict(weight_decay=config["l2_penalty"]
                          ) if "l2_penalty" in config else None

        algo = PPO(discount=config["discount"],
                   entropy_loss_coeff=config["entropy_bonus"],
                   gae_lambda=config["lambda"],
                   minibatches=config["minibatches_per_epoch"],
                   epochs=config["epochs_per_rollout"],
                   ratio_clip=config["ppo_clip"],
                   learning_rate=config["learning_rate"],
                   normalize_advantage=True,
                   optim_kwargs=optim_args)
        agent = ImpalaAgent(
            model_kwargs={
                "in_channels": config["in_channels"],
                "out_channels": config["out_channels"],
                "hidden_size": config["hidden_size"]
            })

        affinity = dict(cuda_idx=0,
                        workers_cpus=list(range(config["workers"])))

        runner = MinibatchRl(algo=algo,
                             agent=agent,
                             sampler=sampler,
                             n_steps=25e6,
                             log_interval_steps=500,
                             affinity=affinity,
                             seed=42069)
        log_dir = "./logs"
        name = config["name"]
        run_ID = name
        with logger_context(log_dir,
                            run_ID,
                            name,
                            config,
                            use_summary_writer=True):
            runner.train()
        torch.save(agent.state_dict(), "./" + name + ".pt")
        wandb.save("./" + name + ".pt")
Ejemplo n.º 13
0
def build_and_train(log_dir, run_ID, config_key):
    # affinity = affinity_from_code(run_slot_affinity_code)
    slot_affinity_code = prepend_run_slot(0, affinity_code)
    affinity = affinity_from_code(slot_affinity_code)

    config = configs[config_key]

    sampler = CpuSampler(EnvCls=make_env,
                         env_kwargs={},
                         CollectorCls=CpuResetCollector,
                         **config["sampler"])
    algo = PPO(optim_kwargs=config["optim"], **config["algo"])
    agent = MultiFfAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRl(algo=algo,
                         agent=agent,
                         sampler=sampler,
                         affinity=affinity,
                         **config["runner"])
    name = config["env"]["id"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 14
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)
    env = IsaacGymEnv(config['env']['task'])  # Make env
    import torch.nn as nn
    config["model"]["hidden_nonlinearity"] = getattr(nn, config["model"]["hidden_nonlinearity"])  # Replace string with proper activation
    sampler = IsaacSampler(env, **config["sampler"])
    algo = PPO(optim_kwargs=config["optim"], **config["algo"])
    agent = MujocoFfAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = "ppo_nv_" + config["env"]["task"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 15
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)
    config["algo_name"] = 'PPO_RNN'

    env = BatchPOMDPEnv(batch_B=config["sampler"]["batch_B"], **config["env"])
    config["algo"]["discount"] = env.discount
    sampler = BatchPOMDPSampler(env=env, **config["sampler"])

    algo = PPO(optim_kwargs=config["optim"], **config["algo"])
    agent = AlternatingPomdpRnnAgent(model_kwargs=config["model"],
                                     **config["agent"])
    runner = MinibatchRl(algo=algo,
                         agent=agent,
                         sampler=sampler,
                         affinity=affinity,
                         **config["runner"])
    name = config["env"]["id"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 16
0
def build_and_train(game="pong", run_ID=0, cuda_idx=None):
    sampler = SerialSampler(
        # EnvCls=MyEnv,
        # env_kwargs=dict(),
        # batch_T=4,  # Four time-steps per sampler iteration.
        # batch_B=1,
        # max_decorrelation_steps=0,
        # eval_n_envs=10,
        # eval_env_kwargs=dict(),
        # eval_max_steps=int(10e3),
        # eval_max_trajectories=5,
        EnvCls=CanvasEnv,
        env_kwargs=dict(),
        batch_T=1,  # 5 time-steps per sampler iteration.
        batch_B=16,  # 16 parallel environments.
        max_decorrelation_steps=400,
    )
    algo = PPO()
    agent = CategoricalPgAgent(
        ModelCls=MyModel,
        model_kwargs=dict(image_shape=(1, CANVAS_WIDTH, CANVAS_WIDTH),
                          output_size=N_ACTIONS),
        initial_model_state_dict=None,
    )
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e3,
        affinity=dict(cuda_idx=cuda_idx),
    )

    config = dict()
    name = "dqn_" + game
    log_dir = "example_1"
    with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
        runner.train()
Ejemplo n.º 17
0
def build_and_train(game="academy_empty_goal_close", run_ID=1, cuda_idx=None):
    env_vector_size = args.envVectorSize
    coach = Coach(envOptions=args.envOptions,
                  vectorSize=env_vector_size,
                  algo='Bandit',
                  initialQ=args.initialQ,
                  beta=args.beta)
    sampler = SerialSampler(
        EnvCls=create_single_football_env,
        env_kwargs=dict(game=game),
        eval_env_kwargs=dict(game=game),
        batch_T=5,  # Four time-steps per sampler iteration.
        batch_B=env_vector_size,
        max_decorrelation_steps=0,
        eval_n_envs=args.evalNumOfEnvs,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
        coach=coach,
        eval_env=args.evalEnv,
    )
    algo = PPO(minibatches=1)  # Run with defaults.
    agent = AtariLstmAgent()  # TODO: move to ff
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=args.numOfSteps,
        log_interval_steps=1e3,
        affinity=dict(cuda_idx=cuda_idx),
    )
    name = args.name
    log_dir = "example_1"
    with logger_context(log_dir,
                        run_ID,
                        name,
                        log_params=vars(args),
                        snapshot_mode="last"):
        runner.train()
Ejemplo n.º 18
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)

    sampler = GpuSampler(
        EnvCls=gym.make,
        env_kwargs=config["env"],
        CollectorCls=GpuResetCollector,
        eval_env_kwargs=config["eval_env"],
        **config["sampler"]
    )
    if config["checkpoint"]:
        model_state_dict = torch.load(config["checkpoint"])
    else:
        model_state_dict = None

    algo = PPO(optim_kwargs=config["optim"], **config["algo"])
    agent = CategoricalPgAgent(
        ModelCls=BaselinePolicy,
        model_kwargs=config["model"],
        initial_model_state_dict=model_state_dict,
        **config["agent"]
    )
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = config["env"]["id"]

    with logger_context(log_dir, run_ID, name, config, snapshot_mode='last'):
        runner.train()
def build_and_train(game="doom_benchmark",
                    run_ID=0,
                    cuda_idx=None,
                    n_parallel=-1,
                    n_env=-1,
                    n_timestep=-1,
                    sample_mode=None):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))

    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        Sampler = CpuSampler
        print(
            f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing."
        )
    elif sample_mode == "gpu":
        Sampler = GpuSampler
        print(
            f"Using GPU parallel sampler (agent in master), {gpu_cpu} for sampling and optimizing."
        )
    elif sample_mode == "alternating":
        Sampler = AlternatingSampler
        affinity["workers_cpus"] += affinity["workers_cpus"]  # (Double list)
        affinity["alternating"] = True  # Sampler will check for this.
        print(
            f"Using Alternating GPU parallel sampler, {gpu_cpu} for sampling and optimizing."
        )

        # !!!
        # COMMENT: to use alternating sampler here we had to comment lines 126-127 in action_server.py
        # if "bootstrap_value" in self.samples_np.agent:
        #     self.bootstrap_value_pair[alt][:] = self.agent.value(*agent_inputs_pair[alt])
        # otherwise it crashes
        # !!!

    sampler = Sampler(
        EnvCls=DmlabEnv,
        env_kwargs=dict(game=game),
        batch_T=n_timestep,
        batch_B=n_env,
        max_decorrelation_steps=0,
    )
    # using decorrelation here completely destroys the performance, because episodes will reset at different times and the learner will wait for 1-2 workers to complete, wasting a lot of time
    # this should not be an issue with asynchronous implementation, but it is not supported at the moment

    algo = PPO(minibatches=1, epochs=1)

    agent = DoomLstmAgent()

    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e5,
        affinity=affinity,
    )
    config = dict(game=game)
    name = "ppo_" + game
    log_dir = "dmlab_ppo"

    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 20
0
def start_experiment(args):

    args_json = json.dumps(vars(args), indent=4)
    if not os.path.isdir(args.log_dir):
        os.makedirs(args.log_dir)
    with open(args.log_dir + '/arguments.json', 'w') as jsonfile:
        jsonfile.write(args_json)
    with open(args.log_dir + '/git.txt', 'w') as git_file:
        branch = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip().decode('utf-8')
        commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip().decode('utf-8')
        git_file.write('{}/{}'.format(branch, commit))

    config = dict(env_id=args.env)
    
    if args.sample_mode == 'gpu':
        # affinity = dict(num_gpus=args.num_gpus, workers_cpus=list(range(args.num_cpus)))
        if args.num_gpus > 0:
            # import ipdb; ipdb.set_trace()
            affinity = make_affinity(
                run_slot=0,
                n_cpu_core=args.num_cpus,  # Use 16 cores across all experiments.
                n_gpu=args.num_gpus,  # Use 8 gpus across all experiments.
                # contexts_per_gpu=2,
                # hyperthread_offset=72,  # If machine has 24 cores.
                # n_socket=2,  # Presume CPU socket affinity to lower/upper half GPUs.
                gpu_per_run=args.gpu_per_run,  # How many GPUs to parallelize one run across.

                # cpu_per_run=1,
            )
            print('Make multi-gpu affinity')
        else:
            affinity = dict(cuda_idx=0, workers_cpus=list(range(args.num_cpus)))
            os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
    else:
        affinity = dict(workers_cpus=list(range(args.num_cpus)))
    
    # potentially reload models
    initial_optim_state_dict = None
    initial_model_state_dict = None
    if args.pretrain != 'None':
        os.system(f"find {args.log_dir} -name '*.json' -delete") # clean up json files for video recorder
        checkpoint = torch.load(os.path.join(_RESULTS_DIR, args.pretrain, 'params.pkl'))
        initial_optim_state_dict = checkpoint['optimizer_state_dict']
        initial_model_state_dict = checkpoint['agent_state_dict']

    # ----------------------------------------------------- POLICY ----------------------------------------------------- #
    model_args = dict(curiosity_kwargs=dict(curiosity_alg=args.curiosity_alg), curiosity_step_kwargs=dict())
    if args.curiosity_alg =='icm':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['forward_model'] = args.forward_model
        model_args['curiosity_kwargs']['feature_space'] = args.feature_space
    elif args.curiosity_alg == 'micm':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['forward_model'] = args.forward_model
        model_args['curiosity_kwargs']['ensemble_mode'] = args.ensemble_mode
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'disagreement':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['ensemble_size'] = args.ensemble_size
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['device'] = args.sample_mode
        model_args['curiosity_kwargs']['forward_model'] = args.forward_model
    elif args.curiosity_alg == 'ndigo':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['pred_horizon'] = args.pred_horizon
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'rnd':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['drop_probability'] = args.drop_probability
        model_args['curiosity_kwargs']['gamma'] = args.discount
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    
    if args.curiosity_alg != 'none':
        model_args['curiosity_step_kwargs']['curiosity_step_minibatches'] = args.curiosity_step_minibatches

    if args.env in _MUJOCO_ENVS:
        if args.lstm:
            agent = MujocoLstmAgent(initial_model_state_dict=initial_model_state_dict)
        else:
            agent = MujocoFfAgent(initial_model_state_dict=initial_model_state_dict)
    else:
        if args.lstm:            
            agent = AtariLstmAgent(
                        initial_model_state_dict=initial_model_state_dict,
                        model_kwargs=model_args,
                        no_extrinsic=args.no_extrinsic,
                        dual_model=args.dual_model,
                        )
        else:
            agent = AtariFfAgent(initial_model_state_dict=initial_model_state_dict,
                model_kwargs=model_args,
                no_extrinsic=args.no_extrinsic,
                dual_model=args.dual_model)

    # ----------------------------------------------------- LEARNING ALG ----------------------------------------------------- #
    if args.alg == 'ppo':
        algo = PPO(
                discount=args.discount,
                learning_rate=args.lr,
                value_loss_coeff=args.v_loss_coeff,
                entropy_loss_coeff=args.entropy_loss_coeff,
                OptimCls=torch.optim.Adam,
                optim_kwargs=None,
                clip_grad_norm=args.grad_norm_bound,
                initial_optim_state_dict=initial_optim_state_dict, # is None is not reloading a checkpoint
                gae_lambda=args.gae_lambda,
                minibatches=args.minibatches, # if recurrent: batch_B needs to be at least equal, if not recurrent: batch_B*batch_T needs to be at least equal to this
                epochs=args.epochs,
                ratio_clip=args.ratio_clip,
                linear_lr_schedule=args.linear_lr,
                normalize_advantage=args.normalize_advantage,
                normalize_reward=args.normalize_reward,
                curiosity_type=args.curiosity_alg,
                policy_loss_type=args.policy_loss_type
                )
    elif args.alg == 'a2c':
        algo = A2C(
                discount=args.discount,
                learning_rate=args.lr,
                value_loss_coeff=args.v_loss_coeff,
                entropy_loss_coeff=args.entropy_loss_coeff,
                OptimCls=torch.optim.Adam,
                optim_kwargs=None,
                clip_grad_norm=args.grad_norm_bound,
                initial_optim_state_dict=initial_optim_state_dict,
                gae_lambda=args.gae_lambda,
                normalize_advantage=args.normalize_advantage
                )

    # ----------------------------------------------------- SAMPLER ----------------------------------------------------- #

    # environment setup
    traj_info_cl = TrajInfo # environment specific - potentially overriden below
    if 'mario' in args.env.lower():
        env_cl = mario_make
        env_args = dict(
            game=args.env,  
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000
            )
    elif args.env in _PYCOLAB_ENVS:
        env_cl = deepmind_make
        traj_info_cl = PycolabTrajInfo
        env_args = dict(
            game=args.env,
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000,
            log_heatmaps=args.log_heatmaps,
            logdir=args.log_dir,
            obs_type=args.obs_type,
            grayscale=args.grayscale,
            max_steps_per_episode=args.max_episode_steps
            )
    elif args.env in _MUJOCO_ENVS:
        env_cl = gym_make
        env_args = dict(
            id=args.env, 
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=False,
            normalize_obs_steps=10000
            )
    elif args.env in _ATARI_ENVS:
        env_cl = AtariEnv
        traj_info_cl = AtariTrajInfo
        env_args = dict(
            game=args.env, 
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000,
            downsampling_scheme='classical',
            record_freq=args.record_freq,
            record_dir=args.log_dir,
            horizon=args.max_episode_steps,
            score_multiplier=args.score_multiplier,
            repeat_action_probability=args.repeat_action_probability,
            fire_on_reset=args.fire_on_reset
            )

    if args.sample_mode == 'gpu':
        if args.lstm:
            collector_class = GpuWaitResetCollector
        else:
            collector_class = GpuResetCollector
        sampler = GpuSampler(
            EnvCls=env_cl,
            env_kwargs=env_args,
            eval_env_kwargs=env_args,
            batch_T=args.timestep_limit,
            batch_B=args.num_envs,
            max_decorrelation_steps=0,
            TrajInfoCls=traj_info_cl,
            eval_n_envs=args.eval_envs,
            eval_max_steps=args.eval_max_steps,
            eval_max_trajectories=args.eval_max_traj,
            record_freq=args.record_freq,
            log_dir=args.log_dir,
            CollectorCls=collector_class
        )
    else:
        if args.lstm:
            collector_class = CpuWaitResetCollector
        else:
            collector_class = CpuResetCollector
        sampler = CpuSampler(
            EnvCls=env_cl,
            env_kwargs=env_args,
            eval_env_kwargs=env_args,
            batch_T=args.timestep_limit, # timesteps in a trajectory episode
            batch_B=args.num_envs, # environments distributed across workers
            max_decorrelation_steps=0,
            TrajInfoCls=traj_info_cl,
            eval_n_envs=args.eval_envs,
            eval_max_steps=args.eval_max_steps,
            eval_max_trajectories=args.eval_max_traj,
            record_freq=args.record_freq,
            log_dir=args.log_dir,
            CollectorCls=collector_class
            )

    # ----------------------------------------------------- RUNNER ----------------------------------------------------- #     
    if args.eval_envs > 0:
        runner = (MinibatchRlEval if args.num_gpus <= 1 else SyncRlEval)(
            algo=algo,
            agent=agent,
            sampler=sampler,
            n_steps=args.iterations,
            affinity=affinity,
            log_interval_steps=args.log_interval,
            log_dir=args.log_dir,
            pretrain=args.pretrain
            )
    else:
        runner = (MinibatchRl if args.num_gpus <= 1 else SyncRl)(
            algo=algo,
            agent=agent,
            sampler=sampler,
            n_steps=args.iterations,
            affinity=affinity,
            log_interval_steps=args.log_interval,
            log_dir=args.log_dir,
            pretrain=args.pretrain
            )

    with logger_context(args.log_dir, config, snapshot_mode="last", use_summary_writer=True):
        runner.train()
def build_and_train(env_id="CartPole-v1",
                    run_ID=0,
                    cuda_idx=None,
                    sample_mode="serial",
                    n_parallel=2,
                    args={}):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        Sampler = CpuSampler
        print(
            f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing."
        )
    elif sample_mode == "gpu":
        Sampler = GpuSampler
        print(
            f"Using GPU parallel sampler (agent in master), {gpu_cpu} for sampling and optimizing."
        )
    elif sample_mode == "alternating":
        Sampler = AlternatingSampler
        affinity["workers_cpus"] += affinity["workers_cpus"]  # (Double list)
        affinity["alternating"] = True  # Sampler will check for this.
        print(
            f"Using Alternating GPU parallel sampler, {gpu_cpu} for sampling and optimizing."
        )

    sampler = Sampler(
        EnvCls=gym_make,
        env_kwargs=dict(id=env_id),
        eval_env_kwargs=dict(id=env_id),
        batch_T=5,  # 5 time-steps per sampler iteration.
        batch_B=16,  # 16 parallel environments.
        max_decorrelation_steps=400,
        eval_n_envs=25,
        eval_max_steps=12500)

    algo = PPO(learning_rate=args.lr)

    agentCls, agent_basis = get_agent_cls_cartpole(args.network)

    agent = agentCls(model_kwargs={
        'fc_sizes': args.fcs,
        'gain_type': args.gain_type,
        'basis': agent_basis
    })
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=1e5,
        log_interval_steps=5e2,
        affinity=affinity,
    )

    config = dict(env_id=env_id,
                  lr=args.lr,
                  gain_type=args.gain_type,
                  debug=False,
                  network=args.network,
                  fcs=str(args.fcs))

    name = f"{args.folder}_{args.network}"
    log_dir = f"{args.folder}_{args.network}"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 22
0
sampler = CpuSampler(
    EnvCls=treechop_env_creator,
    env_kwargs=dict(env_id=env_id),
    batch_T=2048,  # One time-step per sampler iteration.
    batch_B=2,  # One environment (i.e. sampler Batch dimension).
    max_decorrelation_steps=0,
)

algo = PPO(
    discount=0.99,
    learning_rate=2.5e-4,
    value_loss_coeff=0.5,
    entropy_loss_coeff=0.01,
    clip_grad_norm=0.5,
    initial_optim_state_dict=None,
    gae_lambda=0.95,
    minibatches=4,
    epochs=10,
    ratio_clip=0.1,
    linear_lr_schedule=False,
    normalize_advantage=False,
)

runner = MinibatchRl(
    algo=algo,
    agent=agent,
    sampler=sampler,
    n_steps=150,
    log_interval_steps=1,
    affinity=dict(cuda_idx=None,
                  workers_cpus=[i for i in range(12)],
Ejemplo n.º 23
0
    }

    PPO_kwargs={
        'learning_rate': 3e-4,
        'clip_vf_loss': False,
        'entropy_loss_coeff': 0.,
        'discount': 0.99,
        'linear_lr_schedule': False,
        'epochs': 10,
        'clip_grad_norm': 2.,
        'minibatches': 2,
        'normalize_rewards': None,
        'value_loss_coeff': 2.
    }
    agent = MujocoFfAgent(model_kwargs=model_kwargs)
    algo = PPO(**PPO_kwargs)
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=1e8,
        log_interval_steps=1e4,
        affinity=affinity,
        transfer=True,
        transfer_iter=transfer_iter,
        # log_traj_window=10
    )
    config = dict(task=task)
    name = "ppo_nt_nv_" + task
    log_dir = "example_2a"
    with logger_context(log_dir, run_ID, name, config):
Ejemplo n.º 24
0
def build_and_train(game="breakout",
                    run_ID=0,
                    cuda_idx=None,
                    sample_mode="serial",
                    n_parallel=2):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "gpu":
        Sampler = GpuSampler
        print(
            f"Using GPU parallel sampler (agent in master), {gpu_cpu} for sampling and optimizing."
        )

    env_kwargs = dict(game=game,
                      repeat_action_probability=0.25,
                      horizon=int(18e3))

    sampler = Sampler(
        EnvCls=AtariEnv,
        TrajInfoCls=AtariTrajInfo,  # default traj info + GameScore
        env_kwargs=env_kwargs,
        batch_T=128,
        batch_B=1,
        max_decorrelation_steps=0)

    algo = PPO(minibatches=4,
               epochs=4,
               entropy_loss_coeff=0.001,
               learning_rate=0.0001,
               gae_lambda=0.95,
               discount=0.999)

    base_model_kwargs = dict(  # Same front-end architecture as RND model, different fc kwarg name
        channels=[32, 64, 64],
        kernel_sizes=[8, 4, 4],
        strides=[(4, 4), (2, 2), (1, 1)],
        paddings=[0, 0, 0],
        fc_sizes=[512]
        # Automatically applies nonlinearity=torch.nn.ReLU in this case,
        # but can't specify due to rlpyt limitations
    )
    agent = AtariFfAgent(model_kwargs=base_model_kwargs)

    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=int(
            49152e4
        ),  # this is 30k rollouts per environment at (T, B) = (128, 128)
        log_interval_steps=int(1e3),
        affinity=affinity)

    config = dict(game=game)
    name = "ppo_" + game
    log_dir = "baseline"
    with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
        runner.train()
Ejemplo n.º 25
0
def build_and_train(game="cartpole",
                    run_ID=0,
                    cuda_idx=None,
                    sample_mode="serial",
                    n_parallel=2,
                    eval=False,
                    serial=False,
                    train_mask=[True, True],
                    wandb_log=False,
                    save_models_to_wandb=False,
                    log_interval_steps=1e5,
                    observation_mode="agent",
                    inc_player_last_act=False,
                    alt_train=False,
                    eval_perf=False,
                    n_steps=50e6,
                    one_agent=False):
    # def envs:
    if observation_mode == "agent":
        fully_obs = False
        rand_obs = False
    elif observation_mode == "random":
        fully_obs = False
        rand_obs = True
    elif observation_mode == "full":
        fully_obs = True
        rand_obs = False

    n_serial = None
    if game == "cartpole":
        work_env = gym.make
        env_name = 'CartPole-v1'
        cont_act = False
        state_space_low = np.asarray([
            0.0, 0.0, 0.0, 0.0, -4.8000002e+00, -3.4028235e+38, -4.1887903e-01,
            -3.4028235e+38
        ])
        state_space_high = np.asarray([
            1.0, 1.0, 1.0, 1.0, 4.8000002e+00, 3.4028235e+38, 4.1887903e-01,
            3.4028235e+38
        ])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = work_env(env_name).action_space
        player_act_space.shape = (1, )
        print(player_act_space)
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 1),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = player_reward_shaping_cartpole
        observer_reward_shaping = observer_reward_shaping_cartpole
        max_decor_steps = 20
        b_size = 20
        num_envs = 8
        max_episode_length = np.inf
        player_model_kwargs = dict(hidden_sizes=[24],
                                   lstm_size=16,
                                   nonlinearity=torch.nn.ReLU,
                                   normalize_observation=False,
                                   norm_obs_clip=10,
                                   norm_obs_var_clip=1e-6)
        observer_model_kwargs = dict(hidden_sizes=[64],
                                     lstm_size=16,
                                     nonlinearity=torch.nn.ReLU,
                                     normalize_observation=False,
                                     norm_obs_clip=10,
                                     norm_obs_var_clip=1e-6)

    elif game == "hiv":
        work_env = wn.gym.make
        env_name = 'HIV-v0'
        cont_act = False
        state_space_low = np.asarray(
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
        state_space_high = np.asarray([
            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, np.inf, np.inf, np.inf, np.inf,
            np.inf, np.inf
        ])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = work_env(env_name).action_space
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 3),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = player_reward_shaping_hiv
        observer_reward_shaping = observer_reward_shaping_hiv
        max_decor_steps = 10
        b_size = 32
        num_envs = 8
        max_episode_length = 100
        player_model_kwargs = dict(hidden_sizes=[32],
                                   lstm_size=16,
                                   nonlinearity=torch.nn.ReLU,
                                   normalize_observation=False,
                                   norm_obs_clip=10,
                                   norm_obs_var_clip=1e-6)
        observer_model_kwargs = dict(hidden_sizes=[64],
                                     lstm_size=16,
                                     nonlinearity=torch.nn.ReLU,
                                     normalize_observation=False,
                                     norm_obs_clip=10,
                                     norm_obs_var_clip=1e-6)

    elif game == "heparin":
        work_env = HeparinEnv
        env_name = 'Heparin-Simulator'
        cont_act = False
        state_space_low = np.asarray([
            0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 18728.926, 72.84662, 0.0, 0.0,
            0.0, 0.0, 0.0
        ])
        state_space_high = np.asarray([
            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.7251439e+04, 1.0664291e+02,
            200.0, 8.9383472e+02, 1.0025734e+02, 1.5770737e+01, 4.7767456e+01
        ])
        # state_space_low = np.asarray([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,18728.926,72.84662,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])
        # state_space_high = np.asarray([1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,2.7251439e+04,1.0664291e+02,0.0000000e+00,8.9383472e+02,1.4476662e+02,1.3368750e+02,1.6815166e+02,1.0025734e+02,1.5770737e+01,4.7767456e+01,7.7194958e+00])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = work_env(env_name).action_space
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 4),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = player_reward_shaping_hep
        observer_reward_shaping = observer_reward_shaping_hep
        max_decor_steps = 3
        b_size = 20
        num_envs = 8
        max_episode_length = 20
        player_model_kwargs = dict(hidden_sizes=[32],
                                   lstm_size=16,
                                   nonlinearity=torch.nn.ReLU,
                                   normalize_observation=False,
                                   norm_obs_clip=10,
                                   norm_obs_var_clip=1e-6)
        observer_model_kwargs = dict(hidden_sizes=[128],
                                     lstm_size=16,
                                     nonlinearity=torch.nn.ReLU,
                                     normalize_observation=False,
                                     norm_obs_clip=10,
                                     norm_obs_var_clip=1e-6)

    elif game == "halfcheetah":
        assert not serial
        assert not one_agent
        work_env = gym.make
        env_name = 'HalfCheetah-v2'
        cont_act = True
        temp_env = work_env(env_name)
        state_space_low = np.concatenate([
            np.zeros(temp_env.observation_space.low.shape),
            temp_env.observation_space.low
        ])
        state_space_high = np.concatenate([
            np.ones(temp_env.observation_space.high.shape),
            temp_env.observation_space.high
        ])
        obs_space = Box(state_space_low, state_space_high, dtype=np.float32)
        player_act_space = temp_env.action_space
        if inc_player_last_act:
            observer_obs_space = Box(np.append(state_space_low, 0),
                                     np.append(state_space_high, 4),
                                     dtype=np.float32)
        else:
            observer_obs_space = obs_space
        player_reward_shaping = None
        observer_reward_shaping = None
        temp_env.close()
        max_decor_steps = 0
        b_size = 20
        num_envs = 8
        max_episode_length = np.inf
        player_model_kwargs = dict(hidden_sizes=[256, 256])
        observer_model_kwargs = dict(hidden_sizes=[256, 256])
        player_q_model_kwargs = dict(hidden_sizes=[256, 256])
        observer_q_model_kwargs = dict(hidden_sizes=[256, 256])
        player_v_model_kwargs = dict(hidden_sizes=[256, 256])
        observer_v_model_kwargs = dict(hidden_sizes=[256, 256])
    if game == "halfcheetah":
        observer_act_space = Box(
            low=state_space_low[:int(len(state_space_low) / 2)],
            high=state_space_high[:int(len(state_space_high) / 2)])
    else:
        if serial:
            n_serial = int(len(state_space_high) / 2)
            observer_act_space = Discrete(2)
            observer_act_space.shape = (1, )
        else:
            if one_agent:
                observer_act_space = IntBox(
                    low=0,
                    high=player_act_space.n *
                    int(2**int(len(state_space_high) / 2)))
            else:
                observer_act_space = IntBox(low=0,
                                            high=int(2**int(
                                                len(state_space_high) / 2)))

    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        alt = False
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        if eval:
            eval_collector_cl = SerialEvalCollector
        else:
            eval_collector_cl = None
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        alt = False
        Sampler = CpuSampler
        if eval:
            eval_collector_cl = CpuEvalCollector
        else:
            eval_collector_cl = None
        print(
            f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing."
        )
    env_kwargs = dict(work_env=work_env,
                      env_name=env_name,
                      obs_spaces=[obs_space, observer_obs_space],
                      action_spaces=[player_act_space, observer_act_space],
                      serial=serial,
                      player_reward_shaping=player_reward_shaping,
                      observer_reward_shaping=observer_reward_shaping,
                      fully_obs=fully_obs,
                      rand_obs=rand_obs,
                      inc_player_last_act=inc_player_last_act,
                      max_episode_length=max_episode_length,
                      cont_act=cont_act)
    if eval:
        eval_env_kwargs = env_kwargs
        eval_max_steps = 1e4
        num_eval_envs = num_envs
    else:
        eval_env_kwargs = None
        eval_max_steps = None
        num_eval_envs = 0
    sampler = Sampler(
        EnvCls=CWTO_EnvWrapper,
        env_kwargs=env_kwargs,
        batch_T=b_size,
        batch_B=num_envs,
        max_decorrelation_steps=max_decor_steps,
        eval_n_envs=num_eval_envs,
        eval_CollectorCls=eval_collector_cl,
        eval_env_kwargs=eval_env_kwargs,
        eval_max_steps=eval_max_steps,
    )
    if game == "halfcheetah":
        player_algo = SAC()
        observer_algo = SACBeta()
        player = SacAgent(ModelCls=PiMlpModel,
                          QModelCls=QofMuMlpModel,
                          model_kwargs=player_model_kwargs,
                          q_model_kwargs=player_q_model_kwargs,
                          v_model_kwargs=player_v_model_kwargs)
        observer = SacAgentBeta(ModelCls=PiMlpModelBeta,
                                QModelCls=QofMuMlpModel,
                                model_kwargs=observer_model_kwargs,
                                q_model_kwargs=observer_q_model_kwargs,
                                v_model_kwargs=observer_v_model_kwargs)
    else:
        player_model = CWTO_LstmModel
        observer_model = CWTO_LstmModel

        player_algo = PPO()
        observer_algo = PPO()
        player = CWTO_LstmAgent(ModelCls=player_model,
                                model_kwargs=player_model_kwargs,
                                initial_model_state_dict=None)
        observer = CWTO_LstmAgent(ModelCls=observer_model,
                                  model_kwargs=observer_model_kwargs,
                                  initial_model_state_dict=None)
    if one_agent:
        agent = CWTO_AgentWrapper(player,
                                  observer,
                                  serial=serial,
                                  n_serial=n_serial,
                                  alt=alt,
                                  train_mask=train_mask,
                                  one_agent=one_agent,
                                  nplayeract=player_act_space.n)
    else:
        agent = CWTO_AgentWrapper(player,
                                  observer,
                                  serial=serial,
                                  n_serial=n_serial,
                                  alt=alt,
                                  train_mask=train_mask)

    if eval:
        RunnerCl = MinibatchRlEval
    else:
        RunnerCl = MinibatchRl

    runner = RunnerCl(player_algo=player_algo,
                      observer_algo=observer_algo,
                      agent=agent,
                      sampler=sampler,
                      n_steps=n_steps,
                      log_interval_steps=log_interval_steps,
                      affinity=affinity,
                      wandb_log=wandb_log,
                      alt_train=alt_train)
    config = dict(domain=game)
    if game == "halfcheetah":
        name = "sac_" + game
    else:
        name = "ppo_" + game
    log_dir = os.getcwd() + "/cwto_logs/" + name
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
    if save_models_to_wandb:
        agent.save_models_to_wandb()
    if eval_perf:
        eval_n_envs = 10
        eval_envs = [CWTO_EnvWrapper(**env_kwargs) for _ in range(eval_n_envs)]
        set_envs_seeds(eval_envs, make_seed())
        eval_collector = SerialEvalCollector(envs=eval_envs,
                                             agent=agent,
                                             TrajInfoCls=TrajInfo_obs,
                                             max_T=1000,
                                             max_trajectories=10,
                                             log_full_obs=True)
        traj_infos_player, traj_infos_observer = eval_collector.collect_evaluation(
            runner.get_n_itr())
        observations = []
        player_actions = []
        returns = []
        observer_actions = []
        for traj in traj_infos_player:
            observations.append(np.array(traj.Observations))
            player_actions.append(np.array(traj.Actions))
            returns.append(traj.Return)
        for traj in traj_infos_observer:
            observer_actions.append(
                np.array([
                    obs_action_translator(act, eval_envs[0].power_vec,
                                          eval_envs[0].obs_size)
                    for act in traj.Actions
                ]))

        # save results:
        open_obs = open('eval_observations.pkl', "wb")
        pickle.dump(observations, open_obs)
        open_obs.close()
        open_ret = open('eval_returns.pkl', "wb")
        pickle.dump(returns, open_ret)
        open_ret.close()
        open_pact = open('eval_player_actions.pkl', "wb")
        pickle.dump(player_actions, open_pact)
        open_pact.close()
        open_oact = open('eval_observer_actions.pkl', "wb")
        pickle.dump(observer_actions, open_oact)
        open_oact.close()
Ejemplo n.º 26
0
    def run(self, run_ID=0):
        config = self.getConfig()
        sampler = GpuSampler(EnvCls=make_env,
                             env_kwargs={
                                 "num_levels": config["num_levels"],
                                 "env": config['env']
                             },
                             CollectorCls=GpuResetCollector,
                             batch_T=256,
                             batch_B=config["envs_per_worker"],
                             max_decorrelation_steps=1000)

        optim_args = dict(weight_decay=config["l2_penalty"]
                          ) if "l2_penalty" in config else None

        algo = PPO(value_loss_coeff=0.5,
                   clip_grad_norm=0.5,
                   discount=config["discount"],
                   entropy_loss_coeff=config["entropy_bonus"],
                   gae_lambda=config["lambda"],
                   minibatches=config["minibatches_per_epoch"],
                   epochs=config["epochs_per_rollout"],
                   ratio_clip=config["ppo_clip"],
                   learning_rate=config["learning_rate"],
                   normalize_advantage=True,
                   optim_kwargs=optim_args)

        if config["arch"] == 'impala':
            agent = ImpalaAgent(
                model_kwargs={
                    "in_channels": [3, 16, 32],
                    "out_channels": [16, 32, 32],
                    "hidden_size": 256
                })
        elif config["arch"] == 'lstm':
            agent = NatureRecurrentAgent(model_kwargs={
                "hidden_sizes": [512],
                "lstm_size": 256
            })
        else:
            agent = OriginalNatureAgent(
                model_kwargs={
                    "batchNorm": config["batchNorm"],
                    "dropout": config["dropout"],
                    "augment_obs": config["augment_obs"],
                    "use_maxpool": config["maxpool"],
                    "hidden_sizes": config["hidden_sizes"],
                    "arch": config["arch"]
                })

        affinity = dict(cuda_idx=0, workers_cpus=list(range(8)))

        runner = MinibatchRl(algo=algo,
                             agent=agent,
                             sampler=sampler,
                             n_steps=config["total_timesteps"],
                             log_interval_steps=500,
                             affinity=affinity)
        log_dir = "./logs"
        name = config["name"]
        with logger_context(log_dir,
                            run_ID,
                            name,
                            config,
                            use_summary_writer=True,
                            override_prefix=False):
            runner.train()
        torch.save(agent.state_dict(), "./" + name + ".pt")
        wandb.save("./" + name + ".pt")
Ejemplo n.º 27
0
def train_holodeck_ppo(argv):

    # Create gif directory
    image_path = os.path.join(FLAGS.image_dir, FLAGS.name)
    if not os.path.exists(image_path):
        os.makedirs(image_path)

    # Load saved checkpoint
    if FLAGS.checkpoint is not None:
        checkpoint = torch.load(FLAGS.checkpoint)
        model_state_dict = checkpoint['agent_state_dict']
        optim_state_dict = checkpoint['optimizer_state_dict']
    else:
        model_state_dict = None
        optim_state_dict = None

    # Get environment info for agent
    env_kwargs = {
        'max_steps': FLAGS.eps_length,
        'gif_freq': FLAGS.gif_freq,
        'steps_per_action': FLAGS.steps_per_action,
        'image_dir': image_path,
        'viewport': FLAGS.viewport
    }
    if os.path.isfile(FLAGS.scenario):
        with open(FLAGS.scenario) as f:
            env_kwargs['scenario_cfg'] = json.load(f)
    else:
        env_kwargs['scenario_name'] = FLAGS.scenario
    env = HolodeckEnv(**env_kwargs)

    # Instantiate sampler
    sampler = SerialSampler(
        EnvCls=HolodeckEnv,
        batch_T=FLAGS.sampler_steps,
        batch_B=FLAGS.num_workers,
        env_kwargs=env_kwargs,
        max_decorrelation_steps=0,
    )

    # Instantiate algo and agent
    algo = PPO(initial_optim_state_dict=optim_state_dict)

    AgentClass = GaussianPgAgent \
        if env.is_action_continuous \
            else CategoricalPgAgent
    agent = AgentClass(initial_model_state_dict=model_state_dict,
                       ModelCls=PpoHolodeckModel,
                       model_kwargs={
                           'img_size': env.img_size,
                           'lin_size': env.lin_size,
                           'action_size': env.action_size,
                           'is_continuous': env.is_action_continuous,
                           'hidden_size': FLAGS.hidden_size
                       })

    # Instantiate runner
    runner = MinibatchRl(algo=algo,
                         agent=agent,
                         sampler=sampler,
                         n_steps=FLAGS.n_steps,
                         log_interval_steps=FLAGS.log_steps,
                         affinity=dict(cuda_idx=FLAGS.cuda_idx))

    # Run
    params = {
        'run_id': FLAGS.run_id,
        'cuda_idx': FLAGS.cuda_idx,
        'n_steps': FLAGS.n_steps,
        'log_steps': FLAGS.log_steps,
        'eps_length': FLAGS.eps_length,
        'sampler_steps': FLAGS.sampler_steps,
        'steps_per_action': FLAGS.steps_per_action,
        'num_workers': FLAGS.num_workers,
        'gif_freq': FLAGS.gif_freq,
        'hidden_size': FLAGS.hidden_size,
        'viewport': FLAGS.viewport,
        'name': FLAGS.name,
        'checkpoint': FLAGS.checkpoint,
        'image_dir': FLAGS.image_dir,
        'scenario': FLAGS.scenario
    }
    with logger_context(FLAGS.name,
                        FLAGS.run_id,
                        FLAGS.name,
                        snapshot_mode='all',
                        log_params=params):
        runner.train()
Ejemplo n.º 28
0
def build_and_train(game="doom_benchmark", run_ID=0, cuda_idx=None, n_parallel=-1,
                    n_env=-1, n_timestep=-1, sample_mode=None, total_steps=1):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))

    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        Sampler = SerialSampler  # (Ignores workers_cpus.)
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        Sampler = CpuSampler
        print(f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing.")
    elif sample_mode == "gpu":
        Sampler = GpuSampler
        print(f"Using GPU parallel sampler (agent in master), {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "alternating":
        Sampler = AlternatingSampler
        affinity["workers_cpus"] += affinity["workers_cpus"]  # (Double list)
        affinity["alternating"] = True  # Sampler will check for this.
        print(f"Using Alternating GPU parallel sampler, {gpu_cpu} for sampling and optimizing.")

        # !!!
        # COMMENT: to use alternating sampler here we had to comment lines 126-127 in action_server.py
        # if "bootstrap_value" in self.samples_np.agent:
        #     self.bootstrap_value_pair[alt][:] = self.agent.value(*agent_inputs_pair[alt])
        # otherwise it crashes
        # !!!

    sampler = Sampler(
        EnvCls=VizdoomEnv,
        env_kwargs=dict(game=game),
        batch_T=n_timestep,
        batch_B=n_env,
        max_decorrelation_steps=1000,
    )
    algo = PPO(
        learning_rate=0.0001,
        value_loss_coeff=0.5,
        entropy_loss_coeff=0.003,
        OptimCls=torch.optim.Adam,
        optim_kwargs=None,
        clip_grad_norm=4.,
        initial_optim_state_dict=None,
        gae_lambda=0.95,
        minibatches=1,
        epochs=1,
        ratio_clip=0.1,
        linear_lr_schedule=False,
        normalize_advantage=True,
    )

    agent = DoomLstmAgent()

    # Maybe AsyncRL could give better performance?
    # In the current version however PPO + AsyncRL does not seem to be working (not implemented)
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=total_steps,
        log_interval_steps=1e5,
        affinity=affinity,
    )
    config = dict(game=game)
    name = "ppo_" + game
    log_dir = "doom_ppo_rlpyt_wall_time_" + game

    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 29
0
def start_experiment(args):

    args_json = json.dumps(vars(args), indent=4)
    if not os.path.isdir(args.log_dir):
        os.makedirs(args.log_dir)
    with open(args.log_dir + '/arguments.json', 'w') as jsonfile:
        jsonfile.write(args_json)

    config = dict(env_id=args.env)

    if args.sample_mode == 'gpu':
        assert args.num_gpus > 0
        affinity = dict(cuda_idx=0, workers_cpus=list(range(args.num_cpus)))
        os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
    else:
        affinity = dict(workers_cpus=list(range(args.num_cpus)))

    # potentially reload models
    initial_optim_state_dict = None
    initial_model_state_dict = None
    if args.pretrain != 'None':
        os.system(f"find {args.log_dir} -name '*.json' -delete"
                  )  # clean up json files for video recorder
        checkpoint = torch.load(
            os.path.join(_RESULTS_DIR, args.pretrain, 'params.pkl'))
        initial_optim_state_dict = checkpoint['optimizer_state_dict']
        initial_model_state_dict = checkpoint['agent_state_dict']

    # ----------------------------------------------------- POLICY ----------------------------------------------------- #
    model_args = dict(curiosity_kwargs=dict(curiosity_alg=args.curiosity_alg))
    if args.curiosity_alg == 'icm':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs'][
            'prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs'][
            'forward_loss_wt'] = args.forward_loss_wt
    elif args.curiosity_alg == 'disagreement':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['ensemble_size'] = args.ensemble_size
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs'][
            'prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs'][
            'forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'ndigo':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['pred_horizon'] = args.pred_horizon
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['num_predictors'] = args.num_predictors
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'rnd':
        model_args['curiosity_kwargs'][
            'feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs'][
            'prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs'][
            'drop_probability'] = args.drop_probability
        model_args['curiosity_kwargs']['gamma'] = args.discount
        model_args['curiosity_kwargs']['device'] = args.sample_mode

    if args.env in _MUJOCO_ENVS:
        if args.lstm:
            agent = MujocoLstmAgent(
                initial_model_state_dict=initial_model_state_dict)
        else:
            agent = MujocoFfAgent(
                initial_model_state_dict=initial_model_state_dict)
    else:
        if args.lstm:
            agent = AtariLstmAgent(
                initial_model_state_dict=initial_model_state_dict,
                model_kwargs=model_args,
                no_extrinsic=args.no_extrinsic)
        else:
            agent = AtariFfAgent(
                initial_model_state_dict=initial_model_state_dict)

    # ----------------------------------------------------- LEARNING ALG ----------------------------------------------------- #
    if args.alg == 'ppo':
        if args.kernel_mu == 0.:
            kernel_params = None
        else:
            kernel_params = (args.kernel_mu, args.kernel_sigma)
        algo = PPO(
            discount=args.discount,
            learning_rate=args.lr,
            value_loss_coeff=args.v_loss_coeff,
            entropy_loss_coeff=args.entropy_loss_coeff,
            OptimCls=torch.optim.Adam,
            optim_kwargs=None,
            clip_grad_norm=args.grad_norm_bound,
            initial_optim_state_dict=
            initial_optim_state_dict,  # is None is not reloading a checkpoint
            gae_lambda=args.gae_lambda,
            minibatches=args.
            minibatches,  # if recurrent: batch_B needs to be at least equal, if not recurrent: batch_B*batch_T needs to be at least equal to this
            epochs=args.epochs,
            ratio_clip=args.ratio_clip,
            linear_lr_schedule=args.linear_lr,
            normalize_advantage=args.normalize_advantage,
            normalize_reward=args.normalize_reward,
            kernel_params=kernel_params,
            curiosity_type=args.curiosity_alg)
    elif args.alg == 'a2c':
        algo = A2C(discount=args.discount,
                   learning_rate=args.lr,
                   value_loss_coeff=args.v_loss_coeff,
                   entropy_loss_coeff=args.entropy_loss_coeff,
                   OptimCls=torch.optim.Adam,
                   optim_kwargs=None,
                   clip_grad_norm=args.grad_norm_bound,
                   initial_optim_state_dict=initial_optim_state_dict,
                   gae_lambda=args.gae_lambda,
                   normalize_advantage=args.normalize_advantage)

    # ----------------------------------------------------- SAMPLER ----------------------------------------------------- #

    # environment setup
    traj_info_cl = TrajInfo  # environment specific - potentially overriden below
    if 'mario' in args.env.lower():
        env_cl = mario_make
        env_args = dict(game=args.env,
                        no_extrinsic=args.no_extrinsic,
                        no_negative_reward=args.no_negative_reward,
                        normalize_obs=args.normalize_obs,
                        normalize_obs_steps=10000)
    elif 'deepmind' in args.env.lower():  # pycolab deepmind environments
        env_cl = deepmind_make
        traj_info_cl = PycolabTrajInfo
        env_args = dict(game=args.env,
                        no_extrinsic=args.no_extrinsic,
                        no_negative_reward=args.no_negative_reward,
                        normalize_obs=args.normalize_obs,
                        normalize_obs_steps=10000,
                        log_heatmaps=args.log_heatmaps,
                        logdir=args.log_dir,
                        obs_type=args.obs_type,
                        max_steps_per_episode=args.max_episode_steps)
    elif args.env in _MUJOCO_ENVS:
        env_cl = gym_make
        env_args = dict(id=args.env,
                        no_extrinsic=args.no_extrinsic,
                        no_negative_reward=args.no_negative_reward,
                        normalize_obs=False,
                        normalize_obs_steps=10000)
    elif args.env in _ATARI_ENVS:
        env_cl = AtariEnv
        traj_info_cl = AtariTrajInfo
        env_args = dict(
            game=args.env,
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000,
            downsampling_scheme='classical',
            record_freq=args.record_freq,
            record_dir=args.log_dir,
            horizon=args.max_episode_steps,
        )

    if args.sample_mode == 'gpu':
        if args.lstm:
            collector_class = GpuWaitResetCollector
        else:
            collector_class = GpuResetCollector
        sampler = GpuSampler(EnvCls=env_cl,
                             env_kwargs=env_args,
                             eval_env_kwargs=env_args,
                             batch_T=args.timestep_limit,
                             batch_B=args.num_envs,
                             max_decorrelation_steps=0,
                             TrajInfoCls=traj_info_cl,
                             eval_n_envs=args.eval_envs,
                             eval_max_steps=args.eval_max_steps,
                             eval_max_trajectories=args.eval_max_traj,
                             record_freq=args.record_freq,
                             log_dir=args.log_dir,
                             CollectorCls=collector_class)
    else:
        if args.lstm:
            collector_class = CpuWaitResetCollector
        else:
            collector_class = CpuResetCollector
        sampler = CpuSampler(
            EnvCls=env_cl,
            env_kwargs=env_args,
            eval_env_kwargs=env_args,
            batch_T=args.timestep_limit,  # timesteps in a trajectory episode
            batch_B=args.num_envs,  # environments distributed across workers
            max_decorrelation_steps=0,
            TrajInfoCls=traj_info_cl,
            eval_n_envs=args.eval_envs,
            eval_max_steps=args.eval_max_steps,
            eval_max_trajectories=args.eval_max_traj,
            record_freq=args.record_freq,
            log_dir=args.log_dir,
            CollectorCls=collector_class)

    # ----------------------------------------------------- RUNNER ----------------------------------------------------- #
    if args.eval_envs > 0:
        runner = MinibatchRlEval(algo=algo,
                                 agent=agent,
                                 sampler=sampler,
                                 n_steps=args.iterations,
                                 affinity=affinity,
                                 log_interval_steps=args.log_interval,
                                 log_dir=args.log_dir,
                                 pretrain=args.pretrain)
    else:
        runner = MinibatchRl(algo=algo,
                             agent=agent,
                             sampler=sampler,
                             n_steps=args.iterations,
                             affinity=affinity,
                             log_interval_steps=args.log_interval,
                             log_dir=args.log_dir,
                             pretrain=args.pretrain)

    with logger_context(args.log_dir,
                        config,
                        snapshot_mode="last",
                        use_summary_writer=True):
        runner.train()