示例#1
0
 def load(self, path_to_model, global_step=None, sess=None, model_name=None, *args, **kwargs):
     if not model_name:
         model_name = self.name
     if self.default_checkpoint_type == 'tf':
         self._load_from_tf(path_to_model=path_to_model,
                            global_step=global_step,
                            sess=sess, model_name=model_name)
     elif self.default_checkpoint_type == 'h5py':
         self._load_from_h5py(*args, **kwargs)
     Parameters.load(self,
                     load_path=path_to_model,
                     global_step=global_step,
                     name=model_name)
示例#2
0
 def __init__(
     self,
     env_spec,
     dynamics_model: DynamicsModel,
     config_or_config_dict: (DictConfig, dict),
     policy: Policy,
     name='mpc',
 ):
     super().__init__(env_spec, dynamics_model, name)
     self.config = construct_dict_config(config_or_config_dict, self)
     self.policy = policy
     self.parameters = Parameters(parameters=dict(),
                                  source_config=self.config,
                                  name=name + '_' + 'mpc_param')
     self.memory = TransitionData(env_spec=env_spec)
示例#3
0
 def __init__(self, env_spec,
              dynamics_model: DynamicsModel,
              config_or_config_dict: (DictConfig, dict),
              policy: Policy,
              name='mpc',
              ):
     super().__init__(env_spec, dynamics_model, name)
     self.config = construct_dict_config(config_or_config_dict, self)
     self.parameters = Parameters(parameters=dict(),
                                  source_config=self.config,
                                  name=name + '_' + 'mpc_param')
     self.policy = policy
     # TODO: 9.18 should also make memory served as init parameter in __init__,
     #  and set default value as Transition in init()
     self.memory = TransitionData(env_spec=env_spec)
示例#4
0
 def __init__(self, env_spec, dynamics_model: ModelEnsemble,
              model_free_algo: ModelFreeAlgo,
              config_or_config_dict: (DictConfig, dict),
              name='sample_with_dynamics'
              ):
     if not isinstance(dynamics_model.model[0], ContinuousMLPGlobalDynamicsModel):
         raise TypeError("Model ensemble elements should be of type ContinuousMLPGlobalDynamicsModel")
     super().__init__(env_spec, dynamics_model, name)
     config = construct_dict_config(config_or_config_dict, self)
     parameters = Parameters(parameters=dict(),
                             name='dyna_param',
                             source_config=config)
     sub_placeholder_input_list = []
     if isinstance(dynamics_model, PlaceholderInput):
         sub_placeholder_input_list.append(dict(obj=dynamics_model,
                                                attr_name='dynamics_model'))
     if isinstance(model_free_algo, PlaceholderInput):
         sub_placeholder_input_list.append(dict(obj=model_free_algo,
                                                attr_name='model_free_algo'))
     self.model_free_algo = model_free_algo
     self.config = config
     self.parameters = parameters
     self.result = list()
     self.validation_result = [0] * len(dynamics_model)
     self._dynamics_model.__class__ = ModelEnsemble
示例#5
0
 def __init__(self,
              env_spec: EnvSpec,
              batch_data: TransitionData = None,
              epsilon=inf,
              init_sequential=False,
              eigreg=False,
              warmstart=True,
              name_scope='gp_dynamics_model',
              min_samples_per_cluster=40,
              max_clusters=20,
              strength=1,
              name='gp_dynamics_model'):
     parameters = Parameters(
         dict(min_samp=min_samples_per_cluster,
              max_samples=inf,
              max_clusters=max_clusters,
              strength=strength,
              init_sequential=init_sequential,
              eigreg=eigreg,
              warmstart=warmstart))
     super().__init__(env_spec=env_spec, parameters=parameters, name=name)
     self.name_scope = name_scope
     self.batch_data = batch_data
     self.gmm_model = GMM(epsilon=epsilon,
                          init_sequential=init_sequential,
                          eigreg=eigreg,
                          warmstart=warmstart)
     self.X, self.U = None, None
