Esempio n. 1
0
    def test_basic(self):
        a, locals = self.create_parameters()
        a.save(save_path=GlobalConfig().DEFAULT_LOG_PATH + '/param_path',
               name=a.name,
               global_step=0)
        or_val = a._source_config.config_dict['var1']
        or_param = a('param3').copy()
        a._source_config.config_dict['var1'] = 100
        a._parameters['param3'] = 1000
        self.assertNotEqual(a._source_config.config_dict['var1'], or_val)
        self.assertFalse(np.equal(a._parameters['param3'], or_param).all())
        a.load(load_path=GlobalConfig().DEFAULT_LOG_PATH + '/param_path',
               name=a.name,
               global_step=0)
        self.assertEqual(a._source_config.config_dict['var1'], or_val)
        self.assertTrue(np.equal(a._parameters['param3'], or_param).all())
        b, _ = self.create_parameters()
        b.copy_from(a)

        for key in a._source_config.required_key_dict.keys():
            if isinstance(a[key], np.ndarray):
                self.assertTrue(np.equal(a[key], b[key]).all())
            else:
                self.assertEqual(id(a[key]), id(b[key]))
                self.assertEqual(id(a(key)), id(b(key)))
        for key in a._parameters.keys():
            if isinstance(a[key], np.ndarray):
                self.assertTrue(np.equal(a[key], b[key]).all())
            else:
                self.assertEqual(a[key], b[key])
                self.assertEqual(a(key), b(key))
Esempio n. 2
0
            def func(algo=algo, locals=locals):
                GlobalConfig().set(
                    'DEFAULT_EXPERIMENT_END_POINT',
                    dict(TOTAL_AGENT_TRAIN_SAMPLE_COUNT=500,
                         TOTAL_AGENT_TEST_SAMPLE_COUNT=None,
                         TOTAL_AGENT_UPDATE_COUNT=None))
                if not algo:
                    algo, locals = self.create_dqn()
                env_spec = locals['env_spec']
                env = locals['env']
                agent = self.create_agent(env=locals['env'],
                                          algo=algo,
                                          name='agent',
                                          eps=self.create_eps(env_spec)[0],
                                          env_spec=env_spec)[0]

                exp = self.create_exp(name='model_free', env=env, agent=agent)
                algo.parameters.set_scheduler(
                    param_key='LEARNING_RATE',
                    to_tf_ph_flag=True,
                    scheduler=LinearSchedule(
                        t_fn=exp.TOTAL_ENV_STEP_TRAIN_SAMPLE_COUNT,
                        schedule_timesteps=GlobalConfig(
                        ).DEFAULT_EXPERIMENT_END_POINT[
                            'TOTAL_AGENT_TRAIN_SAMPLE_COUNT'],
                        final_p=0.0001,
                        initial_p=0.01))
                exp.run()
                self.assertEqual(exp.TOTAL_AGENT_TEST_SAMPLE_COUNT(),
                                 exp.TOTAL_ENV_STEP_TEST_SAMPLE_COUNT())
                self.assertEqual(exp.TOTAL_AGENT_TRAIN_SAMPLE_COUNT(),
                                 exp.TOTAL_ENV_STEP_TRAIN_SAMPLE_COUNT(), 500)
    def _is_ended(self):
        """

        :return: True if an experiment is ended
        :rtype: bool
        """
        key_founded_flag = False
        finished_flag = False
        for key in GlobalConfig().DEFAULT_EXPERIMENT_END_POINT:
            if GlobalConfig().DEFAULT_EXPERIMENT_END_POINT[key] is not None:
                key_founded_flag = True
                if get_global_status_collect()(key) >= GlobalConfig(
                ).DEFAULT_EXPERIMENT_END_POINT[key]:
                    ConsoleLogger().print(
                        'info',
                        'pipeline ended because {}: {} >= end point value {}'.
                        format(
                            key,
                            get_global_status_collect()(key),
                            GlobalConfig().DEFAULT_EXPERIMENT_END_POINT[key]))
                    finished_flag = True
        if key_founded_flag is False:
            ConsoleLogger().print(
                'warning',
                '{} in experiment_end_point is not registered with global status collector: {}, experiment may not end'
                .format(GlobalConfig().DEFAULT_EXPERIMENT_END_POINT,
                        list(get_global_status_collect()().keys())))
        return finished_flag
    def test_save_load(self):

        param, _ = self.create_tf_parameters('param')

        param.init()
        var_val = [self.sess.run(var) for var in param('tf_var_list')]

        param_other, _ = self.create_tf_parameters(name='other_param')
        param_other.init()

        for i in range(10):
            param.save(sess=self.sess,
                       save_path=GlobalConfig().DEFAULT_LOG_PATH + '/model',
                       global_step=i)

        if tf.get_default_session():
            sess = tf.get_default_session()
            sess.__exit__(None, None, None)
        tf.reset_default_graph()
        print('set tf device as {}'.format(self.default_id))
        self.sess = create_new_tf_session(cuda_device=self.default_id)

        param2, _ = self.create_tf_parameters('param')
        param2.init()
        param2.load(path_to_model=GlobalConfig().DEFAULT_LOG_PATH + '/model',
                    global_step=9)
        for var1, var2 in zip(var_val, param2('tf_var_list')):
            self.assertTrue(np.equal(var1, self.sess.run(var2)).all())
    def test_tf_param(self):
        a, _ = self.create_ph('test')
        for i in range(5):
            a.save(save_path=GlobalConfig().DEFAULT_LOG_PATH +
                   '/test_placehoder_input',
                   global_step=i,
                   name='a')
        file = glob.glob(GlobalConfig().DEFAULT_LOG_PATH +
                         '/test_placehoder_input/a*.meta')
        self.assertTrue(len(file) == 5)
        b, _ = self.create_ph('b')
        b.copy_from(obj=a)
        self.assert_var_list_equal(a.parameters('tf_var_list'),
                                   b.parameters('tf_var_list'))

        a.parameters.init()
        self.assert_var_list_at_least_not_equal(a.parameters('tf_var_list'),
                                                b.parameters('tf_var_list'))

        a.load(path_to_model=GlobalConfig().DEFAULT_LOG_PATH +
               '/test_placehoder_input',
               global_step=4,
               model_name='a')

        self.assert_var_list_equal(a.parameters('tf_var_list'),
                                   b.parameters('tf_var_list'))
