Ejemplo n.º 1
0
def build_and_train(game="pong", run_ID=0, cuda_idx=None):
    sampler = SerialSampler(
        EnvCls=AtariEnv,
        env_kwargs=dict(game=game),
        CollectorCls=ResetCollector,
        eval_env_kwargs=dict(game=game),
        batch_T=4,  # Four time-steps per sampler iteration.
        batch_B=1,
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )
    algo = DQN(min_steps_learn=1e3)  # Run with defaults.
    agent = AtariDqnAgent()
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e3,
        affinity=dict(cuda_idx=cuda_idx),
    )
    config = dict(game=game)
    name = "dqn_" + game
    log_dir = "example_1"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 2
0
def build_and_train(game="academy_empty_goal_close", run_ID=0, cuda_idx=None):
    sampler = SerialSampler(
        EnvCls=create_single_football_env,
        env_kwargs=dict(game=game),
        eval_env_kwargs=dict(game=game),
        batch_T=4,  # Four time-steps per sampler iteration.
        batch_B=1,
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )
    algo = DQN(min_steps_learn=1e3)  # Run with defaults.
    agent = AtariDqnAgent()
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e3,
        affinity=dict(cuda_idx=cuda_idx),
    )
    config = dict(game=game)
    name = "dqn_" + game
    log_dir = "example_1"
    with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
        runner.train()
Ejemplo n.º 3
0
def build_and_train(game="pong", run_ID=0, cuda_idx=None):
    sampler = SerialSampler(
        EnvCls=AtariEnv,
        env_kwargs=dict(game=game),
        eval_env_kwargs=dict(game=game),
        batch_T=
        4,  # Four time-steps per sampler iteration. 在collector中采样数据的时候每个循环走多少个step
        batch_B=1,  # 有多少个并行的environment实例
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )
    algo = DQN(min_steps_learn=1e3)  # Run with defaults.
    agent = AtariDqnAgent()  # 在sampler中initialize
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,  # 总共多少个step
        log_interval_steps=1e3,  # 每多少个step记录一次日志
        affinity=dict(cuda_idx=cuda_idx),
    )
    config = dict(game=game)
    name = "dqn_" + game
    log_dir = "example_1"
    with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
        runner.train()
Ejemplo n.º 4
0
def build_and_train(game="pong", run_ID=0, cuda_idx=None, n_parallel=2):
    config = dict(
        env=dict(game=game),
        algo=dict(batch_size=128),
        sampler=dict(batch_T=2, batch_B=32),
    )
    sampler = GpuSampler(
        EnvCls=AtariEnv,
        env_kwargs=dict(game=game),
        CollectorCls=GpuWaitResetCollector,
        eval_env_kwargs=dict(game=game),
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
        # batch_T=4,  # Get from config.
        # batch_B=1,
        **config[
            "sampler"]  # More parallel environments for batched forward-pass.
    )
    algo = DQN(**config["algo"])  # Run with defaults.
    agent = AtariDqnAgent()
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e3,
        affinity=dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel))),
    )
    name = "dqn_" + game
    log_dir = "example_5"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 5
0
def test_rlpyt_simple():
    """ partially copied from example 1 """
    game = "pong"
    run_ID = 0
    cuda_idx = None
    n_steps = 1
    sampler = SerialSampler(
        EnvCls=AtariEnv,
        TrajInfoCls=AtariTrajInfo,  # default traj info + GameScore
        env_kwargs=dict(game=game),
        eval_env_kwargs=dict(game=game),
        batch_T=4,  # Four time-steps per sampler iteration.
        batch_B=1,
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )
    algo = DQN(min_steps_learn=1e3, replay_size=1e3)  # remove memory issues
    agent = AtariDqnAgent()
    runner = MinibatchRl(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=n_steps,
        log_interval_steps=1e3,
        affinity=dict(cuda_idx=cuda_idx),
    )
    config = dict(game=game)
    name = "dqn_" + game
    log_dir = "test_example_1"
    with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
        runner.train()