示例#6
0
    def __init__(
            self,
            train_sample_count_func,
            config_or_config_dict: (DictConfig, dict),
            func_dict: dict,
    ):
        """

        :param train_sample_count_func: a function indicates how much training samples the agent has collected currently.
        :type train_sample_count_func: method
        :param config_or_config_dict: a Config or a dict should have the keys: (TEST_EVERY_SAMPLE_COUNT, TRAIN_EVERY_SAMPLE_COUNT, START_TRAIN_AFTER_SAMPLE_COUNT, START_TEST_AFTER_SAMPLE_COUNT)
        :type config_or_config_dict: Config or dict
        :param func_dict: function dict, holds the keys: 'sample', 'train', 'test'. each item in the dict as also should be a dict, holds the keys 'func', 'args', 'kwargs'
        :type func_dict: dict
        """
        super().__init__(func_dict)
        super(DynaFlow, self).__init__(func_dict=func_dict)
        config = construct_dict_config(config_or_config_dict, obj=self)
        self.parameters = Parameters(source_config=config, parameters=dict())
        self.time_step_func = train_sample_count_func
        self._last_train_algo_point = -1
        self._last_train_algo_point_from_dynamics = -1
        self._last_test_algo_point = -1
        self._last_train_dynamics_point = -1
        self._last_test_dynamics_point = -1
        assert callable(train_sample_count_func)
示例#7
0
    def __init__(self,
                 train_sample_count_func,
                 config_or_config_dict: (DictConfig, dict),
                 func_dict: dict
                 ):
        super(DDPG_TrainTestFlow, self).__init__(func_dict=func_dict)
        config = construct_dict_config(config_or_config_dict, obj=self)
        self.parameters = Parameters(source_config=config, parameters=dict())
        if train_sample_count_func:
            assert callable(train_sample_count_func)

        self.env = self.parameters('env')
        self.env_spec = self.env.env_spec
        self.agent = self.parameters('agent')
        self.cyber = self.parameters('cyber')
        self.total_steps = self.parameters('total_steps')
        self.max_step_per_episode = self.parameters('max_step_per_episode')
        self.train_after_step = self.parameters('train_after_step')
        self.train_every_step = self.parameters('train_every_step')
        self.test_after_step = self.parameters('test_after_step')
        self.test_every_step = self.parameters('test_every_step')
        self.num_test = self.parameters('num_test')
        self.test_reward = []
        self.data_sample = []
        self.step_counter = SinglentonStepCounter(-1)
示例#8
0
 def __init__(self, env_spec: EnvSpec, config_or_config_dict: (DictConfig, dict), name='policy'):
     config = construct_dict_config(config_or_config_dict, self)
     parameters = Parameters(parameters=dict(),
                             source_config=config)
     assert env_spec.action_space.contains(x=config('ACTION_VALUE'))
     super().__init__(env_spec, parameters, name)
     self.config = config
示例#9
0
 def save(self, save_path, global_step, sess=None, name=None, *args, **kwargs):
     if self.default_checkpoint_type == 'tf':
         self._save_to_tf(save_path=save_path,
                          global_step=global_step,
                          sess=sess,
                          name=name)
     elif self.default_checkpoint_type == 'h5py':
         raise NotImplementedError
     if self.save_rest_param_flag is False:
         to_save_dict = dict(_source_config=self._source_config.config_dict)
     else:
         to_save_dict = dict(_parameters=self._parameters, _source_config=self._source_config.config_dict)
     Parameters.save(self,
                     save_path=save_path,
                     global_step=global_step,
                     default_save_param=to_save_dict,
                     name=name)
 def create_parameters(self):
     parameters = dict(param1='aaaa',
                       param2=12312,
                       param3=np.random.random([4, 2]))
     source_config, _ = self.create_dict_config()
     a = Parameters(parameters=parameters,
                    source_config=source_config,
                    name='test_params')
     return a, locals()
