Пример #1
0
 def init(self):
     Parameters.init(self)
     sess = tf.get_default_session()
     sess.run(tf.variables_initializer(var_list=self._tf_var_list))
     if self.require_snapshot is True:
         if len(self.snapshot_var) == 0:
             # add the snapshot op after the init
             sess = tf.get_default_session()
             with tf.variable_scope('snapshot'):
                 for var in self._tf_var_list:
                     snap_var = tf.Variable(initial_value=sess.run(var),
                                            expected_shape=var.get_shape().as_list(),
                                            name=str(var.name).split(':')[0])
                     self.snapshot_var.append(snap_var)
                     self.save_snapshot_op.append(tf.assign(snap_var, var))
                     self.load_snapshot_op.append(tf.assign(var, snap_var))
         sess.run(tf.variables_initializer(var_list=self.snapshot_var))
         sess.run(self.save_snapshot_op)
     self.saver = tf.train.Saver(max_to_keep=self.max_to_keep,
                                 var_list=self._tf_var_list)
Пример #2
0
class ModelPredictiveControl(ModelBasedAlgo):
    required_key_dict = DictConfig.load_json(
        file_path=GlobalConfig().DEFAULT_MPC_REQUIRED_KEY_LIST)

    def __init__(
        self,
        env_spec,
        dynamics_model: DynamicsModel,
        config_or_config_dict: (DictConfig, dict),
        policy: Policy,
        name='mpc',
    ):
        super().__init__(env_spec, dynamics_model, name)
        self.config = construct_dict_config(config_or_config_dict, self)
        self.policy = policy
        self.parameters = Parameters(parameters=dict(),
                                     source_config=self.config,
                                     name=name + '_' + 'mpc_param')
        self.memory = TransitionData(env_spec=env_spec)

    def init(self, source_obj=None):
        super().init()
        self.parameters.init()
        self._dynamics_model.init()
        self.policy.init()
        if source_obj:
            self.copy_from(source_obj)

    def train(self, *arg, **kwargs) -> dict:
        super(ModelPredictiveControl, self).train()
        res_dict = {}
        batch_data = kwargs[
            'batch_data'] if 'batch_data' in kwargs else self.memory

        dynamics_train_res_dict = self._fit_dynamics_model(
            batch_data=batch_data,
            train_iter=self.parameters('dynamics_model_train_iter'))
        for key, val in dynamics_train_res_dict.items():
            res_dict["mlp_dynamics_{}".format(key)] = val
        return res_dict

    def test(self, *arg, **kwargs) -> dict:
        return super().test(*arg, **kwargs)

    def _fit_dynamics_model(self,
                            batch_data: TransitionData,
                            train_iter,
                            sess=None) -> dict:
        res_dict = self._dynamics_model.train(
            batch_data, **dict(sess=sess, train_iter=train_iter))
        return res_dict

    def predict(self, obs, **kwargs):
        if self.is_training is True:
            return self.env_spec.action_space.sample()
        rollout = TrajectoryData(env_spec=self.env_spec)
        state = obs
        for i in range(self.parameters('SAMPLED_PATH_NUM')):
            path = TransitionData(env_spec=self.env_spec)
            # todo terminal_func signal problem to be consider?
            for _ in range(self.parameters('SAMPLED_HORIZON')):
                ac = self.policy.forward(obs=state)
                new_state, re, done, _ = self.dynamics_env.step(action=ac,
                                                                state=state)
                path.append(state=state,
                            action=ac,
                            new_state=new_state,
                            reward=re,
                            done=done)
                state = new_state
            rollout.append(path)
        rollout.trajectories.sort(key=lambda x: x.cumulative_reward,
                                  reverse=True)
        ac = rollout.trajectories[0].action_set[0]
        assert self.env_spec.action_space.contains(ac)
        return ac

    def append_to_memory(self, samples: TransitionData):
        self.memory.union(samples)

    def copy_from(self, obj) -> bool:
        if not isinstance(obj, type(self)):
            raise TypeError(
                'Wrong type of obj %s to be copied, which should be %s' %
                (type(obj), type(self)))
        self.parameters.copy_from(obj.parameters)
        self._dynamics_model.copy_from(obj._dynamics_model)
        ConsoleLogger().print('info',
                              'model: {} copied from {}'.format(self, obj))
        return True

    @record_return_decorator(which_recorder='self')
    def save(self, global_step, save_path=None, name=None, **kwargs):
        save_path = save_path if save_path else GlobalConfig(
        ).DEFAULT_MODEL_CHECKPOINT_PATH
        name = name if name else self.name

        self._dynamics_model.save(save_path=save_path,
                                  global_step=global_step,
                                  name=name,
                                  **kwargs)
        self.policy.save(save_path=save_path,
                         global_step=global_step,
                         name=name,
                         **kwargs)
        return dict(check_point_save_path=save_path,
                    check_point_save_global_step=global_step,
                    check_point_save_name=name)

    @record_return_decorator(which_recorder='self')
    def load(self, path_to_model, model_name, global_step=None, **kwargs):
        self._dynamics_model.load(path_to_model, model_name, global_step,
                                  **kwargs)
        self.policy.load(path_to_model, model_name, global_step, **kwargs)
        return dict(check_point_load_path=path_to_model,
                    check_point_load_global_step=global_step,
                    check_point_load_name=model_name)
Пример #3
0
    def test_scheduler_param(self):
        def func():
            global x
            return x

        parameters = dict(param1='aaaa',
                          param2=1.0,
                          param4=1.0,
                          param3=np.random.random([4, 2]))
        source_config, _ = self.create_dict_config()
        a = Parameters(
            parameters=parameters,
            source_config=source_config,
            name='test_params',
            to_scheduler_param_tuple=(dict(param_key='param2',
                                           scheduler=LinearSchedule(
                                               t_fn=func,
                                               schedule_timesteps=10,
                                               final_p=0.0)),
                                      dict(param_key='param4',
                                           scheduler=PiecewiseSchedule(
                                               t_fn=func,
                                               endpoints=((2, 0.5), (8, 0.2),
                                                          (10, 0.0)),
                                               outside_value=0.0,
                                           ))))
        a.init()
        for i in range(20):
            global x
            if x < 10:
                self.assertEqual(a('param2'), 1.0 - x * (1.0 - 0.0) / 10)
            else:
                self.assertEqual(a('param2'), 0.0)
            if x == 2:
                self.assertEqual(a('param4'), 0.5)
            if x == 8:
                self.assertEqual(a('param4'), 0.2)
            if x >= 10:
                self.assertEqual(a('param4'), 0.0)
            x += 1
        b, _ = self.create_parameters()
        b.copy_from(a)
        for key in a._source_config.required_key_dict.keys():
            if isinstance(a[key], np.ndarray):
                self.assertTrue(np.equal(a[key], b[key]).all())
            else:
                self.assertEqual(id(a[key]), id(b[key]))
                self.assertEqual(id(a(key)), id(b(key)))
        for key in a._parameters.keys():
            if isinstance(a[key], np.ndarray):
                self.assertTrue(np.equal(a[key], b[key]).all())
            else:
                self.assertEqual(a[key], b[key])
                self.assertEqual(a(key), b(key))
        self.assertEqual(a.to_scheduler_param_list.__len__(),
                         b.to_scheduler_param_list.__len__())
        for a_val, b_val in zip(a.to_scheduler_param_list,
                                b.to_scheduler_param_list):
            self.assertEqual(a_val['param_key'], b_val['param_key'])
            self.assertEqual(a_val['scheduler'].value(),
                             b_val['scheduler'].value())