Ejemplo n.º 6
0
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key):
    affinity = affinity_from_code(slot_affinity_code)
    config = configs[config_key]
    variant = load_variant(log_dir)
    config = update_config(config, variant)
    config["eval_env"]["game"] = config["env"]["game"]

    sampler = GpuSampler(
        EnvCls=AtariEnv,
        env_kwargs=config["env"],
        CollectorCls=WaitResetCollector,
        TrajInfoCls=AtariTrajInfo,
        eval_env_kwargs=config["eval_env"],
        **config["sampler"]
    )
    algo = DQN(optim_kwargs=config["optim"], **config["algo"])
    agent = AtariDqnAgent(model_kwargs=config["model"], **config["agent"])
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        affinity=affinity,
        **config["runner"]
    )
    name = config["env"]["game"]
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 7
0
def build_and_train(game="pong", run_ID=0, cuda_idx=None):
    sampler = SerialSampler(
        EnvCls=AtariEnv,
        TrajInfoCls=AtariTrajInfo,  # default traj info + GameScore
        env_kwargs=dict(game=game),
        eval_env_kwargs=dict(game=game),
        batch_T=4,  # Four time-steps per sampler iteration.
        batch_B=1,
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )
    algo = DQN(min_steps_learn=1e3)  # Run with defaults.
    agent = AtariDqnAgent()
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e3,
        affinity=dict(cuda_idx=cuda_idx),
    )
    config = dict(game=game)
    name = "dqn_" + game
    #log_dir = "example_1"
    log_dir = get_outputs_path()
    with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
        runner.train()
Ejemplo n.º 8
0
def build_and_train(cfg, game="ftwc", run_ID=0):
    #GVS NOTE: for ftwc/qait ?use CpuWaitResetCollector  (or CpuResetCollector)
    sampler = SerialSampler(
        EnvCls=AtariEnv,
        TrajInfoCls=AtariTrajInfo,  # default traj info + GameScore
        env_kwargs=dict(game=game),
        eval_env_kwargs=dict(game=game),
        batch_T=4,  # Four time-steps per sampler iteration.
        batch_B=1,
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e2),
        eval_max_trajectories=5,
    )
    algo = DQN(min_steps_learn=1e2)  # Run with defaults.
    agent = AtariDqnAgent()
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e3,
        affinity=dict(cuda_idx=cfg.cuda_idx),
    )
    config = dict(game=game)
    name = "dqn_" + game
    log_dir = "ftwc"
    with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
        runner.train()
Ejemplo n.º 9
0
def build_and_train(level="nav_maze_random_goal_01", run_ID=0, cuda_idx=None):
    sampler = SerialSampler(
        EnvCls=DeepmindLabEnv,
        env_kwargs=dict(level=level),
        eval_env_kwargs=dict(level=level),
        batch_T=4,  # Four time-steps per sampler iteration.
        batch_B=1,
        max_decorrelation_steps=0,
        eval_n_envs=5,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )
    algo = DQN(min_steps_learn=1e3)  # Run with defaults.
    agent = AtariDqnAgent()
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=50e6,
        log_interval_steps=1e5,
        affinity=dict(cuda_idx=cuda_idx),
    )
    config = dict(level=level)
    name = "lab_dqn"
    log_dir = "lab_example_1"
    with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
        runner.train()