示例#11
0
 def __init__(self, env_spec: EnvSpec, T: int, cost_fn: CostFunc,
              dynamics: LinearDynamicsModel):
     param = Parameters(parameters=dict(T=T))
     super().__init__(env_spec, param)
     self.dynamics = dynamics
     self.Lqr_instance = LQR(env_spec=env_spec,
                             T=self.parameters('T'),
                             dyna_model=dynamics,
                             cost_fn=cost_fn)
    def __init__(self, action_space: Space, init_random_prob: float, prob_scheduler: Schedule = None):
        super(ExplorationStrategy, self).__init__()

        self.action_space = action_space
        self.random_prob_func = lambda: init_random_prob
        if prob_scheduler:
            self.random_prob_func = prob_scheduler.value

        self.parameters = Parameters(parameters=dict(random_prob_func=self.random_prob_func),
                                     name='eps_greedy_params')
示例#13
0
    def __init__(
            self,
            name,
            # config_or_config_dict: (DictConfig, dict),
            env: (Env, Wrapper),
            algo: Algo,
            env_spec: EnvSpec,
            sampler: Sampler = None,
            noise_adder: AgentActionNoiseWrapper = None,
            reset_noise_every_terminal_state=False,
            reset_state_every_sample=False,
            exploration_strategy: ExplorationStrategy = None,
            algo_saving_scheduler: EventScheduler = None):
        """

        :param name: the name of the agent instance
        :type name: str
        :param env: environment that interacts with agent
        :type env: Env
        :param algo: algorithm of the agent
        :type algo: Algo
        :param env_spec: environment specifications: action apace and environment space
        :type env_spec: EnvSpec
        :param sampler: sampler
        :type sampler: Sampler
        :param reset_noise_every_terminal_state: reset the noise every sampled trajectory
        :type reset_noise_every_terminal_state: bool
        :param reset_state_every_sample: reset the state everytime perofrm the sample/rollout
        :type reset_state_every_sample: bool
        :param noise_adder: add action noise for exploration in action space
        :type noise_adder: AgentActionNoiseWrapper
        :param exploration_strategy: exploration strategy in action space
        :type exploration_strategy: ExplorationStrategy
        :param algo_saving_scheduler: control the schedule the varying parameters in training process
        :type algo_saving_scheduler: EventSchedule
        """
        super(Agent, self).__init__(name=name, status=StatusWithSubInfo(self))
        self.parameters = Parameters(parameters=dict(
            reset_noise_every_terminal_state=reset_noise_every_terminal_state,
            reset_state_every_sample=reset_state_every_sample))
        self.env = env
        self.algo = algo
        self._env_step_count = 0
        if sampler is None:
            sampler = Sampler()
        self.sampler = sampler
        self.recorder = Recorder(default_obj=self)
        self.env_spec = env_spec
        if exploration_strategy:
            assert isinstance(exploration_strategy, ExplorationStrategy)
            self.explorations_strategy = exploration_strategy
        else:
            self.explorations_strategy = None
        self.noise_adder = noise_adder
        self.algo_saving_scheduler = algo_saving_scheduler
示例#14
0
 def init(self):
     Parameters.init(self)
     sess = tf.get_default_session()
     sess.run(tf.variables_initializer(var_list=self._tf_var_list))
     if self.require_snapshot is True:
         if len(self.snapshot_var) == 0:
             # add the snapshot op after the init
             sess = tf.get_default_session()
             with tf.variable_scope('snapshot'):
                 for var in self._tf_var_list:
                     snap_var = tf.Variable(initial_value=sess.run(var),
                                            expected_shape=var.get_shape().as_list(),
                                            name=str(var.name).split(':')[0])
                     self.snapshot_var.append(snap_var)
                     self.save_snapshot_op.append(tf.assign(snap_var, var))
                     self.load_snapshot_op.append(tf.assign(var, snap_var))
         sess.run(tf.variables_initializer(var_list=self.snapshot_var))
         sess.run(self.save_snapshot_op)
     self.saver = tf.train.Saver(max_to_keep=self.max_to_keep,
                                 var_list=self._tf_var_list)
示例#15
0
 def __init__(self,
              env_spec: EnvSpec,
              state_transition_matrix: np.array,
              bias: np.array,
              init_state=None,
              name='dynamics_model'):
     parameters = Parameters(
         parameters=dict(F=state_transition_matrix, f=bias))
     super().__init__(env_spec, parameters, init_state, name)
     assert self.parameters('F').shape == \
            (env_spec.obs_space.flat_dim, env_spec.obs_space.flat_dim + env_spec.action_space.flat_dim)
     assert self.parameters('f').shape[0] == env_spec.obs_space.flat_dim
