Ejemplo n.º 1
0
def build_and_train(game="pong", run_ID=0):
    affinity = make_affinity(n_cpu_core=4,
                             n_gpu=2,
                             async_sample=True,
                             n_socket=1,
                             gpu_per_run=1,
                             sample_gpu_per_run=1)
    config = configs['ernbw']
    config['runner']['log_interval_steps'] = 1e5
    config['env']['game'] = game
    config["eval_env"]["game"] = config["env"]["game"]
    config["algo"]["n_step_return"] = 5
    wandb.config.update(config)
    sampler = AsyncGpuSampler(EnvCls=AtariEnv,
                              env_kwargs=config["env"],
                              CollectorCls=DbGpuWaitResetCollector,
                              TrajInfoCls=AtariTrajInfo,
                              eval_env_kwargs=config["eval_env"],
                              **config["sampler"])
    algo = PizeroCategoricalDQN(optim_kwargs=config["optim"],
                                **config["algo"])  # Run with defaults.
    agent = AtariCatDqnAgent(ModelCls=PizeroCatDqnModel,
                             model_kwargs=config["model"],
                             **config["agent"])
    runner = AsyncRlEvalWandb(algo=algo,
                              agent=agent,
                              sampler=sampler,
                              affinity=affinity,
                              **config["runner"])
    name = "dqn_" + game
    log_dir = "example"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 2
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = make_affinity(
        run_slot=0,
        n_cpu_core=os.cpu_count(),  # Use 16 cores across all experiments.
        n_gpu=1,  # Use 8 gpus across all experiments.
        gpu_per_run=1,
        sample_gpu_per_run=1,
        async_sample=True,
        optim_sample_share_gpu=True)

    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)
    config["eval_env"]["game"] = config["env"]["game"]

    sampler = GpuSampler(EnvCls=AtariEnv,
                         env_kwargs=config["env"],
                         CollectorCls=GpuWaitResetCollector,
                         TrajInfoCls=AtariTrajInfo,
                         eval_env_kwargs=config["eval_env"],
                         **config["sampler"])
    algo = CategoricalDQN(optim_kwargs=config["optim"], **config["algo"])
    agent = AtariCatDqnAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRlEval(algo=algo,
                             agent=agent,
                             sampler=sampler,
                             affinity=affinity,
                             **config["runner"])
    name = config["env"]["game"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 3
0
def build_and_train(game="pong", run_ID=0):
    # Change these inputs to match local machine and desired parallelism.
    affinity = make_affinity(
        run_slot=0,
        n_cpu_core=8,  # Use 16 cores across all experiments.
        n_gpu=2,  # Use 8 gpus across all experiments.
        gpu_per_run=1,
        sample_gpu_per_run=1,
        async_sample=True,
        optim_sample_share_gpu=False,
        # hyperthread_offset=24,  # If machine has 24 cores.
        # n_socket=2,  # Presume CPU socket affinity to lower/upper half GPUs.
        # gpu_per_run=2,  # How many GPUs to parallelize one run across.
        # cpu_per_run=1,
    )

    sampler = AsyncGpuSampler(
        EnvCls=AtariEnv,
        TrajInfoCls=AtariTrajInfo,
        env_kwargs=dict(game=game),
        batch_T=5,
        batch_B=36,
        max_decorrelation_steps=100,
        eval_env_kwargs=dict(game=game),
        eval_n_envs=2,
        eval_max_steps=int(10e3),
        eval_max_trajectories=4,
    )
    algo = DQN(
        replay_ratio=8,
        min_steps_learn=1e4,
        replay_size=int(1e5)
    )
    agent = AtariDqnAgent()
    runner = AsyncRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=2e6,
        log_interval_steps=1e4,
        affinity=affinity,
    )
    config = dict(game=game)
    name = "async_dqn_" + game
    log_dir = "async_dqn"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 4
0
def build_and_train(level="nav_maze_random_goal_01", run_ID=0, cuda_idx=None):
    config = configs['r2d1']
    config['eval_env'] = dict(level=level)
    config['env'] = dict(level=level)

    affinity = make_affinity(
        run_slot=0,
        n_cpu_core=4,  # Use 16 cores across all experiments.
        n_gpu=1,  # Use 8 gpus across all experiments.
        hyperthread_offset=6,  # If machine has 24 cores.
        n_socket=2,  # Presume CPU socket affinity to lower/upper half GPUs.
        gpu_per_run=1,  # How many GPUs to parallelize one run across.
    )

    # sampler = GpuSampler(
    #     EnvCls=DeepmindLabEnv,
    #     env_kwargs=config['env'],
    #     eval_env_kwargs=config['eval_env'],
    #     CollectorCls=GpuWaitResetCollector,
    #     TrajInfoCls=LabTrajInfo,
    #     **config["sampler"]
    # )
    sampler = SerialSampler(
        EnvCls=DeepmindLabEnv,
        env_kwargs=config['env'],
        eval_env_kwargs=config['env'],
        batch_T=16,  # Four time-steps per sampler iteration.
        batch_B=1,
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )
    algo = R2D1(optim_kwargs=config["optim"], **config["algo"])
    agent = AtariR2d1Agent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = "lab_dqn_" + level
    log_dir = "lab_example_2"
    with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
        runner.train()
