def get_config(): config = ConfigDict() config.agent_config = ConfigDict() config.agent_config.network = ConfigDict() config.shell_config = ConfigDict() config.shell_config.use_gpu = False config.session_config = ConfigDict() config.session_config.sync_period = 100 return config
def get_shell_config(): config = ConfigDict() agent_config = get_agent_config() # shell class path is default to the distributed folder. config.class_path = 'liaison.distributed.shell_for_test' config.class_name = 'Shell' config.agent_scope = 'shell' config.use_gpu = True config.agent_class = U.import_obj(agent_config.class_name, agent_config.class_path) config.agent_config = agent_config config.agent_config.update(evaluation_mode=True) return config