示例#16
0
 def __init__(
         self,
         train_sample_count_func,
         config_or_config_dict: (DictConfig, dict),
         func_dict: dict,
 ):
     super(TrainTestFlow, self).__init__(func_dict=func_dict)
     config = construct_dict_config(config_or_config_dict, obj=self)
     self.parameters = Parameters(source_config=config, parameters=dict())
     self.time_step_func = train_sample_count_func
     self.last_train_point = -1
     self.last_test_point = -1
     assert callable(train_sample_count_func)
示例#17
0
    def __init__(self, env_spec: EnvSpec, name: str = 'algo'):
        """
        Constructor

        :param env_spec: environment specifications
        :type env_spec: EnvSpec
        :param name: name of the algorithm
        :type name: str
        """

        super().__init__(status=StatusWithSubInfo(obj=self), name=name)
        self.env_spec = env_spec
        self.parameters = Parameters(dict())
        self.recorder = Recorder()
示例#18
0
    def __init__(
            self,
            name,
            # config_or_config_dict: (DictConfig, dict),
            env: Env,
            algo: Algo,
            env_spec: EnvSpec,
            sampler: Sampler = None,
            noise_adder: AgentActionNoiseWrapper = None,
            exploration_strategy: ExplorationStrategy = None,
            algo_saving_scheduler: EventScheduler = None):
        """

        :param name: the name of the agent instance
        :type name: str
        :param env: environment that interacts with agent
        :type env: Env
        :param algo: algorithm of the agent
        :type algo: Algo
        :param env_spec: environment specifications: action apace and environment space
        :type env_spec: EnvSpec
        :param sampler: sampler
        :type sampler: Sampler
        :param noise_adder: add action noise for exploration in action space
        :type noise_adder: AgentActionNoiseWrapper
        :param exploration_strategy: exploration strategy in action space
        :type exploration_strategy: ExplorationStrategy
        :param algo_saving_scheduler: control the schedule the varying parameters in training process
        :type algo_saving_scheduler: EventSchedule
        """
        super(Agent, self).__init__(name=name, status=StatusWithSubInfo(self))
        self.parameters = Parameters(parameters=dict())
        self.total_test_samples = 0
        self.total_train_samples = 0
        self.env = env
        self.algo = algo
        self._env_step_count = 0
        self.sampler = sampler
        self.recorder = Recorder()
        self.env_spec = env_spec
        if exploration_strategy:
            assert isinstance(exploration_strategy, ExplorationStrategy)
            self.explorations_strategy = exploration_strategy
        else:
            self.explorations_strategy = None
        self.sampler = sampler if sampler else Sampler(
            env_spec=env_spec, name='{}_sampler'.format(name))
        self.noise_adder = noise_adder
        self.algo_saving_scheduler = algo_saving_scheduler
示例#19
0
 def __init__(self, env_spec: EnvSpec, T: int, delta: float, iteration: int, cost_fn: CostFunc,
              dynamics_model_train_iter: int,
              dynamics: DynamicsEnvWrapper):
     param = Parameters(parameters=dict(T=T, delta=delta,
                                        iteration=iteration,
                                        dynamics_model_train_iter=dynamics_model_train_iter))
     super().__init__(env_spec, param)
     self.dynamics = dynamics
     self.U_hat = None
     self.X_hat = None
     self.iLqr_instance = iLQR(env_spec=env_spec,
                               delta=self.parameters('delta'),
                               T=self.parameters('T'),
                               dyn_model=dynamics._dynamics,
                               cost_fn=cost_fn)
示例#20
0
    def __init__(self,
                 env_spec: EnvSpec,
                 name: str = 'algo',
                 warm_up_trajectories_number=0):
        """
        Constructor

        :param env_spec: environment specifications
        :type env_spec: EnvSpec
        :param name: name of the algorithm
        :type name: str
        :param warm_up_trajectories_number: how many trajectories used to warm up the training
        :type warm_up_trajectories_number: int
        """

        super().__init__(status=StatusWithSubInfo(obj=self), name=name)
        self.env_spec = env_spec
        self.parameters = Parameters(dict())
        self.recorder = Recorder(default_obj=self)
        self.warm_up_trajectories_number = warm_up_trajectories_number