Esempio n. 6
0
    def test_init(self):
        dqn, locals = self.create_dqn()
        env = locals['env']
        env_spec = locals['env_spec']
        dqn.init()
        st = env.reset()
        a = TransitionData(env_spec)
        for i in range(100):
            ac = dqn.predict(obs=st, sess=self.sess, batch_flag=False)
            st_new, re, done, _ = env.step(action=ac)
            a.append(state=st, new_state=st_new, action=ac, done=done, reward=re)
            st = st_new
            dqn.append_to_memory(a)
        new_dqn, _ = self.create_dqn(name='new_dqn')
        new_dqn.copy_from(dqn)
        self.assert_var_list_id_no_equal(dqn.q_value_func.parameters('tf_var_list'),
                                         new_dqn.q_value_func.parameters('tf_var_list'))
        self.assert_var_list_id_no_equal(dqn.target_q_value_func.parameters('tf_var_list'),
                                         new_dqn.target_q_value_func.parameters('tf_var_list'))

        self.assert_var_list_equal(dqn.q_value_func.parameters('tf_var_list'),
                                   new_dqn.q_value_func.parameters('tf_var_list'))
        self.assert_var_list_equal(dqn.target_q_value_func.parameters('tf_var_list'),
                                   new_dqn.target_q_value_func.parameters('tf_var_list'))

        dqn.save(save_path=GlobalConfig().DEFAULT_LOG_PATH + '/dqn_test',
                 global_step=0,
                 name=dqn.name)

        for i in range(10):
            print(dqn.train(batch_data=a, train_iter=10, sess=None, update_target=True))
            print(dqn.train(batch_data=None, train_iter=10, sess=None, update_target=True))

        self.assert_var_list_at_least_not_equal(dqn.q_value_func.parameters('tf_var_list'),
                                                new_dqn.q_value_func.parameters('tf_var_list'))

        self.assert_var_list_at_least_not_equal(dqn.target_q_value_func.parameters('tf_var_list'),
                                                new_dqn.target_q_value_func.parameters('tf_var_list'))

        dqn.load(path_to_model=GlobalConfig().DEFAULT_LOG_PATH + '/dqn_test',
                 model_name=dqn.name,
                 global_step=0)

        self.assert_var_list_equal(dqn.q_value_func.parameters('tf_var_list'),
                                   new_dqn.q_value_func.parameters('tf_var_list'))
        self.assert_var_list_equal(dqn.target_q_value_func.parameters('tf_var_list'),
                                   new_dqn.target_q_value_func.parameters('tf_var_list'))
        for i in range(10):
            self.sess.run(dqn.update_target_q_value_func_op,
                          feed_dict=dqn.parameters.return_tf_parameter_feed_dict())
            var1 = self.sess.run(dqn.q_value_func.parameters('tf_var_list'))
            var2 = self.sess.run(dqn.target_q_value_func.parameters('tf_var_list'))
            import numpy as np
            total_diff = 0.0
            for v1, v2 in zip(var1, var2):
                total_diff += np.mean(np.abs(np.array(v1) - np.array(v2)))
            print('update target, difference mean', total_diff)