Ejemplo n.º 5
0
def build_and_train(game="pong", run_ID=0):
    # Seems like we should be able to skip the intermediate step of the code,
    # but so far have just always run that way.
    # Change these inputs to match local machine and desired parallelism.
    affinity = make_affinity(
        run_slot=0,
        n_cpu_core=16,  # Use 16 cores across all experiments.
        n_gpu=8,  # Use 8 gpus across all experiments.
        hyperthread_offset=24,  # If machine has 24 cores.
        n_socket=2,  # Presume CPU socket affinity to lower/upper half GPUs.
        gpu_per_run=2,  # How many GPUs to parallelize one run across.
        # cpu_per_run=1,
    )

    sampler = GpuSampler(
        EnvCls=AtariEnv,
        TrajInfoCls=AtariTrajInfo,
        env_kwargs=dict(game=game),
        CollectorCls=GpuWaitResetCollector,
        batch_T=5,
        batch_B=16,
        max_decorrelation_steps=400,
    )
    algo = A2C()  # Run with defaults.
    agent = AtariFfAgent()
    runner = SyncRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e5,
        affinity=affinity,
    )
    config = dict(game=game)
    name = "a2c_" + game
    log_dir = "example_7"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 6
0
def build_and_train(slot_affinity_code=None,
                    log_dir='./data',
                    run_ID=0,
                    serial_mode=True,
                    snapshot: Dict = None,
                    config_update: Dict = None):
    # default configuration
    config = dict(
        sac_kwargs=dict(learning_rate=3e-4,
                        batch_size=512,
                        replay_size=1e6,
                        discount=0.95),
        ppo_kwargs=dict(minibatches=4,
                        learning_rate=2e-1,
                        discount=0.95,
                        linear_lr_schedule=False,
                        OptimCls=SGD,
                        optim_kwargs=dict(momentum=0.9),
                        gae_lambda=0.95,
                        ratio_clip=0.02,
                        entropy_loss_coeff=0,
                        clip_grad_norm=100),
        td3_kwargs=dict(),
        sampler_kwargs=dict(batch_T=32,
                            batch_B=5,
                            env_kwargs=dict(id="TrackEnv-v0"),
                            eval_n_envs=4,
                            eval_max_steps=1e5,
                            eval_max_trajectories=10),
        sac_agent_kwargs=dict(ModelCls=PiMCPModel,
                              QModelCls=QofMCPModel,
                              model_kwargs=dict(freeze_primitives=False)),
        ppo_agent_kwargs=dict(ModelCls=PPOMcpModel,
                              model_kwargs=dict(freeze_primitives=False)),
        runner_kwargs=dict(n_steps=1e9, log_interval_steps=1e5),
        snapshot=snapshot,
        algo='sac')

    # try to update default config
    try:
        variant = load_variant(log_dir)
        config = update_config(config, variant)
    except FileNotFoundError:
        if config_update is not None:
            config = update_config(config, config_update)

    # select correct affinity for configuration
    if slot_affinity_code is None:
        num_cpus = multiprocessing.cpu_count(
        )  # divide by two due to hyperthreading
        num_gpus = len(GPUtil.getGPUs())
        if config['algo'] == 'sac' and not serial_mode:
            affinity = make_affinity(n_cpu_core=num_cpus,
                                     n_gpu=num_gpus,
                                     async_sample=True,
                                     set_affinity=False)
        elif config['algo'] == 'ppo' and not serial_mode:
            affinity = dict(alternating=True,
                            cuda_idx=0,
                            workers_cpus=2 * list(range(num_cpus)),
                            async_sample=True)
        else:
            affinity = make_affinity(n_cpu_core=num_cpus // 2, n_gpu=num_gpus)
    else:
        affinity = affinity_from_code(slot_affinity_code)

    # continue training from saved state_dict if provided
    agent_state_dict = optimizer_state_dict = None
    if config['snapshot'] is not None:
        agent_state_dict = config['snapshot']['agent_state_dict']
        optimizer_state_dict = config['snapshot']['optimizer_state_dict']

    if config['algo'] == 'ppo':
        AgentClass = McpPPOAgent
        AlgoClass = PPO
        RunnerClass = MinibatchRlEval
        SamplerClass = CpuSampler if serial_mode else AlternatingSampler
        algo_kwargs = config['ppo_kwargs']
        agent_kwargs = config['ppo_agent_kwargs']
    elif config['algo'] == 'sac':
        AgentClass = SacAgentSafeLoad
        AlgoClass = SAC
        algo_kwargs = config['sac_kwargs']
        agent_kwargs = config['sac_agent_kwargs']
        if serial_mode:
            SamplerClass = SerialSampler
            RunnerClass = MinibatchRlEval
        else:
            SamplerClass = AsyncCpuSampler
            RunnerClass = AsyncRlEval
            affinity['cuda_idx'] = 0
    else:
        raise NotImplementedError('algorithm not implemented')

    # make debugging easier in serial mode
    if serial_mode:
        config['runner_kwargs']['log_interval_steps'] = 1e3
        config['sac_kwargs']['min_steps_learn'] = 0

    sampler = SamplerClass(**config['sampler_kwargs'],
                           EnvCls=make,
                           eval_env_kwargs=config['sampler_kwargs']
                           ['env_kwargs'])
    algo = AlgoClass(**algo_kwargs,
                     initial_optim_state_dict=optimizer_state_dict)
    agent = AgentClass(initial_model_state_dict=agent_state_dict,
                       **agent_kwargs)
    runner = RunnerClass(**config['runner_kwargs'],
                         algo=algo,
                         agent=agent,
                         sampler=sampler,
                         affinity=affinity)
    config_logger(log_dir,
                  name='parkour-training',
                  snapshot_mode='best',
                  log_params=config)
    # start training
    runner.train()
Ejemplo n.º 7
0
def build_and_train(id="SurfaceCode-v0", name='run', log_dir='./logs', async_mode=False, restore_path=None):
    num_cpus = multiprocessing.cpu_count()
    num_gpus = 0 #len(GPUtil.getGPUs())
    # print(f"num cpus {num_cpus} num gpus {num_gpus}")
    if num_gpus == 0:
        # affinity = make_affinity(n_cpu_core=num_cpus // 2, n_gpu=0, set_affinity=False)
        affinity = make_affinity(n_cpu_core=num_cpus//2, cpu_per_run=num_cpus//2, n_gpu=num_gpus, async_sample=False,
                                 set_affinity=True)
        affinity['workers_cpus'] = tuple(range(num_cpus))
        affinity['master_torch_threads'] = 28
    else:
        affinity = make_affinity(
            run_slot=0,
            n_cpu_core=num_cpus,  # Use 16 cores across all experiments.
            cpu_per_run=num_cpus,#24,
            n_gpu=num_gpus,  # Use 8 gpus across all experiments.
            async_sample=async_mode,
            alternating=False,
            set_affinity=True,
        )
    # num_worker_cpus = len(affinity.sampler['workers_cpus'])
    print(f'affinity: {affinity}')
    agent_state_dict = optim_state_dict = None
    if restore_path is not None:
        state_dict = torch.load(restore_path, map_location='cpu')
        agent_state_dict = state_dict['agent_state_dict']
        optim_state_dict = state_dict['optimizer_state_dict']
    if async_mode:
        SamplerCls = AsyncCpuSampler
        RunnerCls = AsyncRlEval
        algo = AsyncVMPO(batch_B=64, batch_T=40, discrete_actions=True, T_target_steps=40, epochs=1, initial_optim_state_dict=optim_state_dict)
        sampler_kwargs=dict(CollectorCls=QecDbCpuResetCollector, eval_CollectorCls=QecCpuEvalCollector)
    else:
        SamplerCls = CpuSampler
        # SamplerCls = SerialSampler
        # RunnerCls = MinibatchRlEval
        RunnerCls = QECSynchronousRunner
        algo = VMPO(discrete_actions=True, epochs=4, minibatches=100, initial_optim_state_dict=optim_state_dict, epsilon_alpha=0.01)
        sampler_kwargs=dict(CollectorCls=QecCpuResetCollector, eval_CollectorCls=QecCpuEvalCollector)

    env_kwargs = dict(error_model='DP', error_rate=0.005, volume_depth=1)

    sampler = SamplerCls(
        EnvCls=make_qec_env,
        env_kwargs=env_kwargs,
        batch_T=40,
        batch_B=64 * 100,
        max_decorrelation_steps=50,
        eval_env_kwargs=env_kwargs,
        eval_n_envs=num_cpus,
        eval_max_steps=int(1e6),
        eval_max_trajectories=num_cpus,
        TrajInfoCls=EnvInfoTrajInfo,
        **sampler_kwargs
    )
    agent = MultiActionVmpoAgent(ModelCls=MultiActionRecurrentQECModel,
                                 model_kwargs=dict(linear_value_output=False),
                                 initial_model_state_dict=agent_state_dict)
    runner = RunnerCls(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=1e10,
        log_interval_steps=1e6,
        affinity=affinity,
    )
    config = dict(game=id)
    config_logger(log_dir, name=name, snapshot_mode='last', log_params=config)
    runner.train()
def build_and_train(game="aaai_multi", run_ID=0):
    # Change these inputs to match local machine and desired parallelism.
    affinity = make_affinity(
        run_slot=0,
        n_cpu_core=8,  # Use 16 cores across all experiments.
        n_gpu=1,  # Use 8 gpus across all experiments.
        sample_gpu_per_run=1,
        async_sample=True,
        optim_sample_share_gpu=True
        # hyperthread_offset=24,  # If machine has 24 cores.
        # n_socket=2,  # Presume CPU socket affinity to lower/upper half GPUs.
        # gpu_per_run=2,  # How many GPUs to parallelize one run across.
        # cpu_per_run=1,
    )

    train_conf = PytConfig([
        Path(JSONS_FOLDER, 'configs', '2v2', 'all_equal.json'),
        Path(JSONS_FOLDER, 'configs', '2v2', 'more_horizontally.json'),
        Path(JSONS_FOLDER, 'configs', '2v2', 'more_vertically.json'),
        Path(JSONS_FOLDER, 'configs', '2v2', 'more_from_west.json'),
        Path(JSONS_FOLDER, 'configs', '2v2', 'more_from_east.json'),
        Path(JSONS_FOLDER, 'configs', '2v2', 'more_from_north.json'),
        Path(JSONS_FOLDER, 'configs', '2v2', 'more_from_south.json'),
    ])

    eval_conf = PytConfig({
        'all_equal': Path(JSONS_FOLDER, 'configs', '2v2', 'all_equal.json'),
        'more_horizontally': Path(JSONS_FOLDER, 'configs', '2v2', 'more_horizontally.json'),
        'more_vertically': Path(JSONS_FOLDER, 'configs', '2v2', 'more_vertically.json'),
        'more_south': Path(JSONS_FOLDER, 'configs', '2v2', 'more_from_south.json'),
        'more_east': Path(JSONS_FOLDER, 'configs', '2v2', 'more_from_east.json')
    })

    sampler = AsyncGpuSampler(
        EnvCls=Rlpyt_env,
        TrajInfoCls=AaaiTrajInfo,
        env_kwargs={
            'pyt_conf': train_conf,
            'max_steps': 3000
        },
        batch_T=8,
        batch_B=8,
        max_decorrelation_steps=100,
        eval_env_kwargs={
            'pyt_conf': eval_conf,
            'max_steps': 3000
        },
        eval_max_steps=24100,
        eval_n_envs=2,
    )
    algo = DQN(
        replay_ratio=1024,
        double_dqn=True,
        prioritized_replay=True,
        min_steps_learn=5000,
        learning_rate=0.0001,
        target_update_tau=1.0,
        target_update_interval=1000,
        eps_steps=5e4,
        batch_size=512,
        pri_alpha=0.6,
        pri_beta_init=0.4,
        pri_beta_final=1.,
        pri_beta_steps=int(7e4),
        replay_size=int(1e6),
        clip_grad_norm=1.0,
        updates_per_sync=6
    )
    agent = DqnAgent(ModelCls=Frap)
    runner = AsyncRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        log_interval_steps=1000,
        affinity=affinity,
        n_steps=6e5
    )

    config = dict(game=game)
    name = "frap_" + game
    log_dir = Path(PROJECT_ROOT, "saved", "rlpyt", "multi", "frap")

    save_path = Path(log_dir, 'run_{}'.format(run_ID))
    for f in save_path.glob('**/*'):
        print(f)
        f.unlink()

    with logger_context(str(log_dir), run_ID, name, config,
                        snapshot_mode='last', use_summary_writer=True, override_prefix=True):
        runner.train()