示例#21
0
 def __init__(self, env_spec, dynamics_model: DynamicsModel,
              model_free_algo: ModelFreeAlgo,
              config_or_config_dict: (DictConfig, dict),
              name='sample_with_dynamics'
              ):
     super().__init__(env_spec, dynamics_model, name)
     config = construct_dict_config(config_or_config_dict, self)
     parameters = Parameters(parameters=dict(),
                             name='dyna_param',
                             source_config=config)
     sub_placeholder_input_list = []
     if isinstance(dynamics_model, PlaceholderInput):
         sub_placeholder_input_list.append(dict(obj=dynamics_model,
                                                attr_name='dynamics_model'))
     if isinstance(model_free_algo, PlaceholderInput):
         sub_placeholder_input_list.append(dict(obj=model_free_algo,
                                                attr_name='model_free_algo'))
     self.model_free_algo = model_free_algo
     self.config = config
     self.parameters = parameters
示例#22
0
 def __init__(
         self,
         train_sample_count_func,
         config_or_config_dict: (DictConfig, dict),
         func_dict: dict,
 ):
     super().__init__(func_dict)
     super(MEPPO_Flow, self).__init__(func_dict=func_dict)
     config = construct_dict_config(config_or_config_dict, obj=self)
     self.parameters = Parameters(source_config=config, parameters=dict())
     self.time_step_func = train_sample_count_func
     self._last_train_algo_point = -1
     self._start_train_algo_point_from_dynamics = -1
     self._last_test_algo_point = -1
     self._start_train_dynamics_point = -1
     self._last_test_dynamics_point = -1
     self._last_performance = 0
     self._last_chance = 0
     self._fictitious_set_count = 0
     assert callable(train_sample_count_func)
示例#23
0
    def __init__(
            self,
            name,
            # config_or_config_dict: (DictConfig, dict),
            env: Env,
            algo: Algo,
            env_spec: EnvSpec,
            sampler: Sampler = None,
            noise_adder: AgentActionNoiseWrapper = None,
            exploration_strategy: ExplorationStrategy = None,
            algo_saving_scheduler: EventSchedule = None):
        """

        :param name:
        :param env:
        :param algo:
        :param env_spec:
        :param sampler:
        :param noise_adder:
        :param exploration_strategy:
        :param algo_saving_scheduler:
        """
        super(Agent, self).__init__(name=name, status=StatusWithSubInfo(self))
        self.parameters = Parameters(parameters=dict())
        self.total_test_samples = 0
        self.total_train_samples = 0
        self.env = env
        self.algo = algo
        self._env_step_count = 0
        self.sampler = sampler
        self.recorder = Recorder()
        self.env_spec = env_spec
        if exploration_strategy:
            assert isinstance(exploration_strategy, ExplorationStrategy)
            self.explorations_strategy = exploration_strategy
        else:
            self.explorations_strategy = None
        self.sampler = sampler if sampler else Sampler(
            env_spec=env_spec, name='{}_sampler'.format(name))
        self.noise_adder = noise_adder
        self.algo_saving_scheduler = algo_saving_scheduler
示例#24
0
    def __init__(self,
                 train_sample_count_func,
                 config_or_config_dict: (DictConfig, dict),
                 func_dict: dict,
                 ):
        """
        Constructor of TrainTestFlow

        :param train_sample_count_func: a function indicates how much training samples the agent has collected currently. When reach preset value, programm will quit training.
        :type train_sample_count_func: method
        :param config_or_config_dict: a Config or a dict should have the keys: (TEST_EVERY_SAMPLE_COUNT, TRAIN_EVERY_SAMPLE_COUNT, START_TRAIN_AFTER_SAMPLE_COUNT, START_TEST_AFTER_SAMPLE_COUNT)
        :type config_or_config_dict: Config or dict
        :param func_dict: function dict, holds the keys: 'sample', 'train', 'test'. each item in the dict as also should be a dict, holds the keys 'func', 'args', 'kwargs'
        :type func_dict: dict
        """
        super(TrainTestFlow, self).__init__(func_dict=func_dict)
        config = construct_dict_config(config_or_config_dict, obj=self)
        self.parameters = Parameters(source_config=config, parameters=dict())  # hyper parameter instance
        self.time_step_func = train_sample_count_func
        self.last_train_point = -1
        self.last_test_point = -1
        assert callable(train_sample_count_func)    # return TOTAL_AGENT_TRAIN_SAMPLE_COUNT