Esempio n. 7
0
 def setUp(self):
     BaseTestCase.setUp(self)
     try:
         shutil.rmtree(GlobalConfig().DEFAULT_LOG_PATH)
     except FileNotFoundError:
         pass
     os.makedirs(GlobalConfig().DEFAULT_LOG_PATH)
     self.assertFalse(ConsoleLogger().inited_flag)
     self.assertFalse(Logger().inited_flag)
Esempio n. 8
0
class Basic(object):
    """ Basic class within the whole framework"""
    STATUS_LIST = GlobalConfig().DEFAULT_BASIC_STATUS_LIST
    INIT_STATUS = GlobalConfig().DEFAULT_BASIC_INIT_STATUS
    required_key_dict = ()
    allow_duplicate_name = False

    def __init__(self, name: str, status=None):
        """
        Init a new Basic instance.

        :param name: name of the object, can be determined to generate log path, handle tensorflow name scope etc.
        :type name: str
        :param status: A status instance :py:class:`~baconian.core.status.Status` to indicate the status of the object
        :type status: Status
        """

        if not status:
            self._status = Status(self)
        else:
            self._status = status
        self._name = name
        register_name_globally(
            name=name, obj=self
        )  # all instances that inherit Basic will register globally

    def init(self, *args, **kwargs):
        """Initialize the object"""
        raise NotImplementedError

    def get_status(self) -> dict:
        """ Return the object's status, a dictionary."""
        return self._status.get_status()

    def set_status(self, val):
        """ Set the object's status."""
        self._status.set_status(val)

    @property
    def name(self):
        """ The name(id) of object, a string."""
        return self._name

    @property
    def status_list(self):
        """ Status list of the object, ('TRAIN', 'TEST')."""
        return self.STATUS_LIST

    def save(self, *args, **kwargs):
        """ Save the parameters in training checkpoints."""
        raise NotImplementedError

    def load(self, *args, **kwargs):
        """ Load the parameters from training checkpoints."""
        raise NotImplementedError
Esempio n. 9
0
    def LogSetup(self):
        Logger().init(
            config_or_config_dict=GlobalConfig().DEFAULT_LOG_CONFIG_DICT,
            log_path=GlobalConfig().DEFAULT_LOG_PATH,
            log_level=GlobalConfig().DEFAULT_LOG_LEVEL)
        ConsoleLogger().init(logger_name='console_logger',
                             to_file_flag=True,
                             level=GlobalConfig().DEFAULT_LOG_LEVEL,
                             to_file_name=os.path.join(Logger().log_dir,
                                                       'console.log'))

        self.assertTrue(ConsoleLogger().inited_flag)
        self.assertTrue(Logger().inited_flag)