Ejemplo n.º 9
0
def build_and_train(args, game="", run_ID=0, config=None):
    """
    1. Parse the args object into dictionaries understood by rlpyt
    """
    config['env']['id'] = args.env_name
    config["eval_env"]["id"] = args.env_name

    config["eval_env"]["horizon"] = args.horizon
    config["env"]["horizon"] = args.horizon

    if 'procgen' in args.env_name:
        for k, v in vars(args).items():
            if args.env_name.split('-')[1] in k:
                config['env'][k] = v

    config['model']['frame_stack'] = args.frame_stack
    config['model']['nce_loss'] = args.nce_loss
    config['model']['algo'] = args.algo
    config['model']['env_name'] = args.env_name
    config['model']['dueling'] = args.dueling == 1
    config['algo']['double_dqn'] = args.double_dqn == 1
    config['algo']['prioritized_replay'] = args.prioritized_replay == 1
    config['algo']['n_step_return'] = args.n_step_return
    config['algo']['learning_rate'] = args.learning_rate

    config['runner']['log_interval_steps'] = args.log_interval_steps
    config['cmd_args'] = vars(args)
    """
    2. Create the CatDQN (C51) agent from custom implementation
    """

    agent = AtariCatDqnAgent(ModelCls=AtariCatDqnModel_nce,
                             model_kwargs=config["model"],
                             **config["agent"])
    algo = CategoricalDQN_nce(args=config['cmd_args'],
                              ReplayBufferCls=None,
                              optim_kwargs=config["optim"],
                              **config["algo"])

    if args.mode == 'parallel':
        affinity = make_affinity(n_cpu_core=args.n_cpus,
                                 n_gpu=args.n_gpus,
                                 n_socket=1
                                 # hyperthread_offset=0
                                 )
        """
        Some architecture require the following block to be uncommented. Try with and without.
        This is here to allow scheduling of non-sequential CPU IDs
        """
        # import psutil
        # psutil.Process().cpu_affinity([])
        # cpus = tuple(psutil.Process().cpu_affinity())
        # affinity['all_cpus'] = affinity['master_cpus'] = cpus
        # affinity['workers_cpus'] = tuple([tuple([x]) for x in cpus+cpus])
        # env_kwargs = config['env']

        sampler = GpuSampler(EnvCls=make_env,
                             env_kwargs=config["env"],
                             CollectorCls=GpuWaitResetCollector,
                             TrajInfoCls=AtariTrajInfo,
                             eval_env_kwargs=config["eval_env"],
                             **config["sampler"])
        """
        If you don't have a GPU, use the CpuSampler
        """
        # sampler = CpuSampler(
        #             EnvCls=AtariEnv if args.game is not None else make_env,
        #             env_kwargs=config["env"],
        #             CollectorCls=CpuWaitResetCollector,
        #             TrajInfoCls=AtariTrajInfo,
        #             eval_env_kwargs=config["eval_env"],
        #             **config["sampler"]
        #         )

    elif args.mode == 'serial':
        affinity = make_affinity(
            n_cpu_core=1,  # Use 16 cores across all experiments.
            n_gpu=args.n_gpus,  # Use 8 gpus across all experiments.
            n_socket=1,
        )
        """
        Some architecture require the following block to be uncommented. Try with and without.
        """
        # import psutil
        # psutil.Process().cpu_affinity([])
        # cpus = tuple(psutil.Process().cpu_affinity())
        # affinity['all_cpus'] = affinity['master_cpus'] = cpus
        # affinity['workers_cpus'] = tuple([tuple([x]) for x in cpus+cpus])
        # env_kwargs = config['env']

        sampler = SerialSampler(
            EnvCls=make_env,
            env_kwargs=config["env"],
            # CollectorCls=SerialEvalCollector,
            TrajInfoCls=AtariTrajInfo,
            eval_env_kwargs=config["eval_env"],
            **config["sampler"])
    """
    3. Bookkeeping, setting up Comet.ml experiments, etc
    """
    folders_name = [args.output_dir, args.env_name, 'run_' + args.run_ID]
    path = os.path.join(*folders_name)
    os.makedirs(path, exist_ok=True)

    experiment = Experiment(api_key='your_key',
                            auto_output_logging=False,
                            project_name='driml',
                            workspace="your_workspace",
                            disabled=True)
    experiment.add_tag('C51+DIM' if (
        args.lambda_LL > 0 or args.lambda_LG > 0 or args.lambda_GL > 0
        or args.lambda_GG > 0) else 'C51')
    experiment.set_name(args.experiment_name)
    experiment.log_parameters(config)

    MinibatchRlEval.TF_logger = Logger(path,
                                       use_TFX=True,
                                       params=config,
                                       comet_experiment=experiment,
                                       disable_local=True)
    MinibatchRlEval.log_diagnostics = log_diagnostics_custom
    MinibatchRlEval._log_infos = _log_infos
    MinibatchRlEval.evaluate_agent = evaluate_agent
    """
    4. Define the runner as minibatch
    """
    runner = MinibatchRlEval(algo=algo,
                             agent=agent,
                             sampler=sampler,
                             affinity=affinity,
                             **config["runner"])

    runner.algo.opt_info_fields = tuple(
        list(runner.algo.opt_info_fields) + ['lossNCE'] +
        ['action%d' % i for i in range(15)])
    name = args.mode + "_value_based_nce_" + args.env_name
    log_dir = os.path.join(args.output_dir, args.env_name)
    logger.set_snapshot_gap(args.weight_save_interval //
                            config['runner']['log_interval_steps'])
    """
    6. Run the experiment and optionally save network weights
    """

    with experiment.train():
        with logger_context(
                log_dir,
                run_ID,
                name,
                config,
                snapshot_mode=(
                    'last' if args.weight_save_interval == -1 else 'gap'
                )):  # set 'all' to save every it, 'gap' for every X it
            runner.train()