Ejemplo n.º 10
0
def build_and_train(game="pong", run_ID=0):
    # Change these inputs to match local machine and desired parallelism.
    affinity = make_affinity(
        run_slot=0,
        n_cpu_core=8,  # Use 16 cores across all experiments.
        n_gpu=2,  # Use 8 gpus across all experiments.
        gpu_per_run=1,
        sample_gpu_per_run=1,
        async_sample=True,
        optim_sample_share_gpu=False,
        # hyperthread_offset=24,  # If machine has 24 cores.
        # n_socket=2,  # Presume CPU socket affinity to lower/upper half GPUs.
        # gpu_per_run=2,  # How many GPUs to parallelize one run across.
        # cpu_per_run=1,
    )

    sampler = AsyncGpuSampler(
        EnvCls=AtariEnv,
        TrajInfoCls=AtariTrajInfo,
        env_kwargs=dict(game=game),
        batch_T=5,
        batch_B=36,
        max_decorrelation_steps=100,
        eval_env_kwargs=dict(game=game),
        eval_n_envs=2,
        eval_max_steps=int(10e3),
        eval_max_trajectories=4,
    )
    algo = DQN(
        replay_ratio=8,
        min_steps_learn=1e4,
        replay_size=int(1e5)
    )
    agent = AtariDqnAgent()
    runner = AsyncRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=2e6,
        log_interval_steps=1e4,
        affinity=affinity,
    )
    config = dict(game=game)
    name = "async_dqn_" + game
    log_dir = "async_dqn"
    with logger_context(log_dir, run_ID, name, config):
        runner.train()
Ejemplo n.º 11
0
def build_and_train(bsuite_id,
                    gym_id,
                    run_ID=0,
                    cuda_idx=None,
                    results_dir='./bsuite_baseline',
                    n_parallel=8):
    id = bsuite_id if not gym_id else gym_id
    logger._tf_summary_dir = f'./runs/{id.replace("/", "_")}_{run_ID}_baseline_{datetime.now().strftime("%D-%T").replace("/", "_")}'
    logger._tf_summary_writer = SummaryWriter(logger._tf_summary_dir)

    def get_env(*args, **kwargs):
        return GymEnvWrapper(
            TransformObservation(env=FrameStack(
                num_stack=4,
                env=(gym_wrapper.GymFromDMEnv(
                    bsuite.load_and_record_to_csv(
                        bsuite_id=bsuite_id,
                        results_dir=results_dir,
                        overwrite=True,
                    )) if not gym_id else gym.make(gym_id))),
                                 f=lambda lazy_frames: np.reshape(
                                     np.stack(lazy_frames._frames), -1)))

    sampler = SerialSampler(  # TODO (Async)GpuSampler
        EnvCls=get_env,
        env_kwargs=dict(game=bsuite_id),
        eval_env_kwargs=dict(game=bsuite_id),
        batch_T=
        1,  # Four time-steps per sampler iteration. (Only influence count)
        batch_B=1,
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )

    n_steps = 3e4
    algo = DQN(
        discount=0.995,
        min_steps_learn=1e3,
        eps_steps=n_steps,
        # delta_clip=None,
        # learning_rate=1e-4,
        # target_update_tau=500,
        # target_update_tau=0.01,
        # target_update_interval=100,
        double_dqn=True,
        prioritized_replay=True,
        # clip_grad_norm=1,  # FIXME arbitrary
        # n_step_return=2,  # FIXME arbitrary
        # clip_grad_norm=1000000,
    )  # Run with defaults.
    # agent = MlpDqnAgent(ModelCls=lambda *args, **kwargs: MlpDqnModel(*args, **kwargs, dueling=True))
    agent = MlpDqnAgent(ModelCls=MlpDqnModel)

    p = psutil.Process()
    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=n_steps,  # orig 50e6
        log_interval_steps=1e2,  # orig 1e3,
        affinity=dict(cuda_idx=cuda_idx),
        # affinity=dict(cuda_idx=cuda_idx, workers_cpus=p.cpu_affinity()[:n_parallel]),
    )
    runner.train()