Esempio n. 10
0
def single_exp_runner(task_fn,
                      auto_choose_gpu_flag=False,
                      gpu_id: int = 0,
                      seed=None,
                      del_if_log_path_existed=False,
                      keep_session=False,
                      **task_fn_kwargs):
    """

    :param task_fn: task function defined bu users
    :type task_fn: method
    :param auto_choose_gpu_flag: auto choose gpu, default False
    :type auto_choose_gpu_flag: bool
    :param gpu_id: gpu id, default 0
    :type gpu_id: int
    :param seed: seed generated by system time
    :type seed: int
    :param del_if_log_path_existed:delete obsolete log file path if existed, by default False
    :type del_if_log_path_existed: bool
    :param task_fn_kwargs:
    :type task_fn_kwargs:
    :param keep_session: Whether to keep default session & graph
    :type keep_session:
    :return:
    :rtype:
    """
    os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
    if auto_choose_gpu_flag is True:
        DEVICE_ID_LIST = Gpu.getFirstAvailable()
        DEVICE_ID = DEVICE_ID_LIST[0]
        os.environ["CUDA_VISIBLE_DEVICES"] = str(DEVICE_ID)
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    if not seed:
        seed = int(round(time.time() * 1000)) % (2**32 - 1)
    _reset_global_seed(seed, keep_session)
    print("create log path at {}".format(GlobalConfig().DEFAULT_LOG_PATH),
          flush=True)

    file.create_path(path=GlobalConfig().DEFAULT_LOG_PATH,
                     del_if_existed=del_if_log_path_existed)
    Logger().init(config_or_config_dict=dict(),
                  log_path=GlobalConfig().DEFAULT_LOG_PATH,
                  log_level=GlobalConfig().DEFAULT_LOG_LEVEL)
    ConsoleLogger().init(
        to_file_flag=GlobalConfig().DEFAULT_WRITE_CONSOLE_LOG_TO_FILE_FLAG,
        to_file_name=os.path.join(
            GlobalConfig().DEFAULT_LOG_PATH,
            GlobalConfig().DEFAULT_CONSOLE_LOG_FILE_NAME),
        level=GlobalConfig().DEFAULT_LOG_LEVEL,
        logger_name=GlobalConfig().DEFAULT_CONSOLE_LOGGER_NAME)

    task_fn(**task_fn_kwargs)
Esempio n. 11
0
 def _save_all_obj_final_status(self):
     final_status = dict()
     for obj_name, obj in get_all()['_global_name_dict'].items():
         if hasattr(obj, 'get_status') and callable(
                 getattr(obj, 'get_status')):
             tmp_dict = dict()
             tmp_dict[obj_name] = dict()
             for st in obj.STATUS_LIST:
                 obj.set_status(st)
                 tmp_dict[obj_name][st] = obj.get_status()
             final_status = {**final_status, **tmp_dict}
     ConsoleLogger().print(
         'info', 'save final_status into {}'.format(
             os.path.join(self._record_file_log_dir)))
     self.out_to_file(file_path=os.path.join(self._record_file_log_dir),
                      content=final_status,
                      force_new=True,
                      file_name='final_status.json')
     ConsoleLogger().print(
         'info', 'save global_config into {}'.format(
             os.path.join(self._record_file_log_dir)))
     self.out_to_file(file_path=os.path.join(self._record_file_log_dir),
                      content=GlobalConfig().return_all_as_dict(),
                      force_new=True,
                      file_name='global_config.json')
Esempio n. 12
0
    def test_duplicate_exp(self):
        def func():
            GlobalConfig().set(
                'DEFAULT_EXPERIMENT_END_POINT',
                dict(TOTAL_AGENT_TRAIN_SAMPLE_COUNT=500,
                     TOTAL_AGENT_TEST_SAMPLE_COUNT=None,
                     TOTAL_AGENT_UPDATE_COUNT=None))
            dqn, locals = self.create_dqn()
            env_spec = locals['env_spec']
            env = locals['env']
            agent = self.create_agent(env=locals['env'],
                                      algo=dqn,
                                      name='agent',
                                      eps=self.create_eps(env_spec)[0],
                                      env_spec=env_spec)[0]
            exp = self.create_exp(name='model_fre', env=env, agent=agent)
            exp.run()

        base_path = GlobalConfig().DEFAULT_LOG_PATH
        duplicate_exp_runner(2, func, auto_choose_gpu_flag=False, gpu_id=0)
        self.assertTrue(os.path.isdir(base_path))
        self.assertTrue(os.path.isdir(os.path.join(base_path, 'exp_0')))
        self.assertTrue(os.path.isdir(os.path.join(base_path, 'exp_1')))
        self.assertTrue(
            os.path.isdir(os.path.join(base_path, 'exp_0', 'record')))
        self.assertTrue(
            os.path.isdir(os.path.join(base_path, 'exp_1', 'record')))
        self.assertTrue(
            os.path.isfile(os.path.join(base_path, 'exp_0', 'console.log')))
        self.assertTrue(
            os.path.isfile(os.path.join(base_path, 'exp_1', 'console.log')))