示例#25
0
    def __init__(self,
                 train_sample_count_func,
                 config_or_config_dict: (DictConfig, dict),
                 func_dict: dict
                 ):
        super(MBMPC_TrainFlow, self).__init__(func_dict=func_dict)
        config = construct_dict_config(config_or_config_dict, obj=self)
        self.parameters = Parameters(source_config=config, parameters=dict())  # hyper parameter instance
        if train_sample_count_func:
            assert callable(train_sample_count_func)    # return TOTAL_AGENT_TRAIN_SAMPLE_COUNT

        from baconian.common.sampler.sample_data import MPC_TransitionData
        self.env = self.parameters('env')
        self.env_spec = self.env.env_spec
        env_spec = self.env_spec
        self.random_buffer = MPC_TransitionData(env_spec=env_spec,
                                                obs_shape=env_spec.obs_shape,
                                                action_shape=env_spec.action_shape,
                                                size=self.parameters('random_size'))
        self.rl_buffer = MPC_TransitionData(env_spec=env_spec,
                                                obs_shape=env_spec.obs_shape,
                                                action_shape=env_spec.action_shape,
                                                size=self.parameters('rl_size'))
示例#26
0
class ModelPredictiveControl(ModelBasedAlgo):
    required_key_dict = DictConfig.load_json(
        file_path=GlobalConfig().DEFAULT_MPC_REQUIRED_KEY_LIST)

    def __init__(
        self,
        env_spec,
        dynamics_model: DynamicsModel,
        config_or_config_dict: (DictConfig, dict),
        policy: Policy,
        name='mpc',
    ):
        super().__init__(env_spec, dynamics_model, name)
        self.config = construct_dict_config(config_or_config_dict, self)
        self.policy = policy
        self.parameters = Parameters(parameters=dict(),
                                     source_config=self.config,
                                     name=name + '_' + 'mpc_param')
        self.memory = TransitionData(env_spec=env_spec)

    def init(self, source_obj=None):
        super().init()
        self.parameters.init()
        self._dynamics_model.init()
        self.policy.init()
        if source_obj:
            self.copy_from(source_obj)

    def train(self, *arg, **kwargs) -> dict:
        super(ModelPredictiveControl, self).train()
        res_dict = {}
        batch_data = kwargs[
            'batch_data'] if 'batch_data' in kwargs else self.memory

        dynamics_train_res_dict = self._fit_dynamics_model(
            batch_data=batch_data,
            train_iter=self.parameters('dynamics_model_train_iter'))
        for key, val in dynamics_train_res_dict.items():
            res_dict["mlp_dynamics_{}".format(key)] = val
        return res_dict

    def test(self, *arg, **kwargs) -> dict:
        return super().test(*arg, **kwargs)

    def _fit_dynamics_model(self,
                            batch_data: TransitionData,
                            train_iter,
                            sess=None) -> dict:
        res_dict = self._dynamics_model.train(
            batch_data, **dict(sess=sess, train_iter=train_iter))
        return res_dict

    def predict(self, obs, **kwargs):
        if self.is_training is True:
            return self.env_spec.action_space.sample()
        rollout = TrajectoryData(env_spec=self.env_spec)
        state = obs
        for i in range(self.parameters('SAMPLED_PATH_NUM')):
            path = TransitionData(env_spec=self.env_spec)
            # todo terminal_func signal problem to be consider?
            for _ in range(self.parameters('SAMPLED_HORIZON')):
                ac = self.policy.forward(obs=state)
                new_state, re, done, _ = self.dynamics_env.step(action=ac,
                                                                state=state)
                path.append(state=state,
                            action=ac,
                            new_state=new_state,
                            reward=re,
                            done=done)
                state = new_state
            rollout.append(path)
        rollout.trajectories.sort(key=lambda x: x.cumulative_reward,
                                  reverse=True)
        ac = rollout.trajectories[0].action_set[0]
        assert self.env_spec.action_space.contains(ac)
        return ac

    def append_to_memory(self, samples: TransitionData):
        self.memory.union(samples)

    def copy_from(self, obj) -> bool:
        if not isinstance(obj, type(self)):
            raise TypeError(
                'Wrong type of obj %s to be copied, which should be %s' %
                (type(obj), type(self)))
        self.parameters.copy_from(obj.parameters)
        self._dynamics_model.copy_from(obj._dynamics_model)
        ConsoleLogger().print('info',
                              'model: {} copied from {}'.format(self, obj))
        return True

    @record_return_decorator(which_recorder='self')
    def save(self, global_step, save_path=None, name=None, **kwargs):
        save_path = save_path if save_path else GlobalConfig(
        ).DEFAULT_MODEL_CHECKPOINT_PATH
        name = name if name else self.name

        self._dynamics_model.save(save_path=save_path,
                                  global_step=global_step,
                                  name=name,
                                  **kwargs)
        self.policy.save(save_path=save_path,
                         global_step=global_step,
                         name=name,
                         **kwargs)
        return dict(check_point_save_path=save_path,
                    check_point_save_global_step=global_step,
                    check_point_save_name=name)

    @record_return_decorator(which_recorder='self')
    def load(self, path_to_model, model_name, global_step=None, **kwargs):
        self._dynamics_model.load(path_to_model, model_name, global_step,
                                  **kwargs)
        self.policy.load(path_to_model, model_name, global_step, **kwargs)
        return dict(check_point_load_path=path_to_model,
                    check_point_load_global_step=global_step,
                    check_point_load_name=model_name)
