def test_basic(self): a, locals = self.create_parameters() a.save(save_path=GlobalConfig().DEFAULT_LOG_PATH + '/param_path', name=a.name, global_step=0) or_val = a._source_config.config_dict['var1'] or_param = a('param3').copy() a._source_config.config_dict['var1'] = 100 a._parameters['param3'] = 1000 self.assertNotEqual(a._source_config.config_dict['var1'], or_val) self.assertFalse(np.equal(a._parameters['param3'], or_param).all()) a.load(load_path=GlobalConfig().DEFAULT_LOG_PATH + '/param_path', name=a.name, global_step=0) self.assertEqual(a._source_config.config_dict['var1'], or_val) self.assertTrue(np.equal(a._parameters['param3'], or_param).all()) b, _ = self.create_parameters() b.copy_from(a) for key in a._source_config.required_key_dict.keys(): if isinstance(a[key], np.ndarray): self.assertTrue(np.equal(a[key], b[key]).all()) else: self.assertEqual(id(a[key]), id(b[key])) self.assertEqual(id(a(key)), id(b(key))) for key in a._parameters.keys(): if isinstance(a[key], np.ndarray): self.assertTrue(np.equal(a[key], b[key]).all()) else: self.assertEqual(a[key], b[key]) self.assertEqual(a(key), b(key))
def func(algo=algo, locals=locals): GlobalConfig().set( 'DEFAULT_EXPERIMENT_END_POINT', dict(TOTAL_AGENT_TRAIN_SAMPLE_COUNT=500, TOTAL_AGENT_TEST_SAMPLE_COUNT=None, TOTAL_AGENT_UPDATE_COUNT=None)) if not algo: algo, locals = self.create_dqn() env_spec = locals['env_spec'] env = locals['env'] agent = self.create_agent(env=locals['env'], algo=algo, name='agent', eps=self.create_eps(env_spec)[0], env_spec=env_spec)[0] exp = self.create_exp(name='model_free', env=env, agent=agent) algo.parameters.set_scheduler( param_key='LEARNING_RATE', to_tf_ph_flag=True, scheduler=LinearSchedule( t_fn=exp.TOTAL_ENV_STEP_TRAIN_SAMPLE_COUNT, schedule_timesteps=GlobalConfig( ).DEFAULT_EXPERIMENT_END_POINT[ 'TOTAL_AGENT_TRAIN_SAMPLE_COUNT'], final_p=0.0001, initial_p=0.01)) exp.run() self.assertEqual(exp.TOTAL_AGENT_TEST_SAMPLE_COUNT(), exp.TOTAL_ENV_STEP_TEST_SAMPLE_COUNT()) self.assertEqual(exp.TOTAL_AGENT_TRAIN_SAMPLE_COUNT(), exp.TOTAL_ENV_STEP_TRAIN_SAMPLE_COUNT(), 500)
def _is_ended(self): """ :return: True if an experiment is ended :rtype: bool """ key_founded_flag = False finished_flag = False for key in GlobalConfig().DEFAULT_EXPERIMENT_END_POINT: if GlobalConfig().DEFAULT_EXPERIMENT_END_POINT[key] is not None: key_founded_flag = True if get_global_status_collect()(key) >= GlobalConfig( ).DEFAULT_EXPERIMENT_END_POINT[key]: ConsoleLogger().print( 'info', 'pipeline ended because {}: {} >= end point value {}'. format( key, get_global_status_collect()(key), GlobalConfig().DEFAULT_EXPERIMENT_END_POINT[key])) finished_flag = True if key_founded_flag is False: ConsoleLogger().print( 'warning', '{} in experiment_end_point is not registered with global status collector: {}, experiment may not end' .format(GlobalConfig().DEFAULT_EXPERIMENT_END_POINT, list(get_global_status_collect()().keys()))) return finished_flag
def test_save_load(self): param, _ = self.create_tf_parameters('param') param.init() var_val = [self.sess.run(var) for var in param('tf_var_list')] param_other, _ = self.create_tf_parameters(name='other_param') param_other.init() for i in range(10): param.save(sess=self.sess, save_path=GlobalConfig().DEFAULT_LOG_PATH + '/model', global_step=i) if tf.get_default_session(): sess = tf.get_default_session() sess.__exit__(None, None, None) tf.reset_default_graph() print('set tf device as {}'.format(self.default_id)) self.sess = create_new_tf_session(cuda_device=self.default_id) param2, _ = self.create_tf_parameters('param') param2.init() param2.load(path_to_model=GlobalConfig().DEFAULT_LOG_PATH + '/model', global_step=9) for var1, var2 in zip(var_val, param2('tf_var_list')): self.assertTrue(np.equal(var1, self.sess.run(var2)).all())
def test_tf_param(self): a, _ = self.create_ph('test') for i in range(5): a.save(save_path=GlobalConfig().DEFAULT_LOG_PATH + '/test_placehoder_input', global_step=i, name='a') file = glob.glob(GlobalConfig().DEFAULT_LOG_PATH + '/test_placehoder_input/a*.meta') self.assertTrue(len(file) == 5) b, _ = self.create_ph('b') b.copy_from(obj=a) self.assert_var_list_equal(a.parameters('tf_var_list'), b.parameters('tf_var_list')) a.parameters.init() self.assert_var_list_at_least_not_equal(a.parameters('tf_var_list'), b.parameters('tf_var_list')) a.load(path_to_model=GlobalConfig().DEFAULT_LOG_PATH + '/test_placehoder_input', global_step=4, model_name='a') self.assert_var_list_equal(a.parameters('tf_var_list'), b.parameters('tf_var_list'))
def test_init(self): dqn, locals = self.create_dqn() env = locals['env'] env_spec = locals['env_spec'] dqn.init() st = env.reset() a = TransitionData(env_spec) for i in range(100): ac = dqn.predict(obs=st, sess=self.sess, batch_flag=False) st_new, re, done, _ = env.step(action=ac) a.append(state=st, new_state=st_new, action=ac, done=done, reward=re) st = st_new dqn.append_to_memory(a) new_dqn, _ = self.create_dqn(name='new_dqn') new_dqn.copy_from(dqn) self.assert_var_list_id_no_equal(dqn.q_value_func.parameters('tf_var_list'), new_dqn.q_value_func.parameters('tf_var_list')) self.assert_var_list_id_no_equal(dqn.target_q_value_func.parameters('tf_var_list'), new_dqn.target_q_value_func.parameters('tf_var_list')) self.assert_var_list_equal(dqn.q_value_func.parameters('tf_var_list'), new_dqn.q_value_func.parameters('tf_var_list')) self.assert_var_list_equal(dqn.target_q_value_func.parameters('tf_var_list'), new_dqn.target_q_value_func.parameters('tf_var_list')) dqn.save(save_path=GlobalConfig().DEFAULT_LOG_PATH + '/dqn_test', global_step=0, name=dqn.name) for i in range(10): print(dqn.train(batch_data=a, train_iter=10, sess=None, update_target=True)) print(dqn.train(batch_data=None, train_iter=10, sess=None, update_target=True)) self.assert_var_list_at_least_not_equal(dqn.q_value_func.parameters('tf_var_list'), new_dqn.q_value_func.parameters('tf_var_list')) self.assert_var_list_at_least_not_equal(dqn.target_q_value_func.parameters('tf_var_list'), new_dqn.target_q_value_func.parameters('tf_var_list')) dqn.load(path_to_model=GlobalConfig().DEFAULT_LOG_PATH + '/dqn_test', model_name=dqn.name, global_step=0) self.assert_var_list_equal(dqn.q_value_func.parameters('tf_var_list'), new_dqn.q_value_func.parameters('tf_var_list')) self.assert_var_list_equal(dqn.target_q_value_func.parameters('tf_var_list'), new_dqn.target_q_value_func.parameters('tf_var_list')) for i in range(10): self.sess.run(dqn.update_target_q_value_func_op, feed_dict=dqn.parameters.return_tf_parameter_feed_dict()) var1 = self.sess.run(dqn.q_value_func.parameters('tf_var_list')) var2 = self.sess.run(dqn.target_q_value_func.parameters('tf_var_list')) import numpy as np total_diff = 0.0 for v1, v2 in zip(var1, var2): total_diff += np.mean(np.abs(np.array(v1) - np.array(v2))) print('update target, difference mean', total_diff)
def setUp(self): BaseTestCase.setUp(self) try: shutil.rmtree(GlobalConfig().DEFAULT_LOG_PATH) except FileNotFoundError: pass os.makedirs(GlobalConfig().DEFAULT_LOG_PATH) self.assertFalse(ConsoleLogger().inited_flag) self.assertFalse(Logger().inited_flag)
class Basic(object): """ Basic class within the whole framework""" STATUS_LIST = GlobalConfig().DEFAULT_BASIC_STATUS_LIST INIT_STATUS = GlobalConfig().DEFAULT_BASIC_INIT_STATUS required_key_dict = () allow_duplicate_name = False def __init__(self, name: str, status=None): """ Init a new Basic instance. :param name: name of the object, can be determined to generate log path, handle tensorflow name scope etc. :type name: str :param status: A status instance :py:class:`~baconian.core.status.Status` to indicate the status of the object :type status: Status """ if not status: self._status = Status(self) else: self._status = status self._name = name register_name_globally( name=name, obj=self ) # all instances that inherit Basic will register globally def init(self, *args, **kwargs): """Initialize the object""" raise NotImplementedError def get_status(self) -> dict: """ Return the object's status, a dictionary.""" return self._status.get_status() def set_status(self, val): """ Set the object's status.""" self._status.set_status(val) @property def name(self): """ The name(id) of object, a string.""" return self._name @property def status_list(self): """ Status list of the object, ('TRAIN', 'TEST').""" return self.STATUS_LIST def save(self, *args, **kwargs): """ Save the parameters in training checkpoints.""" raise NotImplementedError def load(self, *args, **kwargs): """ Load the parameters from training checkpoints.""" raise NotImplementedError
def LogSetup(self): Logger().init( config_or_config_dict=GlobalConfig().DEFAULT_LOG_CONFIG_DICT, log_path=GlobalConfig().DEFAULT_LOG_PATH, log_level=GlobalConfig().DEFAULT_LOG_LEVEL) ConsoleLogger().init(logger_name='console_logger', to_file_flag=True, level=GlobalConfig().DEFAULT_LOG_LEVEL, to_file_name=os.path.join(Logger().log_dir, 'console.log')) self.assertTrue(ConsoleLogger().inited_flag) self.assertTrue(Logger().inited_flag)
def single_exp_runner(task_fn, auto_choose_gpu_flag=False, gpu_id: int = 0, seed=None, del_if_log_path_existed=False, keep_session=False, **task_fn_kwargs): """ :param task_fn: task function defined bu users :type task_fn: method :param auto_choose_gpu_flag: auto choose gpu, default False :type auto_choose_gpu_flag: bool :param gpu_id: gpu id, default 0 :type gpu_id: int :param seed: seed generated by system time :type seed: int :param del_if_log_path_existed:delete obsolete log file path if existed, by default False :type del_if_log_path_existed: bool :param task_fn_kwargs: :type task_fn_kwargs: :param keep_session: Whether to keep default session & graph :type keep_session: :return: :rtype: """ os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID" if auto_choose_gpu_flag is True: DEVICE_ID_LIST = Gpu.getFirstAvailable() DEVICE_ID = DEVICE_ID_LIST[0] os.environ["CUDA_VISIBLE_DEVICES"] = str(DEVICE_ID) else: os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) if not seed: seed = int(round(time.time() * 1000)) % (2**32 - 1) _reset_global_seed(seed, keep_session) print("create log path at {}".format(GlobalConfig().DEFAULT_LOG_PATH), flush=True) file.create_path(path=GlobalConfig().DEFAULT_LOG_PATH, del_if_existed=del_if_log_path_existed) Logger().init(config_or_config_dict=dict(), log_path=GlobalConfig().DEFAULT_LOG_PATH, log_level=GlobalConfig().DEFAULT_LOG_LEVEL) ConsoleLogger().init( to_file_flag=GlobalConfig().DEFAULT_WRITE_CONSOLE_LOG_TO_FILE_FLAG, to_file_name=os.path.join( GlobalConfig().DEFAULT_LOG_PATH, GlobalConfig().DEFAULT_CONSOLE_LOG_FILE_NAME), level=GlobalConfig().DEFAULT_LOG_LEVEL, logger_name=GlobalConfig().DEFAULT_CONSOLE_LOGGER_NAME) task_fn(**task_fn_kwargs)
def _save_all_obj_final_status(self): final_status = dict() for obj_name, obj in get_all()['_global_name_dict'].items(): if hasattr(obj, 'get_status') and callable( getattr(obj, 'get_status')): tmp_dict = dict() tmp_dict[obj_name] = dict() for st in obj.STATUS_LIST: obj.set_status(st) tmp_dict[obj_name][st] = obj.get_status() final_status = {**final_status, **tmp_dict} ConsoleLogger().print( 'info', 'save final_status into {}'.format( os.path.join(self._record_file_log_dir))) self.out_to_file(file_path=os.path.join(self._record_file_log_dir), content=final_status, force_new=True, file_name='final_status.json') ConsoleLogger().print( 'info', 'save global_config into {}'.format( os.path.join(self._record_file_log_dir))) self.out_to_file(file_path=os.path.join(self._record_file_log_dir), content=GlobalConfig().return_all_as_dict(), force_new=True, file_name='global_config.json')
def test_duplicate_exp(self): def func(): GlobalConfig().set( 'DEFAULT_EXPERIMENT_END_POINT', dict(TOTAL_AGENT_TRAIN_SAMPLE_COUNT=500, TOTAL_AGENT_TEST_SAMPLE_COUNT=None, TOTAL_AGENT_UPDATE_COUNT=None)) dqn, locals = self.create_dqn() env_spec = locals['env_spec'] env = locals['env'] agent = self.create_agent(env=locals['env'], algo=dqn, name='agent', eps=self.create_eps(env_spec)[0], env_spec=env_spec)[0] exp = self.create_exp(name='model_fre', env=env, agent=agent) exp.run() base_path = GlobalConfig().DEFAULT_LOG_PATH duplicate_exp_runner(2, func, auto_choose_gpu_flag=False, gpu_id=0) self.assertTrue(os.path.isdir(base_path)) self.assertTrue(os.path.isdir(os.path.join(base_path, 'exp_0'))) self.assertTrue(os.path.isdir(os.path.join(base_path, 'exp_1'))) self.assertTrue( os.path.isdir(os.path.join(base_path, 'exp_0', 'record'))) self.assertTrue( os.path.isdir(os.path.join(base_path, 'exp_1', 'record'))) self.assertTrue( os.path.isfile(os.path.join(base_path, 'exp_0', 'console.log'))) self.assertTrue( os.path.isfile(os.path.join(base_path, 'exp_1', 'console.log')))
def func(self, creat_func=None): GlobalConfig().set( 'DEFAULT_EXPERIMENT_END_POINT', dict(TOTAL_AGENT_TRAIN_SAMPLE_COUNT=500, TOTAL_AGENT_TEST_SAMPLE_COUNT=None, TOTAL_AGENT_UPDATE_COUNT=None)) if not creat_func: algo, locals = self.create_dqn() else: algo, locals = creat_func() env_spec = locals['env_spec'] env = locals['env'] agent = self.create_agent(env=locals['env'], algo=algo, name='agent', eps=self.create_eps(env_spec)[0], env_spec=env_spec)[0] flow = None from baconian.algo.dyna import Dyna if isinstance(algo, Dyna): flow = self.create_dyna_flow(agent=agent, env=env)[0] exp = self.create_exp(name='model_free', env=env, agent=agent, flow=flow) exp.run() self.assertEqual(exp.TOTAL_AGENT_TEST_SAMPLE_COUNT(), exp.TOTAL_ENV_STEP_TEST_SAMPLE_COUNT()) self.assertEqual(exp.TOTAL_AGENT_TRAIN_SAMPLE_COUNT(), exp.TOTAL_ENV_STEP_TRAIN_SAMPLE_COUNT(), 500)
def init(self, to_file_flag, level: str, to_file_name: str = None, logger_name: str = 'console_logger'): if self.inited_flag is True: return self.name = logger_name if level not in self.ALLOWED_LOG_LEVEL: raise ValueError('Wrong log level use {} instead'.format( self.ALLOWED_LOG_LEVEL)) self.logger = logging.getLogger(self.name) self.logger.setLevel(getattr(logging, level)) for handler in self.logger.root.handlers[:] + self.logger.handlers[:]: self.logger.removeHandler(handler) self.logger.root.removeHandler(handler) self.logger.addHandler(logging.StreamHandler()) if to_file_flag is True: self.logger.addHandler(logging.FileHandler(filename=to_file_name)) for handler in self.logger.root.handlers[:] + self.logger.handlers[:]: handler.setFormatter(fmt=logging.Formatter( fmt=GlobalConfig().DEFAULT_LOGGING_FORMAT)) handler.setLevel(getattr(logging, level)) self.inited_flag = True
class ConstantActionPolicy(DeterministicPolicy): required_key_dict = DictConfig.load_json(file_path=GlobalConfig().DEFAULT_CONSTANT_ACTION_POLICY_REQUIRED_KEY_LIST) def __init__(self, env_spec: EnvSpec, config_or_config_dict: (DictConfig, dict), name='policy'): config = construct_dict_config(config_or_config_dict, self) parameters = Parameters(parameters=dict(), source_config=config) assert env_spec.action_space.contains(x=config('ACTION_VALUE')) super().__init__(env_spec, parameters, name) self.config = config def forward(self, *args, **kwargs): action = self.env_spec.action_space.unflatten(self.parameters('ACTION_VALUE')) assert self.env_spec.action_space.contains(x=action) return action def copy_from(self, obj) -> bool: super().copy_from(obj) self.parameters.copy_from(obj.parameters) return True def make_copy(self, *args, **kwargs): return ConstantActionPolicy(env_spec=self.env_spec, config_or_config_dict=deepcopy(self.config), *args, **kwargs)
def __init__(self, tf_var_list: list, rest_parameters: dict, name: str, max_to_keep=GlobalConfig().DEFAULT_MAX_TF_SAVER_KEEP, default_save_type='tf', source_config=None, to_scheduler_param_tuple: list = None, save_rest_param_flag=True, to_ph_parameter_dict: dict = None, require_snapshot=False): super(ParametersWithTensorflowVariable, self).__init__(parameters=rest_parameters, name=name, to_scheduler_param_tuple=to_scheduler_param_tuple, source_config=source_config) self._tf_var_list = tf_var_list self.snapshot_var = [] self.save_snapshot_op = [] self.load_snapshot_op = [] self.saver = None self.max_to_keep = max_to_keep self.require_snapshot = require_snapshot self.default_checkpoint_type = default_save_type self.save_rest_param_flag = save_rest_param_flag if default_save_type != 'tf': raise NotImplementedError('only support saving tf') self._registered_tf_ph_dict = dict() if to_ph_parameter_dict: for key, val in to_ph_parameter_dict.items(): self.to_tf_ph(key=key, ph=val)
def save(self, global_step, save_path=None, name=None, **kwargs): save_path = save_path if save_path else GlobalConfig().DEFAULT_MODEL_CHECKPOINT_PATH name = name if name else self.name self._dynamics_model.save(save_path=save_path, global_step=global_step, name=name, **kwargs) self.policy.save(save_path=save_path, global_step=global_step, name=name, **kwargs) return dict(check_point_save_path=save_path, check_point_save_global_step=global_step, check_point_save_name=name)
def test_global_config(self): GlobalConfig().set_new_config(config_dict=dict(DEFAULT_BASIC_INIT_STATUS='test')) assert GlobalConfig().DEFAULT_BASIC_INIT_STATUS == 'test' GlobalConfig().freeze_flag = True try: GlobalConfig().set_new_config(config_dict=dict(DEFAULT_BASIC_INIT_STATUS='test')) except AttemptToChangeFreezeGlobalConfigError: pass else: raise TypeError try: GlobalConfig().set('DEFAULT_LOG_PATH', 'tmp') except AttemptToChangeFreezeGlobalConfigError: pass else: raise TypeError try: GlobalConfig().DEFAULT_LOG_PATH = 'tmp' except AttemptToChangeFreezeGlobalConfigError: pass else: raise TypeError GlobalConfig().unfreeze()
def _exit(self): sess = tf.get_default_session() if sess: sess.__exit__(None, None, None) tf.reset_default_graph() reset_global_status_collect() reset_logging() reset_global_var() GlobalConfig().unfreeze()
def register_name_globally(name: str, obj): if name in get_all()['_global_name_dict'] and not id(obj) == id( get_all()['_global_name_dict'][name]) and obj.allow_duplicate_name is False and \ get_all()['_global_name_dict'][ name].allow_duplicate_name is False and GlobalConfig().DEFAULT_TURN_OFF_GLOBAL_NAME_FLAG is False: raise GlobalNameExistedError( 'name : {} is existed with object: {}'.format(name, get_all()['_global_name_dict'][name])) else: get_all()['_global_name_dict'][name] = obj
def _exit(self): """ Exit the experiment, reset global configurations and logging module.""" sess = tf.get_default_session() if sess: sess.__exit__(None, None, None) tf.reset_default_graph() reset_global_status_collect() reset_logging() reset_global_var() GlobalConfig().unfreeze()
class Basic(object): """ Basic class within the whole framework""" STATUS_LIST = GlobalConfig().DEFAULT_BASIC_STATUS_LIST INIT_STATUS = GlobalConfig().DEFAULT_BASIC_INIT_STATUS required_key_dict = () allow_duplicate_name = False def __init__(self, name: str, status=None): """ Init a new Basic instance. :param name: a string for the name of the object, can be determined to generate log path, handle tensorflow name scope etc. :param status: A status instance :py:class:`~baconian.core.status.Status` to indicate the status of the """ if not status: self._status = Status(self) else: self._status = status self._name = name register_name_globally(name=name, obj=self) def init(self, *args, **kwargs): raise NotImplementedError def get_status(self) -> dict: return self._status.get_status() def set_status(self, val): self._status.set_status(val) @property def name(self): return self._name @property def status_list(self): return self.STATUS_LIST def save(self, *args, **kwargs): raise NotImplementedError def load(self, *args, **kwargs): raise NotImplementedError
def run(self): GlobalConfig().freeze() self.init() self.set_status('RUNNING') res = self.flow.launch() if res is False: self.set_status('CORRUPTED') else: self.set_status('FINISHED') self._exit()
def mountaincar_task_fn(): exp_config = MOUNTAINCAR_BENCHMARK_CONFIG_DICT GlobalConfig().set('DEFAULT_EXPERIMENT_END_POINT', exp_config['DEFAULT_EXPERIMENT_END_POINT']) env = make('MountainCar-v0') name = 'benchmark' env_spec = EnvSpec(obs_space=env.observation_space, action_space=env.action_space) mlp_q = MLPQValueFunction(env_spec=env_spec, name_scope=name + '_mlp_q', name=name + '_mlp_q', **exp_config['MLPQValueFunction']) dqn = DQN(env_spec=env_spec, name=name + '_dqn', value_func=mlp_q, **exp_config['DQN']) agent = Agent(env=env, env_spec=env_spec, algo=dqn, name=name + '_agent', exploration_strategy=EpsilonGreedy(action_space=env_spec.action_space, prob_scheduler=LinearScheduler( t_fn=lambda: get_global_status_collect()( 'TOTAL_AGENT_TRAIN_SAMPLE_COUNT'), **exp_config['EpsilonGreedy']['LinearScheduler']), **exp_config['EpsilonGreedy']['config_or_config_dict'])) flow = TrainTestFlow(train_sample_count_func=lambda: get_global_status_collect()('TOTAL_AGENT_TRAIN_SAMPLE_COUNT'), config_or_config_dict=exp_config['TrainTestFlow']['config_or_config_dict'], func_dict={ 'test': {'func': agent.test, 'args': list(), 'kwargs': dict(sample_count=exp_config['TrainTestFlow']['TEST_SAMPLES_COUNT']), }, 'train': {'func': agent.train, 'args': list(), 'kwargs': dict(), }, 'sample': {'func': agent.sample, 'args': list(), 'kwargs': dict(sample_count=exp_config['TrainTestFlow']['TRAIN_SAMPLES_COUNT'], env=agent.env, in_which_status='TRAIN', store_flag=True), }, }) experiment = Experiment( tuner=None, env=env, agent=agent, flow=flow, name=name ) experiment.run()
def run(self): """ Run the experiment, and set status to 'RUNNING'.""" GlobalConfig().freeze() self.init() self.set_status('RUNNING') res = self.flow.launch() if res is False: self.set_status('CORRUPTED') else: self.set_status('FINISHED') self._exit()
def launch(self) -> bool: """ Launch the flow until it finished or catch a system-allowed errors (e.g., out of GPU memory, to ensure the log will be saved safely). :return: Boolean, True for the flow correctly executed and finished. """ try: return self._launch() except GlobalConfig().DEFAULT_ALLOWED_EXCEPTION_OR_ERROR_LIST as e: ConsoleLogger().print('error', 'error {} occurred'.format(e)) return False
def save(self, global_step, save_path=None, name=None, **kwargs): save_path = save_path if save_path else GlobalConfig().DEFAULT_MODEL_CHECKPOINT_PATH name = name if name else self.name self.model_free_algo.save(global_step=global_step, name=None, save_path=os.path.join(save_path, self.model_free_algo.name)) self.dynamics_env.save(global_step=global_step, name=None, save_path=os.path.join(save_path, self.dynamics_env.name)) return dict(check_point_save_path=save_path, check_point_save_global_step=global_step, check_point_save_name=name)
def save(self, global_step, save_path=None, name=None, **kwargs): save_path = save_path if save_path else GlobalConfig().DEFAULT_MODEL_CHECKPOINT_PATH name = name if name else self.name sess = kwargs['sess'] if 'sess' in kwargs else None self.parameters.save(save_path=save_path, global_step=global_step, sess=sess, name=name) ConsoleLogger().print('info', 'model: {}, global step: {}, saved at {}-{}'.format(name, global_step, save_path, global_step))
def setUp(self): BaseTestCase.setUp(self) try: shutil.rmtree(GlobalConfig().DEFAULT_LOG_PATH) except FileNotFoundError: pass # os.makedirs(GlobalConfig().DEFAULT_LOG_PATH) # self.assertFalse(ConsoleLogger().inited_flag) # self.assertFalse(Logger().inited_flag) Logger().init(config_or_config_dict=GlobalConfig().DEFAULT_LOG_CONFIG_DICT, log_path=GlobalConfig().DEFAULT_LOG_PATH, log_level=GlobalConfig().DEFAULT_LOG_LEVEL) ConsoleLogger().init(logger_name='console_logger', to_file_flag=True, level=GlobalConfig().DEFAULT_LOG_LEVEL, to_file_name=os.path.join(Logger().log_dir, 'console.log')) self.assertTrue(ConsoleLogger().inited_flag) self.assertTrue(Logger().inited_flag)
def test_save_load_with_dqn(self): dqn, locals = self.create_dqn() dqn.init() for i in range(5): dqn.save(save_path=GlobalConfig().DEFAULT_LOG_PATH + '/test_placehoder_input', global_step=i, name='dqn') file = glob.glob(GlobalConfig().DEFAULT_LOG_PATH + '/test_placehoder_input/dqn*.meta') self.assertTrue(len(file) == 5) dqn2, _ = self.create_dqn(name='dqn_2') dqn2.copy_from(dqn) self.assert_var_list_equal(dqn.parameters('tf_var_list'), dqn2.parameters('tf_var_list')) self.assert_var_list_equal(dqn.q_value_func.parameters('tf_var_list'), dqn2.q_value_func.parameters('tf_var_list')) self.assert_var_list_equal( dqn.target_q_value_func.parameters('tf_var_list'), dqn2.target_q_value_func.parameters('tf_var_list')) dqn.init() self.assert_var_list_at_least_not_equal( dqn.q_value_func.parameters('tf_var_list'), dqn2.q_value_func.parameters('tf_var_list')) self.assert_var_list_at_least_not_equal( dqn.target_q_value_func.parameters('tf_var_list'), dqn2.target_q_value_func.parameters('tf_var_list')) dqn.load(path_to_model=GlobalConfig().DEFAULT_LOG_PATH + '/test_placehoder_input', global_step=4, model_name='dqn') self.assert_var_list_equal(dqn.parameters('tf_var_list'), dqn2.parameters('tf_var_list')) self.assert_var_list_equal(dqn.q_value_func.parameters('tf_var_list'), dqn2.q_value_func.parameters('tf_var_list')) self.assert_var_list_equal( dqn.target_q_value_func.parameters('tf_var_list'), dqn2.target_q_value_func.parameters('tf_var_list'))