Esempio n. 13
0
 def func(self, creat_func=None):
     GlobalConfig().set(
         'DEFAULT_EXPERIMENT_END_POINT',
         dict(TOTAL_AGENT_TRAIN_SAMPLE_COUNT=500,
              TOTAL_AGENT_TEST_SAMPLE_COUNT=None,
              TOTAL_AGENT_UPDATE_COUNT=None))
     if not creat_func:
         algo, locals = self.create_dqn()
     else:
         algo, locals = creat_func()
     env_spec = locals['env_spec']
     env = locals['env']
     agent = self.create_agent(env=locals['env'],
                               algo=algo,
                               name='agent',
                               eps=self.create_eps(env_spec)[0],
                               env_spec=env_spec)[0]
     flow = None
     from baconian.algo.dyna import Dyna
     if isinstance(algo, Dyna):
         flow = self.create_dyna_flow(agent=agent, env=env)[0]
     exp = self.create_exp(name='model_free',
                           env=env,
                           agent=agent,
                           flow=flow)
     exp.run()
     self.assertEqual(exp.TOTAL_AGENT_TEST_SAMPLE_COUNT(),
                      exp.TOTAL_ENV_STEP_TEST_SAMPLE_COUNT())
     self.assertEqual(exp.TOTAL_AGENT_TRAIN_SAMPLE_COUNT(),
                      exp.TOTAL_ENV_STEP_TRAIN_SAMPLE_COUNT(), 500)
Esempio n. 14
0
    def init(self,
             to_file_flag,
             level: str,
             to_file_name: str = None,
             logger_name: str = 'console_logger'):
        if self.inited_flag is True:
            return
        self.name = logger_name
        if level not in self.ALLOWED_LOG_LEVEL:
            raise ValueError('Wrong log level use {} instead'.format(
                self.ALLOWED_LOG_LEVEL))
        self.logger = logging.getLogger(self.name)
        self.logger.setLevel(getattr(logging, level))

        for handler in self.logger.root.handlers[:] + self.logger.handlers[:]:
            self.logger.removeHandler(handler)
            self.logger.root.removeHandler(handler)

        self.logger.addHandler(logging.StreamHandler())

        if to_file_flag is True:
            self.logger.addHandler(logging.FileHandler(filename=to_file_name))
        for handler in self.logger.root.handlers[:] + self.logger.handlers[:]:
            handler.setFormatter(fmt=logging.Formatter(
                fmt=GlobalConfig().DEFAULT_LOGGING_FORMAT))
            handler.setLevel(getattr(logging, level))

        self.inited_flag = True
Esempio n. 15
0
class ConstantActionPolicy(DeterministicPolicy):
    required_key_dict = DictConfig.load_json(file_path=GlobalConfig().DEFAULT_CONSTANT_ACTION_POLICY_REQUIRED_KEY_LIST)

    def __init__(self, env_spec: EnvSpec, config_or_config_dict: (DictConfig, dict), name='policy'):
        config = construct_dict_config(config_or_config_dict, self)
        parameters = Parameters(parameters=dict(),
                                source_config=config)
        assert env_spec.action_space.contains(x=config('ACTION_VALUE'))
        super().__init__(env_spec, parameters, name)
        self.config = config

    def forward(self, *args, **kwargs):
        action = self.env_spec.action_space.unflatten(self.parameters('ACTION_VALUE'))
        assert self.env_spec.action_space.contains(x=action)
        return action

    def copy_from(self, obj) -> bool:
        super().copy_from(obj)
        self.parameters.copy_from(obj.parameters)
        return True

    def make_copy(self, *args, **kwargs):
        return ConstantActionPolicy(env_spec=self.env_spec,
                                    config_or_config_dict=deepcopy(self.config),
                                    *args, **kwargs)
