class ConstantActionPolicy(DeterministicPolicy): required_key_dict = DictConfig.load_json(file_path=GlobalConfig().DEFAULT_CONSTANT_ACTION_POLICY_REQUIRED_KEY_LIST) def __init__(self, env_spec: EnvSpec, config_or_config_dict: (DictConfig, dict), name='policy'): config = construct_dict_config(config_or_config_dict, self) parameters = Parameters(parameters=dict(), source_config=config) assert env_spec.action_space.contains(x=config('ACTION_VALUE')) super().__init__(env_spec, parameters, name) self.config = config def forward(self, *args, **kwargs): action = self.env_spec.action_space.unflatten(self.parameters('ACTION_VALUE')) assert self.env_spec.action_space.contains(x=action) return action def copy_from(self, obj) -> bool: super().copy_from(obj) self.parameters.copy_from(obj.parameters) return True def make_copy(self, *args, **kwargs): return ConstantActionPolicy(env_spec=self.env_spec, config_or_config_dict=deepcopy(self.config), *args, **kwargs)
def construct_dict_config(config_or_config_dict, obj): if isinstance(config_or_config_dict, dict): return DictConfig(required_key_dict=obj.required_key_dict, config_dict=config_or_config_dict, cls_name=type(obj).__name__) elif isinstance(config_or_config_dict, dict): return config_or_config_dict else: raise TypeError('Type {} is not supported, use dict or Config'.format( type(config_or_config_dict).__name__))
def __init__(self, parameters: dict, source_config: DictConfig = None, name='parameters', to_scheduler_param_tuple: tuple = None, default_save_param_key=None): self._parameters = parameters self.name = name self._source_config = source_config if source_config else DictConfig(required_key_dict=dict(), config_dict=dict()) self.default_save_param_key = default_save_param_key self._scheduler_info_dict = dict() self.to_scheduler_param_list = to_scheduler_param_tuple
class Dyna(ModelBasedAlgo): """ Dyna algorithms, Sutton, R. S. (1991). You can replace the dynamics model with any dynamics models you want. """ required_key_dict = DictConfig.load_json( file_path=GlobalConfig().DEFAULT_ALGO_DYNA_REQUIRED_KEY_LIST) @init_func_arg_record_decorator() @typechecked def __init__(self, env_spec, dynamics_model: DynamicsModel, model_free_algo: ModelFreeAlgo, config_or_config_dict: (DictConfig, dict), name='sample_with_dynamics'): super().__init__(env_spec, dynamics_model, name) config = construct_dict_config(config_or_config_dict, self) parameters = Parameters(parameters=dict(), name='dyna_param', source_config=config) sub_placeholder_input_list = [] if isinstance(dynamics_model, PlaceholderInput): sub_placeholder_input_list.append( dict(obj=dynamics_model, attr_name='dynamics_model')) if isinstance(model_free_algo, PlaceholderInput): sub_placeholder_input_list.append( dict(obj=model_free_algo, attr_name='model_free_algo')) self.model_free_algo = model_free_algo self.config = config self.parameters = parameters @register_counter_info_to_status_decorator(increment=1, info_key='init', under_status='JUST_INITED') def init(self): self.parameters.init() self.model_free_algo.init() self.dynamics_env.init() super().init() @record_return_decorator(which_recorder='self') @register_counter_info_to_status_decorator(increment=1, info_key='train_counter', under_status='TRAIN') def train(self, *args, **kwargs) -> dict: super(Dyna, self).train() res_dict = {} batch_data = kwargs['batch_data'] if 'batch_data' in kwargs else None if 'state' in kwargs: assert kwargs['state'] in ('state_dynamics_training', 'state_agent_training') state = kwargs['state'] kwargs.pop('state') else: state = None if not state or state == 'state_dynamics_training': dynamics_train_res_dict = self._fit_dynamics_model( batch_data=batch_data, train_iter=self.parameters('dynamics_model_train_iter')) for key, val in dynamics_train_res_dict.items(): res_dict["{}_{}".format(self._dynamics_model.name, key)] = val if not state or state == 'state_agent_training': model_free_algo_train_res_dict = self._train_model_free_algo( batch_data=batch_data, train_iter=self.parameters('model_free_algo_train_iter')) for key, val in model_free_algo_train_res_dict.items(): res_dict['{}_{}'.format(self.model_free_algo.name, key)] = val return res_dict @register_counter_info_to_status_decorator(increment=1, info_key='test_counter', under_status='TEST') def test(self, *arg, **kwargs): super().test(*arg, **kwargs) @register_counter_info_to_status_decorator(increment=1, info_key='predict_counter') def predict(self, obs, **kwargs): return self.model_free_algo.predict(obs) def append_to_memory(self, *args, **kwargs): self.model_free_algo.append_to_memory(kwargs['samples']) @record_return_decorator(which_recorder='self') def save(self, global_step, save_path=None, name=None, **kwargs): save_path = save_path if save_path else GlobalConfig( ).DEFAULT_MODEL_CHECKPOINT_PATH name = name if name else self.name self.model_free_algo.save(global_step=global_step, name=None, save_path=os.path.join( save_path, self.model_free_algo.name)) self.dynamics_env.save(global_step=global_step, name=None, save_path=os.path.join(save_path, self.dynamics_env.name)) return dict(check_point_save_path=save_path, check_point_save_global_step=global_step, check_point_save_name=name) @record_return_decorator(which_recorder='self') def load(self, path_to_model, model_name, global_step=None, **kwargs): self.model_free_algo.load(path_to_model=os.path.join( path_to_model, self.model_free_algo.name), model_name=self.model_free_algo.name, global_step=global_step) self.dynamics_env.load(global_step=global_step, path_to_model=os.path.join( path_to_model, self.dynamics_env.name), model_name=self.dynamics_env.name) return dict(check_point_load_path=path_to_model, check_point_load_global_step=global_step, check_point_load_name=model_name) @register_counter_info_to_status_decorator( increment=1, info_key='dyanmics_train_counter', under_status='TRAIN') def _fit_dynamics_model(self, batch_data: TransitionData, train_iter, sess=None) -> dict: res_dict = self._dynamics_model.train( batch_data, **dict(sess=sess, train_iter=train_iter)) return res_dict @register_counter_info_to_status_decorator( increment=1, info_key='mode_free_algo_dyanmics_train_counter', under_status='TRAIN') def _train_model_free_algo(self, batch_data=None, train_iter=None, sess=None): res_dict = self.model_free_algo.train( **dict(batch_data=batch_data, train_iter=train_iter, sess=sess)) return res_dict
class PPO(ModelFreeAlgo, OnPolicyAlgo, MultiPlaceholderInput): required_key_dict = DictConfig.load_json( file_path=GlobalConfig().DEFAULT_PPO_REQUIRED_KEY_LIST) @typechecked def __init__(self, env_spec: EnvSpec, stochastic_policy: StochasticPolicy, config_or_config_dict: (DictConfig, dict), value_func: VValueFunction, warm_up_trajectories_number=5, use_time_index_flag=False, name='ppo'): ModelFreeAlgo.__init__( self, env_spec=env_spec, name=name, warm_up_trajectories_number=warm_up_trajectories_number) self.use_time_index_flag = use_time_index_flag self.config = construct_dict_config(config_or_config_dict, self) self.policy = stochastic_policy self.value_func = value_func to_ph_parameter_dict = dict() self.trajectory_memory = TrajectoryData(env_spec=env_spec) self.transition_data_for_trajectory = TransitionData(env_spec=env_spec) self.value_func_train_data_buffer = None self.scaler = RunningStandardScaler(dims=self.env_spec.flat_obs_dim) if use_time_index_flag: scale_last_time_index_mean = self.scaler._mean scale_last_time_index_mean[-1] = 0 scale_last_time_index_var = self.scaler._var scale_last_time_index_var[-1] = 1000 * 1000 self.scaler.set_param(mean=scale_last_time_index_mean, var=scale_last_time_index_var) with tf.variable_scope(name): self.advantages_ph = tf.placeholder(tf.float32, (None, ), 'advantages') self.v_func_val_ph = tf.placeholder(tf.float32, (None, ), 'val_val_func') dist_info_list = self.policy.get_dist_info() self.old_dist_tensor = [ (tf.placeholder(**dict(dtype=dist_info['dtype'], shape=dist_info['shape'], name=dist_info['name'])), dist_info['name']) for dist_info in dist_info_list ] self.old_policy = self.policy.make_copy( reuse=False, name_scope='old_{}'.format(self.policy.name), name='old_{}'.format(self.policy.name), distribution_tensors_tuple=tuple(self.old_dist_tensor)) to_ph_parameter_dict['beta'] = tf.placeholder( tf.float32, (), 'beta') to_ph_parameter_dict['eta'] = tf.placeholder(tf.float32, (), 'eta') to_ph_parameter_dict['kl_target'] = tf.placeholder( tf.float32, (), 'kl_target') to_ph_parameter_dict['lr_multiplier'] = tf.placeholder( tf.float32, (), 'lr_multiplier') self.parameters = ParametersWithTensorflowVariable( tf_var_list=[], rest_parameters=dict( advantages_ph=self.advantages_ph, v_func_val_ph=self.v_func_val_ph, ), to_ph_parameter_dict=to_ph_parameter_dict, name='ppo_param', save_rest_param_flag=False, source_config=self.config, require_snapshot=False) with tf.variable_scope(name): with tf.variable_scope('train'): self.kl = tf.reduce_mean(self.old_policy.kl(self.policy)) self.average_entropy = tf.reduce_mean(self.policy.entropy()) self.policy_loss, self.policy_optimizer, self.policy_update_op = self._setup_policy_loss( ) self.value_func_loss, self.value_func_optimizer, self.value_func_update_op = self._setup_value_func_loss( ) var_list = get_tf_collection_var_list( '{}/train'.format(name)) + self.policy_optimizer.variables( ) + self.value_func_optimizer.variables() self.parameters.set_tf_var_list( tf_var_list=sorted(list(set(var_list)), key=lambda x: x.name)) MultiPlaceholderInput.__init__(self, sub_placeholder_input_list=[ dict( obj=self.value_func, attr_name='value_func', ), dict(obj=self.policy, attr_name='policy') ], parameters=self.parameters) def warm_up(self, trajectory_data: TrajectoryData): for traj in trajectory_data.trajectories: self.scaler.update_scaler(data=traj.state_set) if self.use_time_index_flag: scale_last_time_index_mean = self.scaler._mean scale_last_time_index_mean[-1] = 0 scale_last_time_index_var = self.scaler._var scale_last_time_index_var[-1] = 1000 * 1000 self.scaler.set_param(mean=scale_last_time_index_mean, var=scale_last_time_index_var) @register_counter_info_to_status_decorator(increment=1, info_key='init', under_status='INITED') def init(self, sess=None, source_obj=None): self.policy.init() self.value_func.init() self.parameters.init() if source_obj: self.copy_from(source_obj) super().init() @record_return_decorator(which_recorder='self') @register_counter_info_to_status_decorator(increment=1, info_key='train', under_status='TRAIN') def train(self, trajectory_data: TrajectoryData = None, train_iter=None, sess=None) -> dict: super(PPO, self).train() if trajectory_data is None: trajectory_data = self.trajectory_memory if len(trajectory_data) == 0: raise MemoryBufferLessThanBatchSizeError( 'not enough trajectory data') for i, traj in enumerate(trajectory_data.trajectories): trajectory_data.trajectories[i].append_new_set( name='state_set', shape=self.env_spec.obs_shape, data_set=np.reshape( np.array(self.scaler.process(np.array(traj.state_set))), [-1] + list(self.env_spec.obs_shape))) trajectory_data.trajectories[i].append_new_set( name='new_state_set', shape=self.env_spec.obs_shape, data_set=np.reshape( np.array(self.scaler.process(np.array( traj.new_state_set))), [-1] + list(self.env_spec.obs_shape))) tf_sess = sess if sess else tf.get_default_session() SampleProcessor.add_estimated_v_value(trajectory_data, value_func=self.value_func) SampleProcessor.add_discount_sum_reward(trajectory_data, gamma=self.parameters('gamma')) SampleProcessor.add_gae(trajectory_data, gamma=self.parameters('gamma'), name='advantage_set', lam=self.parameters('lam'), value_func=self.value_func) trajectory_data = SampleProcessor.normalization(trajectory_data, key='advantage_set') policy_res_dict = self._update_policy( state_set=np.concatenate( [t('state_set') for t in trajectory_data.trajectories], axis=0), action_set=np.concatenate( [t('action_set') for t in trajectory_data.trajectories], axis=0), advantage_set=np.concatenate( [t('advantage_set') for t in trajectory_data.trajectories], axis=0), train_iter=train_iter if train_iter else self.parameters('policy_train_iter'), sess=tf_sess) value_func_res_dict = self._update_value_func( state_set=np.concatenate( [t('state_set') for t in trajectory_data.trajectories], axis=0), discount_set=np.concatenate( [t('discount_set') for t in trajectory_data.trajectories], axis=0), train_iter=train_iter if train_iter else self.parameters('value_func_train_iter'), sess=tf_sess) trajectory_data.reset() self.trajectory_memory.reset() return {**policy_res_dict, **value_func_res_dict} @register_counter_info_to_status_decorator(increment=1, info_key='test', under_status='TEST') def test(self, *arg, **kwargs) -> dict: return super().test(*arg, **kwargs) @register_counter_info_to_status_decorator(increment=1, info_key='predict') def predict(self, obs: np.ndarray, sess=None, batch_flag: bool = False): tf_sess = sess if sess else tf.get_default_session() ac = self.policy.forward( obs=self.scaler.process( data=make_batch(obs, original_shape=self.env_spec.obs_shape)), sess=tf_sess, feed_dict=self.parameters.return_tf_parameter_feed_dict()) return ac def append_to_memory(self, samples: TrajectoryData): # todo how to make sure the data's time sequential obs_list = samples.trajectories[0].state_set for i in range(1, len(samples.trajectories)): obs_list = np.array( np.concatenate([obs_list, samples.trajectories[i].state_set], axis=0)) self.trajectory_memory.union(samples) self.scaler.update_scaler(data=np.array(obs_list)) if self.use_time_index_flag: scale_last_time_index_mean = self.scaler._mean scale_last_time_index_mean[-1] = 0 scale_last_time_index_var = self.scaler._var scale_last_time_index_var[-1] = 1000 * 1000 self.scaler.set_param(mean=scale_last_time_index_mean, var=scale_last_time_index_var) @record_return_decorator(which_recorder='self') def save(self, global_step, save_path=None, name=None, **kwargs): save_path = save_path if save_path else GlobalConfig( ).DEFAULT_MODEL_CHECKPOINT_PATH name = name if name else self.name MultiPlaceholderInput.save(self, save_path=save_path, global_step=global_step, name=name, **kwargs) return dict(check_point_save_path=save_path, check_point_save_global_step=global_step, check_point_save_name=name) @record_return_decorator(which_recorder='self') def load(self, path_to_model, model_name, global_step=None, **kwargs): MultiPlaceholderInput.load(self, path_to_model, model_name, global_step, **kwargs) return dict(check_point_load_path=path_to_model, check_point_load_global_step=global_step, check_point_load_name=model_name) def _setup_policy_loss(self): """ Code clip from pat-cody Three loss terms: 1) standard policy gradient 2) D_KL(pi_old || pi_new) 3) Hinge loss on [D_KL - kl_targ]^2 See: https://arxiv.org/pdf/1707.02286.pdf """ if self.parameters('clipping_range') is not None: pg_ratio = tf.exp(self.policy.log_prob() - self.old_policy.log_prob()) clipped_pg_ratio = tf.clip_by_value( pg_ratio, 1 - self.parameters('clipping_range')[0], 1 + self.parameters('clipping_range')[1]) surrogate_loss = tf.minimum(self.advantages_ph * pg_ratio, self.advantages_ph * clipped_pg_ratio) loss = -tf.reduce_mean(surrogate_loss) else: loss1 = -tf.reduce_mean( self.advantages_ph * tf.exp(self.policy.log_prob() - self.old_policy.log_prob())) loss2 = tf.reduce_mean(self.parameters('beta') * self.kl) loss3 = self.parameters('eta') * tf.square( tf.maximum(0.0, self.kl - 2.0 * self.parameters('kl_target'))) loss = loss1 + loss2 + loss3 self.loss1 = loss1 self.loss2 = loss2 self.loss3 = loss3 if isinstance(self.policy, PlaceholderInput): reg_list = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=self.policy.name_scope) if len(reg_list) > 0: reg_loss = tf.reduce_sum(reg_list) loss += reg_loss optimizer = tf.train.AdamOptimizer( learning_rate=self.parameters('policy_lr') * self.parameters('lr_multiplier')) train_op = optimizer.minimize( loss, var_list=self.policy.parameters('tf_var_list')) return loss, optimizer, train_op def _setup_value_func_loss(self): # todo update the value_func design loss = tf.reduce_mean( tf.square( tf.squeeze(self.value_func.v_tensor) - self.v_func_val_ph)) reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=self.value_func.name_scope) if len(reg_loss) > 0: loss += tf.reduce_sum(reg_loss) optimizer = tf.train.AdamOptimizer(self.parameters('value_func_lr')) train_op = optimizer.minimize( loss, var_list=self.value_func.parameters('tf_var_list')) return loss, optimizer, train_op def _update_policy(self, state_set, action_set, advantage_set, train_iter, sess): old_policy_feed_dict = dict() res = sess.run( [ getattr(self.policy, tensor[1]) for tensor in self.old_dist_tensor ], feed_dict={ self.policy.parameters('state_input'): state_set, self.policy.parameters('action_input'): action_set, **self.parameters.return_tf_parameter_feed_dict() }) for tensor, val in zip(self.old_dist_tensor, res): old_policy_feed_dict[tensor[0]] = val feed_dict = { self.policy.parameters('action_input'): action_set, self.old_policy.parameters('action_input'): action_set, self.policy.parameters('state_input'): state_set, self.advantages_ph: advantage_set, **self.parameters.return_tf_parameter_feed_dict(), **old_policy_feed_dict } average_loss, average_kl, average_entropy = 0.0, 0.0, 0.0 total_epoch = 0 kl = None for i in range(train_iter): _ = sess.run(self.policy_update_op, feed_dict=feed_dict) loss, kl, entropy = sess.run( [self.policy_loss, self.kl, self.average_entropy], feed_dict=feed_dict) average_loss += loss average_kl += kl average_entropy += entropy total_epoch = i + 1 if kl > self.parameters('kl_target', require_true_value=True) * 4: # early stopping if D_KL diverges badly break average_loss, average_kl, average_entropy = average_loss / total_epoch, average_kl / total_epoch, average_entropy / total_epoch if kl > self.parameters('kl_target', require_true_value=True ) * 2: # servo beta to reach D_KL target self.parameters.set( key='beta', new_val=np.minimum( 35, 1.5 * self.parameters('beta', require_true_value=True))) if self.parameters( 'beta', require_true_value=True) > 30 and self.parameters( 'lr_multiplier', require_true_value=True) > 0.1: self.parameters.set( key='lr_multiplier', new_val=self.parameters('lr_multiplier', require_true_value=True) / 1.5) elif kl < self.parameters('kl_target', require_true_value=True) / 2: self.parameters.set( key='beta', new_val=np.maximum( 1 / 35, self.parameters('beta', require_true_value=True) / 1.5)) if self.parameters('beta', require_true_value=True) < ( 1 / 30) and self.parameters('lr_multiplier', require_true_value=True) < 10: self.parameters.set( key='lr_multiplier', new_val=self.parameters('lr_multiplier', require_true_value=True) * 1.5) return dict(policy_average_loss=average_loss, policy_average_kl=average_kl, policy_average_entropy=average_entropy, policy_total_train_epoch=total_epoch) def _update_value_func(self, state_set, discount_set, train_iter, sess): y_hat = self.value_func.forward(obs=state_set).squeeze() old_exp_var = 1 - np.var(discount_set - y_hat) / np.var(discount_set) if self.value_func_train_data_buffer is None: self.value_func_train_data_buffer = (state_set, discount_set) else: self.value_func_train_data_buffer = ( np.concatenate( [self.value_func_train_data_buffer[0], state_set], axis=0), np.concatenate( [self.value_func_train_data_buffer[1], discount_set], axis=0)) if len(self.value_func_train_data_buffer[0]) > self.parameters( 'value_func_memory_size'): self.value_func_train_data_buffer = tuple( np.array(data[-self.parameters('value_func_memory_size'):]) for data in self.value_func_train_data_buffer) state_set_all, discount_set_all = self.value_func_train_data_buffer param_dict = self.parameters.return_tf_parameter_feed_dict() for i in range(train_iter): random_index = np.random.choice(np.arange(len(state_set_all)), len(state_set_all)) state_set_all = state_set_all[random_index] discount_set_all = discount_set_all[random_index] for index in range( 0, len(state_set_all) - self.parameters('value_func_train_batch_size'), self.parameters('value_func_train_batch_size')): state = np.array( state_set_all[index:index + self. parameters('value_func_train_batch_size')]) discount = discount_set_all[ index:index + self.parameters('value_func_train_batch_size')] loss, _ = sess.run( [self.value_func_loss, self.value_func_update_op], options=tf.RunOptions( report_tensor_allocations_upon_oom=True), feed_dict={ self.value_func.state_input: state, self.v_func_val_ph: discount, **param_dict }) y_hat = self.value_func.forward(obs=state_set).squeeze() loss = np.mean(np.square(y_hat - discount_set)) exp_var = 1 - np.var(discount_set - y_hat) / np.var(discount_set) return dict(value_func_loss=loss, value_func_policy_exp_var=exp_var, value_func_policy_old_exp_var=old_exp_var)
class DQN(ModelFreeAlgo, OffPolicyAlgo, MultiPlaceholderInput): required_key_dict = DictConfig.load_json(file_path=GlobalConfig().DEFAULT_DQN_REQUIRED_KEY_LIST) @init_func_arg_record_decorator() @typechecked def __init__(self, env_spec, config_or_config_dict: (DictConfig, dict), value_func: MLPQValueFunction, schedule_param_list=None, name: str = 'dqn', replay_buffer=None): ModelFreeAlgo.__init__(self, env_spec=env_spec, name=name) self.config = construct_dict_config(config_or_config_dict, self) if replay_buffer: assert issubclass(replay_buffer, BaseReplayBuffer) self.replay_buffer = replay_buffer else: self.replay_buffer = UniformRandomReplayBuffer(limit=self.config('REPLAY_BUFFER_SIZE'), action_shape=self.env_spec.action_shape, observation_shape=self.env_spec.obs_shape) self.q_value_func = value_func self.state_input = self.q_value_func.state_input self.action_input = self.q_value_func.action_input self.update_target_q_every_train = self.config('UPDATE_TARGET_Q_FREQUENCY') if 'UPDATE_TARGET_Q_FREQUENCY' in \ self.config.config_dict else 1 self.parameters = ParametersWithTensorflowVariable(tf_var_list=[], rest_parameters=dict(), to_scheduler_param_tuple=schedule_param_list, name='{}_param'.format(name), source_config=self.config, require_snapshot=False) with tf.variable_scope(name): self.reward_input = tf.placeholder(shape=[None, 1], dtype=tf.float32) self.next_state_input = tf.placeholder(shape=[None, self.env_spec.flat_obs_dim], dtype=tf.float32) self.done_input = tf.placeholder(shape=[None, 1], dtype=tf.bool) self.target_q_input = tf.placeholder(shape=[None, 1], dtype=tf.float32) done = tf.cast(self.done_input, dtype=tf.float32) self.target_q_value_func = self.q_value_func.make_copy(name_scope='{}_target_q_value_net'.format(name), name='{}_target_q_value_net'.format(name), reuse=False) self.predict_q_value = (1. - done) * self.config('GAMMA') * self.target_q_input + self.reward_input self.td_error = self.predict_q_value - self.q_value_func.q_tensor with tf.variable_scope('train'): self.q_value_func_loss, self.optimizer, self.update_q_value_func_op = self._set_up_loss() self.update_target_q_value_func_op = self._set_up_target_update() # redundant sort operation on var_list var_list = get_tf_collection_var_list(key=tf.GraphKeys.GLOBAL_VARIABLES, scope='{}/train'.format(name)) + self.optimizer.variables() self.parameters.set_tf_var_list(tf_var_list=sorted(list(set(var_list)), key=lambda x: x.name)) MultiPlaceholderInput.__init__(self, sub_placeholder_input_list=[dict(obj=self.q_value_func, attr_name='q_value_func'), dict(obj=self.target_q_value_func, attr_name='target_q_value_func')], parameters=self.parameters) @register_counter_info_to_status_decorator(increment=1, info_key='init', under_status='INITED') def init(self, sess=None, source_obj=None): super().init() self.q_value_func.init() self.target_q_value_func.init(source_obj=self.q_value_func) self.parameters.init() if source_obj: self.copy_from(source_obj) @record_return_decorator(which_recorder='self') @register_counter_info_to_status_decorator(increment=1, info_key='train_counter', under_status='TRAIN') def train(self, batch_data=None, train_iter=None, sess=None, update_target=True) -> dict: super(DQN, self).train() self.recorder.record() if batch_data and not isinstance(batch_data, TransitionData): raise TypeError() tf_sess = sess if sess else tf.get_default_session() train_iter = self.parameters("TRAIN_ITERATION") if not train_iter else train_iter average_loss = 0.0 for i in range(train_iter): if batch_data is None: train_data = self.replay_buffer.sample(batch_size=self.parameters('BATCH_SIZE')) else: train_data = batch_data _, target_q_val_on_new_s = self.predict_target_with_q_val(obs=train_data.new_state_set, batch_flag=True) target_q_val_on_new_s = np.expand_dims(target_q_val_on_new_s, axis=1) assert target_q_val_on_new_s.shape[0] == train_data.state_set.shape[0] feed_dict = { self.reward_input: np.reshape(train_data.reward_set, [-1, 1]), self.action_input: flatten_n(self.env_spec.action_space, train_data.action_set), self.state_input: train_data.state_set, self.done_input: np.reshape(train_data.done_set, [-1, 1]), self.target_q_input: target_q_val_on_new_s, **self.parameters.return_tf_parameter_feed_dict() } res, _ = tf_sess.run([self.q_value_func_loss, self.update_q_value_func_op], feed_dict=feed_dict) average_loss += res average_loss /= train_iter if update_target is True and self.get_status()['train_counter'] % self.update_target_q_every_train == 0: tf_sess.run(self.update_target_q_value_func_op, feed_dict=self.parameters.return_tf_parameter_feed_dict()) return dict(average_loss=average_loss) @register_counter_info_to_status_decorator(increment=1, info_key='test_counter', under_status='TEST') def test(self, *arg, **kwargs): return super().test(*arg, **kwargs) @register_counter_info_to_status_decorator(increment=1, info_key='predict_counter') def predict(self, obs: np.ndarray, sess=None, batch_flag: bool = False): if batch_flag: action, q_val = self._predict_batch_action(obs=obs, q_value_tensor=self.q_value_func.q_tensor, action_ph=self.action_input, state_ph=self.state_input, sess=sess) else: action, q_val = self._predict_action(obs=obs, q_value_tensor=self.q_value_func.q_tensor, action_ph=self.action_input, state_ph=self.state_input, sess=sess) if not batch_flag: return int(action) else: return action.astype(np.int).tolist() def predict_target_with_q_val(self, obs: np.ndarray, sess=None, batch_flag: bool = False): if batch_flag: action, q_val = self._predict_batch_action(obs=obs, q_value_tensor=self.target_q_value_func.q_tensor, action_ph=self.target_q_value_func.action_input, state_ph=self.target_q_value_func.state_input, sess=sess) else: action, q_val = self._predict_action(obs=obs, q_value_tensor=self.target_q_value_func.q_tensor, action_ph=self.target_q_value_func.action_input, state_ph=self.target_q_value_func.state_input, sess=sess) return action, q_val # Store Transition @register_counter_info_to_status_decorator(increment=1, info_key='append_to_memory') def append_to_memory(self, samples: TransitionData): self.replay_buffer.append_batch(obs0=samples.state_set, obs1=samples.new_state_set, action=samples.action_set, reward=samples.reward_set, terminal1=samples.done_set) self._status.update_info(info_key='replay_buffer_data_total_count', increment=len(samples)) @record_return_decorator(which_recorder='self') def save(self, global_step, save_path=None, name=None, **kwargs): save_path = save_path if save_path else GlobalConfig().DEFAULT_MODEL_CHECKPOINT_PATH name = name if name else self.name MultiPlaceholderInput.save(self, save_path=save_path, global_step=global_step, name=name, **kwargs) return dict(check_point_save_path=save_path, check_point_save_global_step=global_step, check_point_save_name=name) @record_return_decorator(which_recorder='self') def load(self, path_to_model, model_name, global_step=None, **kwargs): MultiPlaceholderInput.load(self, path_to_model, model_name, global_step, **kwargs) return dict(check_point_load_path=path_to_model, check_point_load_global_step=global_step, check_point_load_name=model_name) def _predict_action(self, obs: np.ndarray, q_value_tensor: tf.Tensor, action_ph: tf.Tensor, state_ph: tf.Tensor, sess=None): if self.env_spec.obs_space.contains(obs) is False: raise StateOrActionOutOfBoundError("obs {} out of bound {}".format(obs, self.env_spec.obs_space.bound())) obs = repeat_ndarray(obs, repeats=self.env_spec.flat_action_dim) tf_sess = sess if sess else tf.get_default_session() feed_dict = {action_ph: generate_n_actions_hot_code(n=self.env_spec.flat_action_dim), state_ph: obs, **self.parameters.return_tf_parameter_feed_dict()} res = tf_sess.run([q_value_tensor], feed_dict=feed_dict)[0] return np.argmax(res, axis=0), np.max(res, axis=0) def _predict_batch_action(self, obs: np.ndarray, q_value_tensor: tf.Tensor, action_ph: tf.Tensor, state_ph: tf.Tensor, sess=None): actions = [] q_values = [] for obs_i in obs: action, q_val = self._predict_action(obs=obs_i, q_value_tensor=q_value_tensor, action_ph=action_ph, state_ph=state_ph, sess=sess) actions.append(np.argmax(action, axis=0)) q_values.append(np.max(q_val, axis=0)) return np.array(actions), np.array(q_values) def _set_up_loss(self): reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=self.q_value_func.name_scope) loss = tf.reduce_sum((self.predict_q_value - self.q_value_func.q_tensor) ** 2) if len(reg_loss) > 0: loss += tf.reduce_sum(reg_loss) optimizer = tf.train.AdamOptimizer(learning_rate=self.parameters('LEARNING_RATE')) optimize_op = optimizer.minimize(loss=loss, var_list=self.q_value_func.parameters('tf_var_list')) return loss, optimizer, optimize_op # update target net def _set_up_target_update(self): op = [] for var, target_var in zip(self.q_value_func.parameters('tf_var_list'), self.target_q_value_func.parameters('tf_var_list')): ref_val = self.parameters('DECAY') * target_var + (1.0 - self.parameters('DECAY')) * var op.append(tf.assign(target_var, ref_val)) return op def _evaluate_td_error(self, sess=None): # tf_sess = sess if sess else tf.get_default_session() # feed_dict = { # self.reward_input: train_data.reward_set, # self.action_input: flatten_n(self.env_spec.action_space, train_data.action_set), # self.state_input: train_data.state_set, # self.done_input: train_data.done_set, # self.target_q_input: target_q_val_on_new_s, # **self.parameters.return_tf_parameter_feed_dict() # } # td_loss = tf_sess.run([self.td_error], feed_dict=feed_dict) pass
class DDPG(ModelFreeAlgo, OffPolicyAlgo, MultiPlaceholderInput): required_key_dict = DictConfig.load_json(file_path=GlobalConfig().DEFAULT_DDPG_REQUIRED_KEY_LIST) @typechecked def __init__(self, env_spec: EnvSpec, config_or_config_dict: (DictConfig, dict), value_func: MLPQValueFunction, policy: DeterministicMLPPolicy, schedule_param_list=None, name='ddpg', replay_buffer=None): """ :param env_spec: environment specifications, like action apace or observation space :param config_or_config_dict: configuraion dictionary, like learning rate or decay, if any :param value_func: value function :param policy: agent policy :param schedule_param_list: schedule parameter list, if any initla final function to schedule learning process :param name: name of algorithm class instance :param replay_buffer: replay buffer, if any """ ModelFreeAlgo.__init__(self, env_spec=env_spec, name=name) config = construct_dict_config(config_or_config_dict, self) self.config = config self.actor = policy self.target_actor = self.actor.make_copy(name_scope='{}_target_actor'.format(self.name), name='{}_target_actor'.format(self.name), reuse=False) self.critic = value_func self.target_critic = self.critic.make_copy(name_scope='{}_target_critic'.format(self.name), name='{}_target_critic'.format(self.name), reuse=False) self.state_input = self.actor.state_input if replay_buffer: assert issubclass(replay_buffer, BaseReplayBuffer) self.replay_buffer = replay_buffer else: self.replay_buffer = UniformRandomReplayBuffer(limit=self.config('REPLAY_BUFFER_SIZE'), action_shape=self.env_spec.action_shape, observation_shape=self.env_spec.obs_shape) """ self.parameters contains all the parameters (variables) of the algorithm """ self.parameters = ParametersWithTensorflowVariable(tf_var_list=[], rest_parameters=dict(), to_scheduler_param_tuple=schedule_param_list, name='ddpg_param', source_config=config, require_snapshot=False) self._critic_with_actor_output = self.critic.make_copy(reuse=True, name='actor_input_{}'.format(self.critic.name), state_input=self.state_input, action_input=self.actor.action_tensor) self._target_critic_with_target_actor_output = self.target_critic.make_copy(reuse=True, name='target_critic_with_target_actor_output_{}'.format( self.critic.name), action_input=self.target_actor.action_tensor) with tf.variable_scope(name): self.reward_input = tf.placeholder(shape=[None, 1], dtype=tf.float32) self.next_state_input = tf.placeholder(shape=[None, self.env_spec.flat_obs_dim], dtype=tf.float32) self.done_input = tf.placeholder(shape=[None, 1], dtype=tf.bool) self.target_q_input = tf.placeholder(shape=[None, 1], dtype=tf.float32) done = tf.cast(self.done_input, dtype=tf.float32) self.predict_q_value = (1. - done) * self.config('GAMMA') * self.target_q_input + self.reward_input with tf.variable_scope('train'): self.critic_loss, self.critic_update_op, self.target_critic_update_op, self.critic_optimizer, \ self.critic_grads = self._setup_critic_loss() self.actor_loss, self.actor_update_op, self.target_actor_update_op, self.action_optimizer, \ self.actor_grads = self._set_up_actor_loss() var_list = get_tf_collection_var_list( '{}/train'.format(name)) + self.critic_optimizer.variables() + self.action_optimizer.variables() self.parameters.set_tf_var_list(tf_var_list=sorted(list(set(var_list)), key=lambda x: x.name)) MultiPlaceholderInput.__init__(self, sub_placeholder_input_list=[dict(obj=self.target_actor, attr_name='target_actor', ), dict(obj=self.actor, attr_name='actor'), dict(obj=self.critic, attr_name='critic'), dict(obj=self.target_critic, attr_name='target_critic') ], parameters=self.parameters) @register_counter_info_to_status_decorator(increment=1, info_key='init', under_status='INITED') def init(self, sess=None, source_obj=None): self.actor.init() self.critic.init() self.target_actor.init() self.target_critic.init(source_obj=self.critic) self.parameters.init() if source_obj: self.copy_from(source_obj) super().init() @record_return_decorator(which_recorder='self') @register_counter_info_to_status_decorator(increment=1, info_key='train', under_status='TRAIN') def train(self, batch_data=None, train_iter=None, sess=None, update_target=True) -> dict: super(DDPG, self).train() if isinstance(batch_data, TrajectoryData): batch_data = batch_data.return_as_transition_data(shuffle_flag=True) tf_sess = sess if sess else tf.get_default_session() train_iter = self.parameters("TRAIN_ITERATION") if not train_iter else train_iter average_critic_loss = 0.0 average_actor_loss = 0.0 for i in range(train_iter): train_batch = self.replay_buffer.sample( batch_size=self.parameters('BATCH_SIZE')) if batch_data is None else batch_data assert isinstance(train_batch, TransitionData) critic_loss, _ = self._critic_train(train_batch, tf_sess) actor_loss, _ = self._actor_train(train_batch, tf_sess) average_actor_loss += actor_loss average_critic_loss += critic_loss if update_target: tf_sess.run([self.target_actor_update_op, self.target_critic_update_op]) return dict(average_actor_loss=average_actor_loss / train_iter, average_critic_loss=average_critic_loss / train_iter) def _critic_train(self, batch_data, sess) -> (): target_q = sess.run( self._target_critic_with_target_actor_output.q_tensor, feed_dict={ self._target_critic_with_target_actor_output.state_input: batch_data.new_state_set, self.target_actor.state_input: batch_data.new_state_set } ) loss, _, grads = sess.run( [self.critic_loss, self.critic_update_op, self.critic_grads ], feed_dict={ self.target_q_input: target_q, self.critic.state_input: batch_data.state_set, self.critic.action_input: batch_data.action_set, self.done_input: np.reshape(batch_data.done_set, [-1, 1]), self.reward_input: np.reshape(batch_data.reward_set, [-1, 1]), **self.parameters.return_tf_parameter_feed_dict() } ) return loss, grads def _actor_train(self, batch_data, sess) -> (): target_q, loss, _, grads = sess.run( [self._critic_with_actor_output.q_tensor, self.actor_loss, self.actor_update_op, self.actor_grads], feed_dict={ self.actor.state_input: batch_data.state_set, self._critic_with_actor_output.state_input: batch_data.state_set, **self.parameters.return_tf_parameter_feed_dict() } ) return loss, grads @register_counter_info_to_status_decorator(increment=1, info_key='test', under_status='TEST') def test(self, *arg, **kwargs) -> dict: return super().test(*arg, **kwargs) def predict(self, obs: np.ndarray, sess=None, batch_flag: bool = False): tf_sess = sess if sess else tf.get_default_session() feed_dict = { self.state_input: make_batch(obs, original_shape=self.env_spec.obs_shape), **self.parameters.return_tf_parameter_feed_dict() } return self.actor.forward(obs=obs, sess=tf_sess, feed_dict=feed_dict) def append_to_memory(self, samples: TransitionData): self.replay_buffer.append_batch(obs0=samples.state_set, obs1=samples.new_state_set, action=samples.action_set, reward=samples.reward_set, terminal1=samples.done_set) @record_return_decorator(which_recorder='self') def save(self, global_step, save_path=None, name=None, **kwargs): save_path = save_path if save_path else GlobalConfig().DEFAULT_MODEL_CHECKPOINT_PATH name = name if name else self.name MultiPlaceholderInput.save(self, save_path=save_path, global_step=global_step, name=name, **kwargs) return dict(check_point_save_path=save_path, check_point_save_global_step=global_step, check_point_save_name=name) @record_return_decorator(which_recorder='self') def load(self, path_to_model, model_name, global_step=None, **kwargs): MultiPlaceholderInput.load(self, path_to_model, model_name, global_step, **kwargs) return dict(check_point_load_path=path_to_model, check_point_load_global_step=global_step, check_point_load_name=model_name) def _setup_critic_loss(self): reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=self.critic.name_scope) loss = tf.reduce_sum((self.predict_q_value - self.critic.q_tensor) ** 2) if len(reg_loss) > 0: loss += tf.reduce_sum(reg_loss) optimizer = tf.train.AdamOptimizer(learning_rate=self.parameters('CRITIC_LEARNING_RATE')) grad_var_pair = optimizer.compute_gradients(loss=loss, var_list=self.critic.parameters('tf_var_list')) grads = [g[0] for g in grad_var_pair] if self.parameters('critic_clip_norm') is not None: grad_var_pair, grads = clip_grad(optimizer=optimizer, loss=loss, var_list=self.critic.parameters('tf_var_list'), clip_norm=self.parameters('critic_clip_norm')) optimize_op = optimizer.apply_gradients(grad_var_pair) op = [] for var, target_var in zip(self.critic.parameters('tf_var_list'), self.target_critic.parameters('tf_var_list')): ref_val = self.parameters('DECAY') * target_var + (1.0 - self.parameters('DECAY')) * var op.append(tf.assign(target_var, ref_val)) return loss, optimize_op, op, optimizer, grads def _set_up_actor_loss(self): reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=self.actor.name_scope) loss = -tf.reduce_mean(self._critic_with_actor_output.q_tensor) if len(reg_loss) > 0: loss += tf.reduce_sum(reg_loss) optimizer = tf.train.AdamOptimizer(learning_rate=self.parameters('CRITIC_LEARNING_RATE')) grad_var_pair = optimizer.compute_gradients(loss=loss, var_list=self.actor.parameters('tf_var_list')) grads = [g[0] for g in grad_var_pair] if self.parameters('actor_clip_norm') is not None: grad_var_pair, grads = clip_grad(optimizer=optimizer, loss=loss, var_list=self.actor.parameters('tf_var_list'), clip_norm=self.parameters('critic_clip_norm')) optimize_op = optimizer.apply_gradients(grad_var_pair) op = [] for var, target_var in zip(self.actor.parameters('tf_var_list'), self.target_actor.parameters('tf_var_list')): ref_val = self.parameters('DECAY') * target_var + (1.0 - self.parameters('DECAY')) * var op.append(tf.assign(target_var, ref_val)) return loss, optimize_op, op, optimizer, grads
class ModelPredictiveControl(ModelBasedAlgo): required_key_dict = DictConfig.load_json( file_path=GlobalConfig().DEFAULT_MPC_REQUIRED_KEY_LIST) def __init__( self, env_spec, dynamics_model: DynamicsModel, config_or_config_dict: (DictConfig, dict), policy: Policy, name='mpc', ): super().__init__(env_spec, dynamics_model, name) self.config = construct_dict_config(config_or_config_dict, self) self.policy = policy self.parameters = Parameters(parameters=dict(), source_config=self.config, name=name + '_' + 'mpc_param') self.memory = TransitionData(env_spec=env_spec) def init(self, source_obj=None): super().init() self.parameters.init() self._dynamics_model.init() self.policy.init() if source_obj: self.copy_from(source_obj) def train(self, *arg, **kwargs) -> dict: super(ModelPredictiveControl, self).train() res_dict = {} batch_data = kwargs[ 'batch_data'] if 'batch_data' in kwargs else self.memory dynamics_train_res_dict = self._fit_dynamics_model( batch_data=batch_data, train_iter=self.parameters('dynamics_model_train_iter')) for key, val in dynamics_train_res_dict.items(): res_dict["mlp_dynamics_{}".format(key)] = val return res_dict def test(self, *arg, **kwargs) -> dict: return super().test(*arg, **kwargs) def _fit_dynamics_model(self, batch_data: TransitionData, train_iter, sess=None) -> dict: res_dict = self._dynamics_model.train( batch_data, **dict(sess=sess, train_iter=train_iter)) return res_dict def predict(self, obs, **kwargs): if self.is_training is True: return self.env_spec.action_space.sample() rollout = TrajectoryData(env_spec=self.env_spec) state = obs for i in range(self.parameters('SAMPLED_PATH_NUM')): path = TransitionData(env_spec=self.env_spec) # todo terminal_func signal problem to be consider? for _ in range(self.parameters('SAMPLED_HORIZON')): ac = self.policy.forward(obs=state) new_state, re, done, _ = self.dynamics_env.step(action=ac, state=state) path.append(state=state, action=ac, new_state=new_state, reward=re, done=done) state = new_state rollout.append(path) rollout.trajectories.sort(key=lambda x: x.cumulative_reward, reverse=True) ac = rollout.trajectories[0].action_set[0] assert self.env_spec.action_space.contains(ac) return ac def append_to_memory(self, samples: TransitionData): self.memory.union(samples) def copy_from(self, obj) -> bool: if not isinstance(obj, type(self)): raise TypeError( 'Wrong type of obj %s to be copied, which should be %s' % (type(obj), type(self))) self.parameters.copy_from(obj.parameters) self._dynamics_model.copy_from(obj._dynamics_model) ConsoleLogger().print('info', 'model: {} copied from {}'.format(self, obj)) return True @record_return_decorator(which_recorder='self') def save(self, global_step, save_path=None, name=None, **kwargs): save_path = save_path if save_path else GlobalConfig( ).DEFAULT_MODEL_CHECKPOINT_PATH name = name if name else self.name self._dynamics_model.save(save_path=save_path, global_step=global_step, name=name, **kwargs) self.policy.save(save_path=save_path, global_step=global_step, name=name, **kwargs) return dict(check_point_save_path=save_path, check_point_save_global_step=global_step, check_point_save_name=name) @record_return_decorator(which_recorder='self') def load(self, path_to_model, model_name, global_step=None, **kwargs): self._dynamics_model.load(path_to_model, model_name, global_step, **kwargs) self.policy.load(path_to_model, model_name, global_step, **kwargs) return dict(check_point_load_path=path_to_model, check_point_load_global_step=global_step, check_point_load_name=model_name)
class MEPPO(ModelBasedAlgo): """ Model Ensemble, Proximal Policy Optimisation """ required_key_dict = DictConfig.load_json(file_path=GlobalConfig().DEFAULT_ALGO_DYNA_REQUIRED_KEY_LIST) @init_func_arg_record_decorator() @typechecked def __init__(self, env_spec, dynamics_model: ModelEnsemble, model_free_algo: ModelFreeAlgo, config_or_config_dict: (DictConfig, dict), name='sample_with_dynamics' ): if not isinstance(dynamics_model.model[0], ContinuousMLPGlobalDynamicsModel): raise TypeError("Model ensemble elements should be of type ContinuousMLPGlobalDynamicsModel") super().__init__(env_spec, dynamics_model, name) config = construct_dict_config(config_or_config_dict, self) parameters = Parameters(parameters=dict(), name='dyna_param', source_config=config) sub_placeholder_input_list = [] if isinstance(dynamics_model, PlaceholderInput): sub_placeholder_input_list.append(dict(obj=dynamics_model, attr_name='dynamics_model')) if isinstance(model_free_algo, PlaceholderInput): sub_placeholder_input_list.append(dict(obj=model_free_algo, attr_name='model_free_algo')) self.model_free_algo = model_free_algo self.config = config self.parameters = parameters self.result = list() self.validation_result = [0] * len(dynamics_model) self._dynamics_model.__class__ = ModelEnsemble @register_counter_info_to_status_decorator(increment=1, info_key='init', under_status='JUST_INITED') def init(self): self.parameters.init() self.model_free_algo.init() self.dynamics_env.init() super().init() @record_return_decorator(which_recorder='self') @register_counter_info_to_status_decorator(increment=1, info_key='train_counter', under_status='TRAIN') def train(self, *args, **kwargs) -> dict: super(MEPPO, self).train() res_dict = {} batch_data = kwargs['batch_data'] if 'batch_data' in kwargs else None if 'state' in kwargs: assert kwargs['state'] in ('state_dynamics_training', 'state_agent_training') state = kwargs['state'] kwargs.pop('state') else: state = None if not state or state == 'state_dynamics_training': dynamics_train_res_dict = self._fit_dynamics_model(batch_data=batch_data, train_iter=self.parameters('dynamics_model_train_iter')) for key, val in dynamics_train_res_dict.items(): res_dict["{}_{}".format(self._dynamics_model.name, key)] = val if not state or state == 'state_agent_training': model_free_algo_train_res_dict = self._train_model_free_algo(batch_data=batch_data, train_iter=self.parameters( 'model_free_algo_train_iter')) for key, val in model_free_algo_train_res_dict.items(): res_dict['{}_{}'.format(self.model_free_algo.name, key)] = val return res_dict @register_counter_info_to_status_decorator(increment=1, info_key='test_counter', under_status='TEST') def test(self, *arg, **kwargs): return super().test(*arg, **kwargs) def validate(self, *args, **kwargs): old_result = self.result self.validation_result = 0 for a in range(len(self._dynamics_model)): individual_model = self._dynamics_model.model[a] env = individual_model.return_as_env() new_state, reward, terminal, () = env.step(self, *args, **kwargs) self.result[a] = reward if reward > old_result[a]: self.validation_result += 1 self.validation_result = self.validation_result / len(self._dynamics_model) return self.validation_result @register_counter_info_to_status_decorator(increment=1, info_key='predict_counter') def predict(self, obs, **kwargs): return self.model_free_algo.predict(obs) def append_to_memory(self, *args, **kwargs): self.model_free_algo.append_to_memory(kwargs['samples']) @record_return_decorator(which_recorder='self') def save(self, global_step, save_path=None, name=None, **kwargs): save_path = save_path if save_path else GlobalConfig().DEFAULT_MODEL_CHECKPOINT_PATH name = name if name else self.name self.model_free_algo.save(global_step=global_step, name=None, save_path=os.path.join(save_path, self.model_free_algo.name)) self.dynamics_env.save(global_step=global_step, name=None, save_path=os.path.join(save_path, self.dynamics_env.name)) return dict(check_point_save_path=save_path, check_point_save_global_step=global_step, check_point_save_name=name) @record_return_decorator(which_recorder='self') def load(self, path_to_model, model_name, global_step=None, **kwargs): self.model_free_algo.load(path_to_model=os.path.join(path_to_model, self.model_free_algo.name), model_name=self.model_free_algo.name, global_step=global_step) self.dynamics_env.load(global_step=global_step, path_to_model=os.path.join(path_to_model, self.dynamics_env.name), model_name=self.dynamics_env.name) return dict(check_point_load_path=path_to_model, check_point_load_global_step=global_step, check_point_load_name=model_name) @register_counter_info_to_status_decorator(increment=1, info_key='dyanmics_train_counter', under_status='TRAIN') def _fit_dynamics_model(self, batch_data: TransitionData, train_iter, sess=None) -> dict: res_dict = self._dynamics_model.train(batch_data, **dict(sess=sess, train_iter=train_iter)) return res_dict @register_counter_info_to_status_decorator(increment=1, info_key='mode_free_algo_dyanmics_train_counter', under_status='TRAIN') def _train_model_free_algo(self, batch_data=None, train_iter=None, sess=None): res_dict = self.model_free_algo.train(**dict(batch_data=batch_data, train_iter=train_iter, sess=sess)) return res_dict