Ejemplo n.º 1
0
def build_and_train(log_dir,
                    game="cartpole_balance",
                    run_ID=0,
                    cuda_idx=None,
                    eval=False,
                    save_model='last',
                    load_model_path=None):
    params = torch.load(load_model_path) if load_model_path else {}
    agent_state_dict = params.get('agent_state_dict')
    optimizer_state_dict = params.get('optimizer_state_dict')

    action_repeat = 2
    factory_method = make_wapper(DeepMindControl,
                                 [ActionRepeat, NormalizeActions, TimeLimit], [
                                     dict(amount=action_repeat),
                                     dict(),
                                     dict(duration=1000 / action_repeat)
                                 ])
    sampler = SerialSampler(
        EnvCls=factory_method,
        TrajInfoCls=TrajInfo,
        env_kwargs=dict(name=game, use_state=args.state),
        eval_env_kwargs=dict(name=game, use_state=args.state),
        batch_T=1,
        batch_B=1,
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )
    algo = Dreamer(
        initial_optim_state_dict=optimizer_state_dict)  # Run with defaults.
    agent = DMCDreamerAgent(train_noise=0.3,
                            eval_noise=0,
                            expl_type="additive_gaussian",
                            expl_min=None,
                            expl_decay=None,
                            initial_model_state_dict=agent_state_dict)
    runner_cls = MinibatchRlEval if eval else MinibatchRl
    runner = runner_cls(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=5e6,
        log_interval_steps=1e3,
        affinity=dict(cuda_idx=cuda_idx),
    )
    config = dict(game=game)
    name = "dreamer_" + game
    with logger_context(log_dir,
                        run_ID,
                        name,
                        config,
                        snapshot_mode=save_model,
                        override_prefix=True,
                        use_summary_writer=True):
        runner.train()
Ejemplo n.º 2
0
def build_and_train(game="pong", run_ID=0, cuda_idx=None, eval=False):
    action_repeat = 2
    env_kwargs = dict(
        name=game,
        action_repeat=action_repeat,
        size=(64, 64),
        grayscale=False,
        life_done=True,
        sticky_actions=True,
    )
    factory_method = make_wapper(
        AtariEnv, [OneHotAction, TimeLimit],
        [dict(), dict(duration=1000 / action_repeat)])
    sampler = SerialSampler(
        EnvCls=factory_method,
        TrajInfoCls=AtariTrajInfo,  # default traj info + GameScore
        env_kwargs=env_kwargs,
        eval_env_kwargs=env_kwargs,
        batch_T=1,
        batch_B=1,
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )
    algo = Dreamer(
        batch_size=1,
        batch_length=5,
        train_every=10,
        train_steps=2,
        prefill=10,
        horizon=5,
        replay_size=100,
        log_video=False,
        kl_scale=0.1,
        use_pcont=True,
    )
    agent = AtariDreamerAgent(train_noise=0.4,
                              eval_noise=0,
                              expl_type="epsilon_greedy",
                              expl_min=0.1,
                              expl_decay=2000 / 0.3,
                              model_kwargs=dict(use_pcont=True))
    runner_cls = MinibatchRlEval if eval else MinibatchRl
    runner = runner_cls(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=20,
        log_interval_steps=10,
        affinity=dict(cuda_idx=cuda_idx),
    )
    runner.train()
Ejemplo n.º 3
0
def build_and_train(log_dir, level="Level_GoToLocalAvoidLava", run_ID=0, cuda_idx=None, eval=False, save_model='last', load_model_path=None):
    params = torch.load(load_model_path) if load_model_path else {}
    agent_state_dict = params.get('agent_state_dict')
    optimizer_state_dict = params.get('optimizer_state_dict')

    env_kwargs = dict(
        level=level,
        slipperiness=0.0,
        one_hot_obs=True,
    )
    factory_method = make_wapper(
        Minigrid,
        [OneHotAction, TimeLimit],
        [dict(), dict(duration=64)])
    sampler = SerialSampler(
        EnvCls=factory_method,
        TrajInfoCls=TrajInfo,
        env_kwargs=env_kwargs,
        eval_env_kwargs=env_kwargs,
        batch_T=1,
        batch_B=1,
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )
    algo = Dreamer(horizon=10, kl_scale=0.1, use_pcont=True, initial_optim_state_dict=optimizer_state_dict,
                   env=OneHotAction(TimeLimit(Minigrid(**env_kwargs), 64)), save_env_videos=True)
    agent = MinigridDreamerAgent(train_noise=0.4, eval_noise=0, expl_type="epsilon_greedy",
                              expl_min=0.1, expl_decay=2000 / 0.3, initial_model_state_dict=agent_state_dict,
                              model_kwargs=dict(use_pcont=True, stride=1, shape=(20, 7, 7), depth=1, padding=2,
                                                full_conv=False))
    runner_cls = MinibatchRlEval if eval else MinibatchRl
    runner = runner_cls(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=5e6,
        log_interval_steps=1e3,
        affinity=dict(cuda_idx=cuda_idx),
    )
    config = dict(level=level)
    name = "dreamer_" + level
    with logger_context(log_dir, run_ID, name, config, snapshot_mode=save_model, override_prefix=True,
                        use_summary_writer=True):
        runner.train()
Ejemplo n.º 4
0
def build_and_train(log_dir, game="traffic", run_ID=0, cuda_idx=None, eval=False, save_model='last', load_model_path=None, action_repeat=1, **kwargs):
    params = torch.load(load_model_path) if load_model_path else {}
    agent_state_dict = params.get('agent_state_dict')
    optimizer_state_dict = params.get('optimizer_state_dict')

    env_kwargs = dict(
        name=game,
        render=False,
        **kwargs
    )
    factory_method = make_wapper(
        TrafficEnv,
        [ActionRepeat, OneHotAction, TimeLimit],
        [dict(amount=action_repeat), dict(), dict(duration=1000 / action_repeat)])
    sampler = SerialSampler(
        EnvCls=factory_method,
        TrajInfoCls=AtariTrajInfo,  # default traj info + GameScore
        env_kwargs=env_kwargs,
        eval_env_kwargs=env_kwargs,
        batch_T=1,
        batch_B=1,
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )
    algo = Dreamer(horizon=10, kl_scale=0.1, use_pcont=True, initial_optim_state_dict=optimizer_state_dict)
    agent = AtariDreamerAgent(train_noise=0.4, eval_noise=0, expl_type="epsilon_greedy",
                              expl_min=0.1, expl_decay=2000 / 0.3, initial_model_state_dict=agent_state_dict,
                              model_kwargs=dict(use_pcont=True))
    runner_cls = MinibatchRlEval if eval else MinibatchRl
    runner = runner_cls(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=5e6,
        log_interval_steps=1e3,
        affinity=dict(cuda_idx=cuda_idx),
    )
    config = dict(game=game)
    name = "dreamer_" + game
    with logger_context(log_dir, run_ID, name, config, snapshot_mode=save_model, override_prefix=True,
                        use_summary_writer=True):
        runner.train()