Esempio n. 16
0
 def __init__(self,
              tf_var_list: list,
              rest_parameters: dict,
              name: str,
              max_to_keep=GlobalConfig().DEFAULT_MAX_TF_SAVER_KEEP,
              default_save_type='tf',
              source_config=None,
              to_scheduler_param_tuple: list = None,
              save_rest_param_flag=True,
              to_ph_parameter_dict: dict = None,
              require_snapshot=False):
     super(ParametersWithTensorflowVariable,
           self).__init__(parameters=rest_parameters,
                          name=name,
                          to_scheduler_param_tuple=to_scheduler_param_tuple,
                          source_config=source_config)
     self._tf_var_list = tf_var_list
     self.snapshot_var = []
     self.save_snapshot_op = []
     self.load_snapshot_op = []
     self.saver = None
     self.max_to_keep = max_to_keep
     self.require_snapshot = require_snapshot
     self.default_checkpoint_type = default_save_type
     self.save_rest_param_flag = save_rest_param_flag
     if default_save_type != 'tf':
         raise NotImplementedError('only support saving tf')
     self._registered_tf_ph_dict = dict()
     if to_ph_parameter_dict:
         for key, val in to_ph_parameter_dict.items():
             self.to_tf_ph(key=key, ph=val)
Esempio n. 17
0
    def save(self, global_step, save_path=None, name=None, **kwargs):
        save_path = save_path if save_path else GlobalConfig().DEFAULT_MODEL_CHECKPOINT_PATH
        name = name if name else self.name

        self._dynamics_model.save(save_path=save_path, global_step=global_step, name=name, **kwargs)
        self.policy.save(save_path=save_path, global_step=global_step, name=name, **kwargs)
        return dict(check_point_save_path=save_path, check_point_save_global_step=global_step,
                    check_point_save_name=name)
Esempio n. 18
0
    def test_global_config(self):
        GlobalConfig().set_new_config(config_dict=dict(DEFAULT_BASIC_INIT_STATUS='test'))
        assert GlobalConfig().DEFAULT_BASIC_INIT_STATUS == 'test'

        GlobalConfig().freeze_flag = True
        try:
            GlobalConfig().set_new_config(config_dict=dict(DEFAULT_BASIC_INIT_STATUS='test'))
        except AttemptToChangeFreezeGlobalConfigError:
            pass
        else:
            raise TypeError

        try:
            GlobalConfig().set('DEFAULT_LOG_PATH', 'tmp')
        except AttemptToChangeFreezeGlobalConfigError:
            pass
        else:
            raise TypeError

        try:
            GlobalConfig().DEFAULT_LOG_PATH = 'tmp'
        except AttemptToChangeFreezeGlobalConfigError:
            pass
        else:
            raise TypeError
        GlobalConfig().unfreeze()
Esempio n. 19
0
 def _exit(self):
     sess = tf.get_default_session()
     if sess:
         sess.__exit__(None, None, None)
     tf.reset_default_graph()
     reset_global_status_collect()
     reset_logging()
     reset_global_var()
     GlobalConfig().unfreeze()
Esempio n. 20
0
def register_name_globally(name: str, obj):
    if name in get_all()['_global_name_dict'] and not id(obj) == id(
            get_all()['_global_name_dict'][name]) and obj.allow_duplicate_name is False and \
            get_all()['_global_name_dict'][
                name].allow_duplicate_name is False and GlobalConfig().DEFAULT_TURN_OFF_GLOBAL_NAME_FLAG is False:
        raise GlobalNameExistedError(
            'name : {} is existed with object: {}'.format(name, get_all()['_global_name_dict'][name]))
    else:
        get_all()['_global_name_dict'][name] = obj
Esempio n. 21
0
 def _exit(self):
     """ Exit the experiment, reset global configurations and logging module."""
     sess = tf.get_default_session()
     if sess:
         sess.__exit__(None, None, None)
     tf.reset_default_graph()
     reset_global_status_collect()
     reset_logging()
     reset_global_var()
     GlobalConfig().unfreeze()
