def build_and_train(slot_affinity_code, log_dir, run_ID, config_key): affinity = get_affinity(slot_affinity_code) config = configs[config_key] # variant = load_variant(log_dir) # config = update_config(config, variant) sampler = CpuParallelSampler( EnvCls=AtariEnv, env_kwargs=config["env"], CollectorCls=EpisodicLivesWaitResetCollector, TrajInfoCls=AtariTrajInfo, **config["sampler"] ) algo = A2C(optim_kwargs=config["optim"], **config["algo"]) agent = AtariFfAgent(model_kwargs=config["model"], **config["agent"]) runner = MinibatchRl( algo=algo, agent=agent, sampler=sampler, affinity=affinity, **config["runner"] ) name = config["env"]["game"] with logger_context(log_dir, run_ID, name, config): # Might have to flatten config runner.train()
def build_and_train(slot_affinity_code, log_dir, run_ID, config_key): affinity = affinity_from_code(slot_affinity_code) config = configs[config_key] # variant = load_variant(log_dir) # config = update_config(config, variant) sampler = CpuParallelSampler( EnvCls=AtariEnv, env_kwargs=config["env"], CollectorCls=EpisodicLivesWaitResetCollector, **config["sampler"] ) algo = A2C(optim_kwargs=config["optim"], **config["algo"]) agent = AtariLstmAgent(model_kwargs=config["model"], **config["agent"]) runner = MinibatchRl( algo=algo, agent=agent, sampler=sampler, affinity=affinity, **config["runner"] ) name = config["env"]["game"] + str(config["algo"]["entropy_loss_coeff"]) with logger_context(log_dir, run_ID, name, config): runner.train()