def init(self): Parameters.init(self) sess = tf.get_default_session() sess.run(tf.variables_initializer(var_list=self._tf_var_list)) if self.require_snapshot is True: if len(self.snapshot_var) == 0: # add the snapshot op after the init sess = tf.get_default_session() with tf.variable_scope('snapshot'): for var in self._tf_var_list: snap_var = tf.Variable(initial_value=sess.run(var), expected_shape=var.get_shape().as_list(), name=str(var.name).split(':')[0]) self.snapshot_var.append(snap_var) self.save_snapshot_op.append(tf.assign(snap_var, var)) self.load_snapshot_op.append(tf.assign(var, snap_var)) sess.run(tf.variables_initializer(var_list=self.snapshot_var)) sess.run(self.save_snapshot_op) self.saver = tf.train.Saver(max_to_keep=self.max_to_keep, var_list=self._tf_var_list)
class ModelPredictiveControl(ModelBasedAlgo): required_key_dict = DictConfig.load_json( file_path=GlobalConfig().DEFAULT_MPC_REQUIRED_KEY_LIST) def __init__( self, env_spec, dynamics_model: DynamicsModel, config_or_config_dict: (DictConfig, dict), policy: Policy, name='mpc', ): super().__init__(env_spec, dynamics_model, name) self.config = construct_dict_config(config_or_config_dict, self) self.policy = policy self.parameters = Parameters(parameters=dict(), source_config=self.config, name=name + '_' + 'mpc_param') self.memory = TransitionData(env_spec=env_spec) def init(self, source_obj=None): super().init() self.parameters.init() self._dynamics_model.init() self.policy.init() if source_obj: self.copy_from(source_obj) def train(self, *arg, **kwargs) -> dict: super(ModelPredictiveControl, self).train() res_dict = {} batch_data = kwargs[ 'batch_data'] if 'batch_data' in kwargs else self.memory dynamics_train_res_dict = self._fit_dynamics_model( batch_data=batch_data, train_iter=self.parameters('dynamics_model_train_iter')) for key, val in dynamics_train_res_dict.items(): res_dict["mlp_dynamics_{}".format(key)] = val return res_dict def test(self, *arg, **kwargs) -> dict: return super().test(*arg, **kwargs) def _fit_dynamics_model(self, batch_data: TransitionData, train_iter, sess=None) -> dict: res_dict = self._dynamics_model.train( batch_data, **dict(sess=sess, train_iter=train_iter)) return res_dict def predict(self, obs, **kwargs): if self.is_training is True: return self.env_spec.action_space.sample() rollout = TrajectoryData(env_spec=self.env_spec) state = obs for i in range(self.parameters('SAMPLED_PATH_NUM')): path = TransitionData(env_spec=self.env_spec) # todo terminal_func signal problem to be consider? for _ in range(self.parameters('SAMPLED_HORIZON')): ac = self.policy.forward(obs=state) new_state, re, done, _ = self.dynamics_env.step(action=ac, state=state) path.append(state=state, action=ac, new_state=new_state, reward=re, done=done) state = new_state rollout.append(path) rollout.trajectories.sort(key=lambda x: x.cumulative_reward, reverse=True) ac = rollout.trajectories[0].action_set[0] assert self.env_spec.action_space.contains(ac) return ac def append_to_memory(self, samples: TransitionData): self.memory.union(samples) def copy_from(self, obj) -> bool: if not isinstance(obj, type(self)): raise TypeError( 'Wrong type of obj %s to be copied, which should be %s' % (type(obj), type(self))) self.parameters.copy_from(obj.parameters) self._dynamics_model.copy_from(obj._dynamics_model) ConsoleLogger().print('info', 'model: {} copied from {}'.format(self, obj)) return True @record_return_decorator(which_recorder='self') def save(self, global_step, save_path=None, name=None, **kwargs): save_path = save_path if save_path else GlobalConfig( ).DEFAULT_MODEL_CHECKPOINT_PATH name = name if name else self.name self._dynamics_model.save(save_path=save_path, global_step=global_step, name=name, **kwargs) self.policy.save(save_path=save_path, global_step=global_step, name=name, **kwargs) return dict(check_point_save_path=save_path, check_point_save_global_step=global_step, check_point_save_name=name) @record_return_decorator(which_recorder='self') def load(self, path_to_model, model_name, global_step=None, **kwargs): self._dynamics_model.load(path_to_model, model_name, global_step, **kwargs) self.policy.load(path_to_model, model_name, global_step, **kwargs) return dict(check_point_load_path=path_to_model, check_point_load_global_step=global_step, check_point_load_name=model_name)
def test_scheduler_param(self): def func(): global x return x parameters = dict(param1='aaaa', param2=1.0, param4=1.0, param3=np.random.random([4, 2])) source_config, _ = self.create_dict_config() a = Parameters( parameters=parameters, source_config=source_config, name='test_params', to_scheduler_param_tuple=(dict(param_key='param2', scheduler=LinearSchedule( t_fn=func, schedule_timesteps=10, final_p=0.0)), dict(param_key='param4', scheduler=PiecewiseSchedule( t_fn=func, endpoints=((2, 0.5), (8, 0.2), (10, 0.0)), outside_value=0.0, )))) a.init() for i in range(20): global x if x < 10: self.assertEqual(a('param2'), 1.0 - x * (1.0 - 0.0) / 10) else: self.assertEqual(a('param2'), 0.0) if x == 2: self.assertEqual(a('param4'), 0.5) if x == 8: self.assertEqual(a('param4'), 0.2) if x >= 10: self.assertEqual(a('param4'), 0.0) x += 1 b, _ = self.create_parameters() b.copy_from(a) for key in a._source_config.required_key_dict.keys(): if isinstance(a[key], np.ndarray): self.assertTrue(np.equal(a[key], b[key]).all()) else: self.assertEqual(id(a[key]), id(b[key])) self.assertEqual(id(a(key)), id(b(key))) for key in a._parameters.keys(): if isinstance(a[key], np.ndarray): self.assertTrue(np.equal(a[key], b[key]).all()) else: self.assertEqual(a[key], b[key]) self.assertEqual(a(key), b(key)) self.assertEqual(a.to_scheduler_param_list.__len__(), b.to_scheduler_param_list.__len__()) for a_val, b_val in zip(a.to_scheduler_param_list, b.to_scheduler_param_list): self.assertEqual(a_val['param_key'], b_val['param_key']) self.assertEqual(a_val['scheduler'].value(), b_val['scheduler'].value())