Esempio n. 22
0
class Basic(object):
    """ Basic class within the whole framework"""
    STATUS_LIST = GlobalConfig().DEFAULT_BASIC_STATUS_LIST
    INIT_STATUS = GlobalConfig().DEFAULT_BASIC_INIT_STATUS
    required_key_dict = ()
    allow_duplicate_name = False

    def __init__(self, name: str, status=None):
        """
        Init a new Basic instance.

        :param name: a string for the name of the object, can be determined to generate log path, handle tensorflow name scope etc.
        :param status: A status instance :py:class:`~baconian.core.status.Status` to indicate the status of the
        """
        if not status:
            self._status = Status(self)
        else:
            self._status = status
        self._name = name
        register_name_globally(name=name, obj=self)

    def init(self, *args, **kwargs):
        raise NotImplementedError

    def get_status(self) -> dict:
        return self._status.get_status()

    def set_status(self, val):
        self._status.set_status(val)

    @property
    def name(self):
        return self._name

    @property
    def status_list(self):
        return self.STATUS_LIST

    def save(self, *args, **kwargs):
        raise NotImplementedError

    def load(self, *args, **kwargs):
        raise NotImplementedError
Esempio n. 23
0
 def run(self):
     GlobalConfig().freeze()
     self.init()
     self.set_status('RUNNING')
     res = self.flow.launch()
     if res is False:
         self.set_status('CORRUPTED')
     else:
         self.set_status('FINISHED')
     self._exit()
Esempio n. 24
0
def mountaincar_task_fn():
    exp_config = MOUNTAINCAR_BENCHMARK_CONFIG_DICT
    GlobalConfig().set('DEFAULT_EXPERIMENT_END_POINT',
                       exp_config['DEFAULT_EXPERIMENT_END_POINT'])

    env = make('MountainCar-v0')
    name = 'benchmark'
    env_spec = EnvSpec(obs_space=env.observation_space,
                       action_space=env.action_space)

    mlp_q = MLPQValueFunction(env_spec=env_spec,
                              name_scope=name + '_mlp_q',
                              name=name + '_mlp_q',
                              **exp_config['MLPQValueFunction'])
    dqn = DQN(env_spec=env_spec,
              name=name + '_dqn',
              value_func=mlp_q,
              **exp_config['DQN'])
    agent = Agent(env=env, env_spec=env_spec,
                  algo=dqn,
                  name=name + '_agent',
                  exploration_strategy=EpsilonGreedy(action_space=env_spec.action_space,
                                                     prob_scheduler=LinearScheduler(
                                                         t_fn=lambda: get_global_status_collect()(
                                                             'TOTAL_AGENT_TRAIN_SAMPLE_COUNT'),
                                                         **exp_config['EpsilonGreedy']['LinearScheduler']),
                                                     **exp_config['EpsilonGreedy']['config_or_config_dict']))
    flow = TrainTestFlow(train_sample_count_func=lambda: get_global_status_collect()('TOTAL_AGENT_TRAIN_SAMPLE_COUNT'),
                         config_or_config_dict=exp_config['TrainTestFlow']['config_or_config_dict'],
                         func_dict={
                             'test': {'func': agent.test,
                                      'args': list(),
                                      'kwargs': dict(sample_count=exp_config['TrainTestFlow']['TEST_SAMPLES_COUNT']),
                                      },
                             'train': {'func': agent.train,
                                       'args': list(),
                                       'kwargs': dict(),
                                       },
                             'sample': {'func': agent.sample,
                                        'args': list(),
                                        'kwargs': dict(sample_count=exp_config['TrainTestFlow']['TRAIN_SAMPLES_COUNT'],
                                                       env=agent.env,
                                                       in_which_status='TRAIN',
                                                       store_flag=True),
                                        },
                         })

    experiment = Experiment(
        tuner=None,
        env=env,
        agent=agent,
        flow=flow,
        name=name
    )
    experiment.run()
Esempio n. 25
0
 def run(self):
     """ Run the experiment, and set status to 'RUNNING'."""
     GlobalConfig().freeze()
     self.init()
     self.set_status('RUNNING')
     res = self.flow.launch()
     if res is False:
         self.set_status('CORRUPTED')
     else:
         self.set_status('FINISHED')
     self._exit()
    def launch(self) -> bool:
        """
        Launch the flow until it finished or catch a system-allowed errors (e.g., out of GPU memory, to ensure the log will be saved safely).

        :return: Boolean, True for the flow correctly executed and finished.
        """
        try:
            return self._launch()
        except GlobalConfig().DEFAULT_ALLOWED_EXCEPTION_OR_ERROR_LIST as e:
            ConsoleLogger().print('error', 'error {} occurred'.format(e))
            return False