def build_and_train(game="aaai_multi", run_ID=0):
    # Change these inputs to match local machine and desired parallelism.
    affinity = make_affinity(
        run_slot=0,
        n_cpu_core=8,  # Use 16 cores across all experiments.
        n_gpu=1,  # Use 8 gpus across all experiments.
        sample_gpu_per_run=1,
        async_sample=True,
        optim_sample_share_gpu=True
        # hyperthread_offset=24,  # If machine has 24 cores.
        # n_socket=2,  # Presume CPU socket affinity to lower/upper half GPUs.
        # gpu_per_run=2,  # How many GPUs to parallelize one run across.
        # cpu_per_run=1,
    )

    train_conf = PytConfig([
        Path(JSONS_FOLDER, 'configs', '2v2', 'all_equal.json'),
        Path(JSONS_FOLDER, 'configs', '2v2', 'more_horizontally.json'),
        Path(JSONS_FOLDER, 'configs', '2v2', 'more_vertically.json'),
        Path(JSONS_FOLDER, 'configs', '2v2', 'more_from_west.json'),
        Path(JSONS_FOLDER, 'configs', '2v2', 'more_from_east.json'),
        Path(JSONS_FOLDER, 'configs', '2v2', 'more_from_north.json'),
        Path(JSONS_FOLDER, 'configs', '2v2', 'more_from_south.json'),
    ])

    eval_conf = PytConfig({
        'all_equal': Path(JSONS_FOLDER, 'configs', '2v2', 'all_equal.json'),
        'more_horizontally': Path(JSONS_FOLDER, 'configs', '2v2', 'more_horizontally.json'),
        'more_vertically': Path(JSONS_FOLDER, 'configs', '2v2', 'more_vertically.json'),
        'more_south': Path(JSONS_FOLDER, 'configs', '2v2', 'more_from_south.json'),
        'more_east': Path(JSONS_FOLDER, 'configs', '2v2', 'more_from_east.json')
    })

    sampler = AsyncGpuSampler(
        EnvCls=Rlpyt_env,
        TrajInfoCls=AaaiTrajInfo,
        env_kwargs={
            'pyt_conf': train_conf,
            'max_steps': 3000
        },
        batch_T=8,
        batch_B=8,
        max_decorrelation_steps=100,
        eval_env_kwargs={
            'pyt_conf': eval_conf,
            'max_steps': 3000
        },
        eval_max_steps=24100,
        eval_n_envs=2,
    )
    algo = DQN(
        replay_ratio=1024,
        double_dqn=True,
        prioritized_replay=True,
        min_steps_learn=5000,
        learning_rate=0.0001,
        target_update_tau=1.0,
        target_update_interval=1000,
        eps_steps=5e4,
        batch_size=512,
        pri_alpha=0.6,
        pri_beta_init=0.4,
        pri_beta_final=1.,
        pri_beta_steps=int(7e4),
        replay_size=int(1e6),
        clip_grad_norm=1.0,
        updates_per_sync=6
    )
    agent = DqnAgent(ModelCls=Frap)
    runner = AsyncRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        log_interval_steps=1000,
        affinity=affinity,
        n_steps=6e5
    )

    config = dict(game=game)
    name = "frap_" + game
    log_dir = Path(PROJECT_ROOT, "saved", "rlpyt", "multi", "frap")

    save_path = Path(log_dir, 'run_{}'.format(run_ID))
    for f in save_path.glob('**/*'):
        print(f)
        f.unlink()

    with logger_context(str(log_dir), run_ID, name, config,
                        snapshot_mode='last', use_summary_writer=True, override_prefix=True):
        runner.train()
