def build_and_train(slot_affinity_code, log_dir, run_ID, config_key): affinity = get_affinity(slot_affinity_code) config = configs[config_key] # variant = load_variant(log_dir) # config = update_config(config, variant) sampler = CpuParallelSampler( EnvCls=AtariEnv, env_kwargs=config["env"], CollectorCls=EpisodicLivesWaitResetCollector, TrajInfoCls=AtariTrajInfo, **config["sampler"] ) algo = A2C(optim_kwargs=config["optim"], **config["algo"]) agent = AtariFfAgent(model_kwargs=config["model"], **config["agent"]) runner = MinibatchRl( algo=algo, agent=agent, sampler=sampler, affinity=affinity, **config["runner"] ) name = config["env"]["game"] with logger_context(log_dir, run_ID, name, config): # Might have to flatten config runner.train()
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key): affinity = affinity_from_code(slot_affinity_code) config = configs[config_key] variant = load_variant(log_dir) config = update_config(config, variant) config["eval_env"]["game"] = config["env"]["game"] CollectorCls = config["sampler"].pop("CollectorCls", None) sampler = CpuParallelSampler( EnvCls=AtariEnv, env_kwargs=config["env"], CollectorCls=CollectorCls or WaitResetCollector, TrajInfoCls=AtariTrajInfo, eval_env_kwargs=config["eval_env"], **config["sampler"] ) algo = DQN(optim_kwargs=config["optim"], **config["algo"]) agent = AtariDqnAgent(model_kwargs=config["model"], **config["agent"]) runner = MinibatchRlEval( algo=algo, agent=agent, sampler=sampler, affinity=affinity, **config["runner"] ) name = config["env"]["game"] with logger_context(log_dir, run_ID, name, config): runner.train()
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key): affinity = affinity_from_code(slot_affinity_code) config = configs[config_key] variant = load_variant(log_dir) config = update_config(config, variant) sampler = CpuParallelSampler(EnvCls=gym_make, env_kwargs=config["env"], CollectorCls=ResetCollector, **config["sampler"]) algo = DDPG(optim_kwargs=config["optim"], **config["algo"]) agent = DdpgAgent(**config["agent"]) runner = MinibatchRl(algo=algo, agent=agent, sampler=sampler, affinity=affinity, **config["runner"]) name = config["env"]["id"] with logger_context(log_dir, run_ID, name, config): runner.train()
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key): affinity = affinity_from_code(slot_affinity_code) config = configs[config_key] # variant = load_variant(log_dir) # config = update_config(config, variant) sampler = CpuParallelSampler( EnvCls=AtariEnv, env_kwargs=config["env"], CollectorCls=EpisodicLivesWaitResetCollector, **config["sampler"] ) algo = A2C(optim_kwargs=config["optim"], **config["algo"]) agent = AtariLstmAgent(model_kwargs=config["model"], **config["agent"]) runner = MinibatchRl( algo=algo, agent=agent, sampler=sampler, affinity=affinity, **config["runner"] ) name = config["env"]["game"] + str(config["algo"]["entropy_loss_coeff"]) with logger_context(log_dir, run_ID, name, config): runner.train()