Esempio n. 27
0
 def save(self, global_step, save_path=None, name=None, **kwargs):
     save_path = save_path if save_path else GlobalConfig().DEFAULT_MODEL_CHECKPOINT_PATH
     name = name if name else self.name
     self.model_free_algo.save(global_step=global_step,
                               name=None,
                               save_path=os.path.join(save_path, self.model_free_algo.name))
     self.dynamics_env.save(global_step=global_step,
                            name=None,
                            save_path=os.path.join(save_path, self.dynamics_env.name))
     return dict(check_point_save_path=save_path, check_point_save_global_step=global_step,
                 check_point_save_name=name)
 def save(self, global_step, save_path=None, name=None, **kwargs):
     save_path = save_path if save_path else GlobalConfig().DEFAULT_MODEL_CHECKPOINT_PATH
     name = name if name else self.name
     sess = kwargs['sess'] if 'sess' in kwargs else None
     self.parameters.save(save_path=save_path,
                          global_step=global_step,
                          sess=sess,
                          name=name)
     ConsoleLogger().print('info',
                           'model: {}, global step: {}, saved at {}-{}'.format(name, global_step, save_path,
                                                                               global_step))
Esempio n. 29
0
    def setUp(self):
        BaseTestCase.setUp(self)
        try:
            shutil.rmtree(GlobalConfig().DEFAULT_LOG_PATH)
        except FileNotFoundError:
            pass
        # os.makedirs(GlobalConfig().DEFAULT_LOG_PATH)
        # self.assertFalse(ConsoleLogger().inited_flag)
        # self.assertFalse(Logger().inited_flag)

        Logger().init(config_or_config_dict=GlobalConfig().DEFAULT_LOG_CONFIG_DICT,
                      log_path=GlobalConfig().DEFAULT_LOG_PATH,
                      log_level=GlobalConfig().DEFAULT_LOG_LEVEL)
        ConsoleLogger().init(logger_name='console_logger',
                             to_file_flag=True,
                             level=GlobalConfig().DEFAULT_LOG_LEVEL,
                             to_file_name=os.path.join(Logger().log_dir, 'console.log'))

        self.assertTrue(ConsoleLogger().inited_flag)
        self.assertTrue(Logger().inited_flag)
    def test_save_load_with_dqn(self):
        dqn, locals = self.create_dqn()
        dqn.init()
        for i in range(5):
            dqn.save(save_path=GlobalConfig().DEFAULT_LOG_PATH +
                     '/test_placehoder_input',
                     global_step=i,
                     name='dqn')
        file = glob.glob(GlobalConfig().DEFAULT_LOG_PATH +
                         '/test_placehoder_input/dqn*.meta')
        self.assertTrue(len(file) == 5)
        dqn2, _ = self.create_dqn(name='dqn_2')
        dqn2.copy_from(dqn)

        self.assert_var_list_equal(dqn.parameters('tf_var_list'),
                                   dqn2.parameters('tf_var_list'))
        self.assert_var_list_equal(dqn.q_value_func.parameters('tf_var_list'),
                                   dqn2.q_value_func.parameters('tf_var_list'))
        self.assert_var_list_equal(
            dqn.target_q_value_func.parameters('tf_var_list'),
            dqn2.target_q_value_func.parameters('tf_var_list'))

        dqn.init()
        self.assert_var_list_at_least_not_equal(
            dqn.q_value_func.parameters('tf_var_list'),
            dqn2.q_value_func.parameters('tf_var_list'))
        self.assert_var_list_at_least_not_equal(
            dqn.target_q_value_func.parameters('tf_var_list'),
            dqn2.target_q_value_func.parameters('tf_var_list'))
        dqn.load(path_to_model=GlobalConfig().DEFAULT_LOG_PATH +
                 '/test_placehoder_input',
                 global_step=4,
                 model_name='dqn')

        self.assert_var_list_equal(dqn.parameters('tf_var_list'),
                                   dqn2.parameters('tf_var_list'))
        self.assert_var_list_equal(dqn.q_value_func.parameters('tf_var_list'),
                                   dqn2.q_value_func.parameters('tf_var_list'))
        self.assert_var_list_equal(
            dqn.target_q_value_func.parameters('tf_var_list'),
            dqn2.target_q_value_func.parameters('tf_var_list'))