def build_and_train(bsuite_id,
                    gym_id,
                    run_ID=0,
                    cuda_idx=None,
                    results_dir='./bsuite_shaping',
                    n_parallel=4):
    id = bsuite_id if not gym_id else gym_id
    logger._tf_summary_dir = f'./runs/{id.replace("/", "_")}_{run_ID}_model_shaping_{datetime.now().strftime("%D-%T").replace("/", "_")}'
    logger._tf_summary_writer = SummaryWriter(logger._tf_summary_dir)

    def get_env(*args, **kwargs):
        return GymEnvWrapper(
            TransformObservation(env=FrameStack(
                num_stack=4,
                env=(gym_wrapper.GymFromDMEnv(
                    bsuite.load_and_record_to_csv(
                        bsuite_id=bsuite_id,
                        results_dir=results_dir,
                        overwrite=True,
                    )) if not gym_id else gym.make(gym_id))),
                                 f=lambda lazy_frames: np.reshape(
                                     np.stack(lazy_frames._frames), -1)))

    env_info = get_env()
    obs_ndim = len(env_info.observation_space.shape)
    obs_size = reduce(lambda x, y: x * y, env_info.observation_space.shape)

    def mlp_factory(input_size=obs_size,
                    output_size=env_info.action_space.n,
                    hidden_sizes=None,
                    dueling=False):
        if hidden_sizes is None: hidden_sizes = [64, 64]
        return lambda *args, **kwargs: MlpDqnModel(
            input_size=input_size,
            fc_sizes=hidden_sizes,
            output_size=output_size,
            dueling=dueling,
        )

    latent_step_model = mlp_factory(input_size=obs_size + 1,
                                    output_size=obs_size)()
    reward_model = mlp_factory(input_size=2 * obs_size + 1, output_size=1)()
    termination_model = mlp_factory(input_size=2 * obs_size + 1,
                                    output_size=1)()

    def get_modeled_env(*args, **kwargs):
        return ModeledEnv(
            latent_step_model=latent_step_model,
            reward_model=reward_model,
            termination_model=termination_model,
            env_cls=get_env,
        )

    model_info = get_modeled_env()

    sampler = SerialSampler(
        EnvCls=get_env,
        batch_T=1,
        batch_B=1,
        env_kwargs=dict(game=bsuite_id),
        eval_env_kwargs=dict(game=bsuite_id),
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )

    weak_sampler = SerialSampler(
        EnvCls=get_modeled_env,
        batch_T=1,
        batch_B=1,
        env_kwargs=dict(game=bsuite_id),
        eval_env_kwargs=dict(game=bsuite_id),
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )

    def shaping(samples):
        # TODO eval/train mode here and in other places
        if logger._iteration <= 1e3:  # FIXME(1e3)
            return 0

        with torch.no_grad():
            obs = (samples.agent_inputs.observation.to(device)
                   )  # TODO check if maybe better to keep it on cpu
            obsprim = (samples.target_inputs.observation.to(device))
            qs = weak_agent(obs, samples.agent_inputs.prev_action.to(device),
                            samples.agent_inputs.prev_reward.to(device))
            qsprim = weak_agent(obsprim,
                                samples.target_inputs.prev_action.to(device),
                                samples.target_inputs.prev_reward.to(device))
            vals = 0.995 * torch.max(qsprim, dim=1).values - torch.max(
                qs, dim=1).values

            if logger._iteration % 1e1 == 0:  # FIXME(1e1)
                with logger.tabular_prefix("Shaping"):
                    logger.record_tabular_misc_stat(
                        'ShapedReward',
                        vals.detach().cpu().numpy())
            return vals

    n_steps = 3e4
    algo = ShapedDQN(
        # target_update_tau=0.01,
        # target_update_interval=16,
        shaping_function=shaping,
        # target_update_interval=312,
        discount=0.995,
        min_steps_learn=1e3,  # 1e3
        eps_steps=n_steps,
        # pri_beta_steps=n_steps,
        double_dqn=True,
        prioritized_replay=True,
        # clip_grad_norm=1,  # FIXME arbitrary
        # n_step_return=4,  # FIXME arbitrary
    )
    weak_algo = DQN(
        # target_update_tau=0.01,
        # target_update_interval=16,
        # target_update_interval=312,
        discount=0.995,
        min_steps_learn=1e3,  # 1e3
        eps_steps=n_steps,
        # pri_beta_steps=n_steps,
        double_dqn=True,
        prioritized_replay=True,
        # clip_grad_norm=1,  # FIXME arbitrary
        # n_step_return=4,  # FIXME arbi
    )

    agent = DqnAgent(ModelCls=mlp_factory(hidden_sizes=[512], dueling=False))
    weak_agent = DqnAgent(ModelCls=mlp_factory(
        input_size=obs_size, hidden_sizes=[512], dueling=False))

    p = psutil.Process()
    runner = WeakAgentModelBasedRunner(
        algo=algo,
        agent=agent,
        sampler=sampler,
        weak_algo=weak_algo,
        weak_agent=weak_agent,
        weak_sampler=weak_sampler,
        env_model=get_modeled_env(),
        # n_steps=num_episodes,
        n_steps=n_steps,  # orig 50e6
        log_interval_steps=1e2,  # orig 1e3
        affinity=dict(cuda_idx=cuda_idx),
        # affinity=dict(cuda_idx=cuda_idx, workers_cpus=p.cpu_affinity()),
    )

    env_info.close()
    model_info.close()
    runner.train()