Ejemplo n.º 10
0
def start_experiment(args):

    args_json = json.dumps(vars(args), indent=4)
    if not os.path.isdir(args.log_dir):
        os.makedirs(args.log_dir)
    with open(args.log_dir + '/arguments.json', 'w') as jsonfile:
        jsonfile.write(args_json)
    with open(args.log_dir + '/git.txt', 'w') as git_file:
        branch = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip().decode('utf-8')
        commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip().decode('utf-8')
        git_file.write('{}/{}'.format(branch, commit))

    config = dict(env_id=args.env)
    
    if args.sample_mode == 'gpu':
        # affinity = dict(num_gpus=args.num_gpus, workers_cpus=list(range(args.num_cpus)))
        if args.num_gpus > 0:
            # import ipdb; ipdb.set_trace()
            affinity = make_affinity(
                run_slot=0,
                n_cpu_core=args.num_cpus,  # Use 16 cores across all experiments.
                n_gpu=args.num_gpus,  # Use 8 gpus across all experiments.
                # contexts_per_gpu=2,
                # hyperthread_offset=72,  # If machine has 24 cores.
                # n_socket=2,  # Presume CPU socket affinity to lower/upper half GPUs.
                gpu_per_run=args.gpu_per_run,  # How many GPUs to parallelize one run across.

                # cpu_per_run=1,
            )
            print('Make multi-gpu affinity')
        else:
            affinity = dict(cuda_idx=0, workers_cpus=list(range(args.num_cpus)))
            os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
    else:
        affinity = dict(workers_cpus=list(range(args.num_cpus)))
    
    # potentially reload models
    initial_optim_state_dict = None
    initial_model_state_dict = None
    if args.pretrain != 'None':
        os.system(f"find {args.log_dir} -name '*.json' -delete") # clean up json files for video recorder
        checkpoint = torch.load(os.path.join(_RESULTS_DIR, args.pretrain, 'params.pkl'))
        initial_optim_state_dict = checkpoint['optimizer_state_dict']
        initial_model_state_dict = checkpoint['agent_state_dict']

    # ----------------------------------------------------- POLICY ----------------------------------------------------- #
    model_args = dict(curiosity_kwargs=dict(curiosity_alg=args.curiosity_alg), curiosity_step_kwargs=dict())
    if args.curiosity_alg =='icm':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['forward_model'] = args.forward_model
        model_args['curiosity_kwargs']['feature_space'] = args.feature_space
    elif args.curiosity_alg == 'micm':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['forward_model'] = args.forward_model
        model_args['curiosity_kwargs']['ensemble_mode'] = args.ensemble_mode
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'disagreement':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['ensemble_size'] = args.ensemble_size
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['forward_loss_wt'] = args.forward_loss_wt
        model_args['curiosity_kwargs']['device'] = args.sample_mode
        model_args['curiosity_kwargs']['forward_model'] = args.forward_model
    elif args.curiosity_alg == 'ndigo':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['pred_horizon'] = args.pred_horizon
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['batch_norm'] = args.batch_norm
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    elif args.curiosity_alg == 'rnd':
        model_args['curiosity_kwargs']['feature_encoding'] = args.feature_encoding
        model_args['curiosity_kwargs']['prediction_beta'] = args.prediction_beta
        model_args['curiosity_kwargs']['drop_probability'] = args.drop_probability
        model_args['curiosity_kwargs']['gamma'] = args.discount
        model_args['curiosity_kwargs']['device'] = args.sample_mode
    
    if args.curiosity_alg != 'none':
        model_args['curiosity_step_kwargs']['curiosity_step_minibatches'] = args.curiosity_step_minibatches

    if args.env in _MUJOCO_ENVS:
        if args.lstm:
            agent = MujocoLstmAgent(initial_model_state_dict=initial_model_state_dict)
        else:
            agent = MujocoFfAgent(initial_model_state_dict=initial_model_state_dict)
    else:
        if args.lstm:            
            agent = AtariLstmAgent(
                        initial_model_state_dict=initial_model_state_dict,
                        model_kwargs=model_args,
                        no_extrinsic=args.no_extrinsic,
                        dual_model=args.dual_model,
                        )
        else:
            agent = AtariFfAgent(initial_model_state_dict=initial_model_state_dict,
                model_kwargs=model_args,
                no_extrinsic=args.no_extrinsic,
                dual_model=args.dual_model)

    # ----------------------------------------------------- LEARNING ALG ----------------------------------------------------- #
    if args.alg == 'ppo':
        algo = PPO(
                discount=args.discount,
                learning_rate=args.lr,
                value_loss_coeff=args.v_loss_coeff,
                entropy_loss_coeff=args.entropy_loss_coeff,
                OptimCls=torch.optim.Adam,
                optim_kwargs=None,
                clip_grad_norm=args.grad_norm_bound,
                initial_optim_state_dict=initial_optim_state_dict, # is None is not reloading a checkpoint
                gae_lambda=args.gae_lambda,
                minibatches=args.minibatches, # if recurrent: batch_B needs to be at least equal, if not recurrent: batch_B*batch_T needs to be at least equal to this
                epochs=args.epochs,
                ratio_clip=args.ratio_clip,
                linear_lr_schedule=args.linear_lr,
                normalize_advantage=args.normalize_advantage,
                normalize_reward=args.normalize_reward,
                curiosity_type=args.curiosity_alg,
                policy_loss_type=args.policy_loss_type
                )
    elif args.alg == 'a2c':
        algo = A2C(
                discount=args.discount,
                learning_rate=args.lr,
                value_loss_coeff=args.v_loss_coeff,
                entropy_loss_coeff=args.entropy_loss_coeff,
                OptimCls=torch.optim.Adam,
                optim_kwargs=None,
                clip_grad_norm=args.grad_norm_bound,
                initial_optim_state_dict=initial_optim_state_dict,
                gae_lambda=args.gae_lambda,
                normalize_advantage=args.normalize_advantage
                )

    # ----------------------------------------------------- SAMPLER ----------------------------------------------------- #

    # environment setup
    traj_info_cl = TrajInfo # environment specific - potentially overriden below
    if 'mario' in args.env.lower():
        env_cl = mario_make
        env_args = dict(
            game=args.env,  
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000
            )
    elif args.env in _PYCOLAB_ENVS:
        env_cl = deepmind_make
        traj_info_cl = PycolabTrajInfo
        env_args = dict(
            game=args.env,
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000,
            log_heatmaps=args.log_heatmaps,
            logdir=args.log_dir,
            obs_type=args.obs_type,
            grayscale=args.grayscale,
            max_steps_per_episode=args.max_episode_steps
            )
    elif args.env in _MUJOCO_ENVS:
        env_cl = gym_make
        env_args = dict(
            id=args.env, 
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=False,
            normalize_obs_steps=10000
            )
    elif args.env in _ATARI_ENVS:
        env_cl = AtariEnv
        traj_info_cl = AtariTrajInfo
        env_args = dict(
            game=args.env, 
            no_extrinsic=args.no_extrinsic,
            no_negative_reward=args.no_negative_reward,
            normalize_obs=args.normalize_obs,
            normalize_obs_steps=10000,
            downsampling_scheme='classical',
            record_freq=args.record_freq,
            record_dir=args.log_dir,
            horizon=args.max_episode_steps,
            score_multiplier=args.score_multiplier,
            repeat_action_probability=args.repeat_action_probability,
            fire_on_reset=args.fire_on_reset
            )

    if args.sample_mode == 'gpu':
        if args.lstm:
            collector_class = GpuWaitResetCollector
        else:
            collector_class = GpuResetCollector
        sampler = GpuSampler(
            EnvCls=env_cl,
            env_kwargs=env_args,
            eval_env_kwargs=env_args,
            batch_T=args.timestep_limit,
            batch_B=args.num_envs,
            max_decorrelation_steps=0,
            TrajInfoCls=traj_info_cl,
            eval_n_envs=args.eval_envs,
            eval_max_steps=args.eval_max_steps,
            eval_max_trajectories=args.eval_max_traj,
            record_freq=args.record_freq,
            log_dir=args.log_dir,
            CollectorCls=collector_class
        )
    else:
        if args.lstm:
            collector_class = CpuWaitResetCollector
        else:
            collector_class = CpuResetCollector
        sampler = CpuSampler(
            EnvCls=env_cl,
            env_kwargs=env_args,
            eval_env_kwargs=env_args,
            batch_T=args.timestep_limit, # timesteps in a trajectory episode
            batch_B=args.num_envs, # environments distributed across workers
            max_decorrelation_steps=0,
            TrajInfoCls=traj_info_cl,
            eval_n_envs=args.eval_envs,
            eval_max_steps=args.eval_max_steps,
            eval_max_trajectories=args.eval_max_traj,
            record_freq=args.record_freq,
            log_dir=args.log_dir,
            CollectorCls=collector_class
            )

    # ----------------------------------------------------- RUNNER ----------------------------------------------------- #     
    if args.eval_envs > 0:
        runner = (MinibatchRlEval if args.num_gpus <= 1 else SyncRlEval)(
            algo=algo,
            agent=agent,
            sampler=sampler,
            n_steps=args.iterations,
            affinity=affinity,
            log_interval_steps=args.log_interval,
            log_dir=args.log_dir,
            pretrain=args.pretrain
            )
    else:
        runner = (MinibatchRl if args.num_gpus <= 1 else SyncRl)(
            algo=algo,
            agent=agent,
            sampler=sampler,
            n_steps=args.iterations,
            affinity=affinity,
            log_interval_steps=args.log_interval,
            log_dir=args.log_dir,
            pretrain=args.pretrain
            )

    with logger_context(args.log_dir, config, snapshot_mode="last", use_summary_writer=True):
        runner.train()
