def build_and_train(log_dir, game="cartpole_balance", run_ID=0, cuda_idx=None, eval=False, save_model='last', load_model_path=None): params = torch.load(load_model_path) if load_model_path else {} agent_state_dict = params.get('agent_state_dict') optimizer_state_dict = params.get('optimizer_state_dict') action_repeat = 2 factory_method = make_wapper(DeepMindControl, [ActionRepeat, NormalizeActions, TimeLimit], [ dict(amount=action_repeat), dict(), dict(duration=1000 / action_repeat) ]) sampler = SerialSampler( EnvCls=factory_method, TrajInfoCls=TrajInfo, env_kwargs=dict(name=game, use_state=args.state), eval_env_kwargs=dict(name=game, use_state=args.state), batch_T=1, batch_B=1, max_decorrelation_steps=0, eval_n_envs=10, eval_max_steps=int(10e3), eval_max_trajectories=5, ) algo = Dreamer( initial_optim_state_dict=optimizer_state_dict) # Run with defaults. agent = DMCDreamerAgent(train_noise=0.3, eval_noise=0, expl_type="additive_gaussian", expl_min=None, expl_decay=None, initial_model_state_dict=agent_state_dict) runner_cls = MinibatchRlEval if eval else MinibatchRl runner = runner_cls( algo=algo, agent=agent, sampler=sampler, n_steps=5e6, log_interval_steps=1e3, affinity=dict(cuda_idx=cuda_idx), ) config = dict(game=game) name = "dreamer_" + game with logger_context(log_dir, run_ID, name, config, snapshot_mode=save_model, override_prefix=True, use_summary_writer=True): runner.train()
def build_and_train(game="pong", run_ID=0, cuda_idx=None, eval=False): action_repeat = 2 env_kwargs = dict( name=game, action_repeat=action_repeat, size=(64, 64), grayscale=False, life_done=True, sticky_actions=True, ) factory_method = make_wapper( AtariEnv, [OneHotAction, TimeLimit], [dict(), dict(duration=1000 / action_repeat)]) sampler = SerialSampler( EnvCls=factory_method, TrajInfoCls=AtariTrajInfo, # default traj info + GameScore env_kwargs=env_kwargs, eval_env_kwargs=env_kwargs, batch_T=1, batch_B=1, max_decorrelation_steps=0, eval_n_envs=10, eval_max_steps=int(10e3), eval_max_trajectories=5, ) algo = Dreamer( batch_size=1, batch_length=5, train_every=10, train_steps=2, prefill=10, horizon=5, replay_size=100, log_video=False, kl_scale=0.1, use_pcont=True, ) agent = AtariDreamerAgent(train_noise=0.4, eval_noise=0, expl_type="epsilon_greedy", expl_min=0.1, expl_decay=2000 / 0.3, model_kwargs=dict(use_pcont=True)) runner_cls = MinibatchRlEval if eval else MinibatchRl runner = runner_cls( algo=algo, agent=agent, sampler=sampler, n_steps=20, log_interval_steps=10, affinity=dict(cuda_idx=cuda_idx), ) runner.train()
def build_and_train(log_dir, level="Level_GoToLocalAvoidLava", run_ID=0, cuda_idx=None, eval=False, save_model='last', load_model_path=None): params = torch.load(load_model_path) if load_model_path else {} agent_state_dict = params.get('agent_state_dict') optimizer_state_dict = params.get('optimizer_state_dict') env_kwargs = dict( level=level, slipperiness=0.0, one_hot_obs=True, ) factory_method = make_wapper( Minigrid, [OneHotAction, TimeLimit], [dict(), dict(duration=64)]) sampler = SerialSampler( EnvCls=factory_method, TrajInfoCls=TrajInfo, env_kwargs=env_kwargs, eval_env_kwargs=env_kwargs, batch_T=1, batch_B=1, max_decorrelation_steps=0, eval_n_envs=10, eval_max_steps=int(10e3), eval_max_trajectories=5, ) algo = Dreamer(horizon=10, kl_scale=0.1, use_pcont=True, initial_optim_state_dict=optimizer_state_dict, env=OneHotAction(TimeLimit(Minigrid(**env_kwargs), 64)), save_env_videos=True) agent = MinigridDreamerAgent(train_noise=0.4, eval_noise=0, expl_type="epsilon_greedy", expl_min=0.1, expl_decay=2000 / 0.3, initial_model_state_dict=agent_state_dict, model_kwargs=dict(use_pcont=True, stride=1, shape=(20, 7, 7), depth=1, padding=2, full_conv=False)) runner_cls = MinibatchRlEval if eval else MinibatchRl runner = runner_cls( algo=algo, agent=agent, sampler=sampler, n_steps=5e6, log_interval_steps=1e3, affinity=dict(cuda_idx=cuda_idx), ) config = dict(level=level) name = "dreamer_" + level with logger_context(log_dir, run_ID, name, config, snapshot_mode=save_model, override_prefix=True, use_summary_writer=True): runner.train()
def build_and_train(log_dir, game="traffic", run_ID=0, cuda_idx=None, eval=False, save_model='last', load_model_path=None, action_repeat=1, **kwargs): params = torch.load(load_model_path) if load_model_path else {} agent_state_dict = params.get('agent_state_dict') optimizer_state_dict = params.get('optimizer_state_dict') env_kwargs = dict( name=game, render=False, **kwargs ) factory_method = make_wapper( TrafficEnv, [ActionRepeat, OneHotAction, TimeLimit], [dict(amount=action_repeat), dict(), dict(duration=1000 / action_repeat)]) sampler = SerialSampler( EnvCls=factory_method, TrajInfoCls=AtariTrajInfo, # default traj info + GameScore env_kwargs=env_kwargs, eval_env_kwargs=env_kwargs, batch_T=1, batch_B=1, max_decorrelation_steps=0, eval_n_envs=10, eval_max_steps=int(10e3), eval_max_trajectories=5, ) algo = Dreamer(horizon=10, kl_scale=0.1, use_pcont=True, initial_optim_state_dict=optimizer_state_dict) agent = AtariDreamerAgent(train_noise=0.4, eval_noise=0, expl_type="epsilon_greedy", expl_min=0.1, expl_decay=2000 / 0.3, initial_model_state_dict=agent_state_dict, model_kwargs=dict(use_pcont=True)) runner_cls = MinibatchRlEval if eval else MinibatchRl runner = runner_cls( algo=algo, agent=agent, sampler=sampler, n_steps=5e6, log_interval_steps=1e3, affinity=dict(cuda_idx=cuda_idx), ) config = dict(game=game) name = "dreamer_" + game with logger_context(log_dir, run_ID, name, config, snapshot_mode=save_model, override_prefix=True, use_summary_writer=True): runner.train()