def build_and_train(id="SurfaceCode-v0", name='run', log_dir='./logs'):
    # Change these inputs to match local machine and desired parallelism.
    # affinity = make_affinity(
    #     n_cpu_core=24,  # Use 16 cores across all experiments.
    #     n_gpu=1,  # Use 8 gpus across all experiments.
    #     async_sample=True,
    #     set_affinity=True
    # )
    # affinity['optimizer'][0]['cuda_idx'] = 1
    num_cpus = multiprocessing.cpu_count()
    affinity = make_affinity(n_cpu_core=num_cpus//2, cpu_per_run=num_cpus//2, n_gpu=0, async_sample=False,
                                 set_affinity=True)
    affinity['workers_cpus'] = tuple(range(num_cpus))
    affinity['master_torch_threads'] = 28
    # env_kwargs = dict(id='SurfaceCode-v0', error_model='X', volume_depth=5)
    state_dict = None # torch.load('./logs/run_29/params.pkl', map_location='cpu')
    agent_state_dict = None #state_dict['agent_state_dict']['model']
    optim_state_dict = None #state_dict['optimizer_state_dict']

    # sampler = AsyncCpuSampler(
    sampler = CpuSampler(
        # sampler=SerialSampler(
        EnvCls=make_qec_env,
        # TrajInfoCls=AtariTrajInfo,
        env_kwargs=dict(error_rate=0.005, error_model='DP'),
        batch_T=10,
        batch_B=num_cpus * 10,
        max_decorrelation_steps=100,
        eval_env_kwargs=dict(error_rate=0.005, error_model='DP', fixed_episode_length=5000),
        eval_n_envs=num_cpus,
        eval_max_steps=int(1e6),
        eval_max_trajectories=num_cpus,
        TrajInfoCls=EnvInfoTrajInfo
    )
    algo = DQN(
        replay_ratio=8,
        learning_rate=1e-5,
        min_steps_learn=1e4,
        replay_size=int(5e4),
        batch_size=32,
        double_dqn=True,
        # target_update_tau=0.002,
        target_update_interval=5000,
        ReplayBufferCls=UniformReplayBuffer,
        initial_optim_state_dict=optim_state_dict,
        eps_steps=2e6,
    )
    agent = AtariDqnAgent(model_kwargs=dict(channels=[32, 64, 64],
                                            kernel_sizes=[3, 2, 2],
                                            strides=[2, 1, 1],
                                            paddings=[0, 0, 0],
                                            fc_sizes=[512, ],
                                            dueling=True),
                          ModelCls=QECModel,
                          eps_init=1,
                          eps_final=0.02,
                          eps_itr_max=int(5e6),
                          eps_eval=0,
                          initial_model_state_dict=agent_state_dict)
    # agent = DqnAgent(ModelCls=FfModel)
    runner = QECSynchronousRunner(
        # runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=1e9,
        log_interval_steps=3e5,
        affinity=affinity,
    )
    config = dict(game=id)
    config_logger(log_dir, name=name, snapshot_mode='last', log_params=config)
    # with logger_context(log_dir, run_ID, name, config):
    runner.train()