def build_and_train(id="SurfaceCode-v0", name='run', log_dir='./logs'):
    # Change these inputs to match local machine and desired parallelism.
    # affinity = make_affinity(
    #     n_cpu_core=24,  # Use 16 cores across all experiments.
    #     n_gpu=1,  # Use 8 gpus across all experiments.
    #     async_sample=True,
    #     set_affinity=True
    # )
    # affinity['optimizer'][0]['cuda_idx'] = 1
    num_cpus = multiprocessing.cpu_count()
    affinity = make_affinity(n_cpu_core=num_cpus//2, cpu_per_run=num_cpus//2, n_gpu=0, async_sample=False,
                                 set_affinity=True)
    affinity['workers_cpus'] = tuple(range(num_cpus))
    affinity['master_torch_threads'] = 28
    # env_kwargs = dict(id='SurfaceCode-v0', error_model='X', volume_depth=5)
    state_dict = None # torch.load('./logs/run_29/params.pkl', map_location='cpu')
    agent_state_dict = None #state_dict['agent_state_dict']['model']
    optim_state_dict = None #state_dict['optimizer_state_dict']

    # sampler = AsyncCpuSampler(
    sampler = CpuSampler(
        # sampler=SerialSampler(
        EnvCls=make_qec_env,
        # TrajInfoCls=AtariTrajInfo,
        env_kwargs=dict(error_rate=0.005, error_model='DP'),
        batch_T=10,
        batch_B=num_cpus * 10,
        max_decorrelation_steps=100,
        eval_env_kwargs=dict(error_rate=0.005, error_model='DP', fixed_episode_length=5000),
        eval_n_envs=num_cpus,
        eval_max_steps=int(1e6),
        eval_max_trajectories=num_cpus,
        TrajInfoCls=EnvInfoTrajInfo
    )
    algo = DQN(
        replay_ratio=8,
        learning_rate=1e-5,
        min_steps_learn=1e4,
        replay_size=int(5e4),
        batch_size=32,
        double_dqn=True,
        # target_update_tau=0.002,
        target_update_interval=5000,
        ReplayBufferCls=UniformReplayBuffer,
        initial_optim_state_dict=optim_state_dict,
        eps_steps=2e6,
    )
    agent = AtariDqnAgent(model_kwargs=dict(channels=[32, 64, 64],
                                            kernel_sizes=[3, 2, 2],
                                            strides=[2, 1, 1],
                                            paddings=[0, 0, 0],
                                            fc_sizes=[512, ],
                                            dueling=True),
                          ModelCls=QECModel,
                          eps_init=1,
                          eps_final=0.02,
                          eps_itr_max=int(5e6),
                          eps_eval=0,
                          initial_model_state_dict=agent_state_dict)
    # agent = DqnAgent(ModelCls=FfModel)
    runner = QECSynchronousRunner(
        # runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=1e9,
        log_interval_steps=3e5,
        affinity=affinity,
    )
    config = dict(game=id)
    config_logger(log_dir, name=name, snapshot_mode='last', log_params=config)
    # with logger_context(log_dir, run_ID, name, config):
    runner.train()