示例#27
0
    def test_scheduler_param(self):
        def func():
            global x
            return x

        parameters = dict(param1='aaaa',
                          param2=1.0,
                          param4=1.0,
                          param3=np.random.random([4, 2]))
        source_config, _ = self.create_dict_config()
        a = Parameters(
            parameters=parameters,
            source_config=source_config,
            name='test_params',
            to_scheduler_param_tuple=(dict(param_key='param2',
                                           scheduler=LinearSchedule(
                                               t_fn=func,
                                               schedule_timesteps=10,
                                               final_p=0.0)),
                                      dict(param_key='param4',
                                           scheduler=PiecewiseSchedule(
                                               t_fn=func,
                                               endpoints=((2, 0.5), (8, 0.2),
                                                          (10, 0.0)),
                                               outside_value=0.0,
                                           ))))
        a.init()
        for i in range(20):
            global x
            if x < 10:
                self.assertEqual(a('param2'), 1.0 - x * (1.0 - 0.0) / 10)
            else:
                self.assertEqual(a('param2'), 0.0)
            if x == 2:
                self.assertEqual(a('param4'), 0.5)
            if x == 8:
                self.assertEqual(a('param4'), 0.2)
            if x >= 10:
                self.assertEqual(a('param4'), 0.0)
            x += 1
        b, _ = self.create_parameters()
        b.copy_from(a)
        for key in a._source_config.required_key_dict.keys():
            if isinstance(a[key], np.ndarray):
                self.assertTrue(np.equal(a[key], b[key]).all())
            else:
                self.assertEqual(id(a[key]), id(b[key]))
                self.assertEqual(id(a(key)), id(b(key)))
        for key in a._parameters.keys():
            if isinstance(a[key], np.ndarray):
                self.assertTrue(np.equal(a[key], b[key]).all())
            else:
                self.assertEqual(a[key], b[key])
                self.assertEqual(a(key), b(key))
        self.assertEqual(a.to_scheduler_param_list.__len__(),
                         b.to_scheduler_param_list.__len__())
        for a_val, b_val in zip(a.to_scheduler_param_list,
                                b.to_scheduler_param_list):
            self.assertEqual(a_val['param_key'], b_val['param_key'])
            self.assertEqual(a_val['scheduler'].value(),
                             b_val['scheduler'].value())