示例#1
0
class ConstantActionPolicy(DeterministicPolicy):
    required_key_dict = DictConfig.load_json(file_path=GlobalConfig().DEFAULT_CONSTANT_ACTION_POLICY_REQUIRED_KEY_LIST)

    def __init__(self, env_spec: EnvSpec, config_or_config_dict: (DictConfig, dict), name='policy'):
        config = construct_dict_config(config_or_config_dict, self)
        parameters = Parameters(parameters=dict(),
                                source_config=config)
        assert env_spec.action_space.contains(x=config('ACTION_VALUE'))
        super().__init__(env_spec, parameters, name)
        self.config = config

    def forward(self, *args, **kwargs):
        action = self.env_spec.action_space.unflatten(self.parameters('ACTION_VALUE'))
        assert self.env_spec.action_space.contains(x=action)
        return action

    def copy_from(self, obj) -> bool:
        super().copy_from(obj)
        self.parameters.copy_from(obj.parameters)
        return True

    def make_copy(self, *args, **kwargs):
        return ConstantActionPolicy(env_spec=self.env_spec,
                                    config_or_config_dict=deepcopy(self.config),
                                    *args, **kwargs)
示例#2
0
def construct_dict_config(config_or_config_dict, obj):
    if isinstance(config_or_config_dict, dict):
        return DictConfig(required_key_dict=obj.required_key_dict,
                          config_dict=config_or_config_dict,
                          cls_name=type(obj).__name__)
    elif isinstance(config_or_config_dict, dict):
        return config_or_config_dict
    else:
        raise TypeError('Type {} is not supported, use dict or Config'.format(
            type(config_or_config_dict).__name__))
示例#3
0
 def __init__(self, parameters: dict,
              source_config: DictConfig = None,
              name='parameters',
              to_scheduler_param_tuple: tuple = None,
              default_save_param_key=None):
     self._parameters = parameters
     self.name = name
     self._source_config = source_config if source_config else DictConfig(required_key_dict=dict(),
                                                                          config_dict=dict())
     self.default_save_param_key = default_save_param_key
     self._scheduler_info_dict = dict()
     self.to_scheduler_param_list = to_scheduler_param_tuple
示例#4
0
class Dyna(ModelBasedAlgo):
    """
    Dyna algorithms, Sutton, R. S. (1991).
    You can replace the dynamics model with any dynamics models you want.
    """
    required_key_dict = DictConfig.load_json(
        file_path=GlobalConfig().DEFAULT_ALGO_DYNA_REQUIRED_KEY_LIST)

    @init_func_arg_record_decorator()
    @typechecked
    def __init__(self,
                 env_spec,
                 dynamics_model: DynamicsModel,
                 model_free_algo: ModelFreeAlgo,
                 config_or_config_dict: (DictConfig, dict),
                 name='sample_with_dynamics'):
        super().__init__(env_spec, dynamics_model, name)
        config = construct_dict_config(config_or_config_dict, self)
        parameters = Parameters(parameters=dict(),
                                name='dyna_param',
                                source_config=config)
        sub_placeholder_input_list = []
        if isinstance(dynamics_model, PlaceholderInput):
            sub_placeholder_input_list.append(
                dict(obj=dynamics_model, attr_name='dynamics_model'))
        if isinstance(model_free_algo, PlaceholderInput):
            sub_placeholder_input_list.append(
                dict(obj=model_free_algo, attr_name='model_free_algo'))
        self.model_free_algo = model_free_algo
        self.config = config
        self.parameters = parameters

    @register_counter_info_to_status_decorator(increment=1,
                                               info_key='init',
                                               under_status='JUST_INITED')
    def init(self):
        self.parameters.init()
        self.model_free_algo.init()
        self.dynamics_env.init()
        super().init()

    @record_return_decorator(which_recorder='self')
    @register_counter_info_to_status_decorator(increment=1,
                                               info_key='train_counter',
                                               under_status='TRAIN')
    def train(self, *args, **kwargs) -> dict:
        super(Dyna, self).train()
        res_dict = {}
        batch_data = kwargs['batch_data'] if 'batch_data' in kwargs else None
        if 'state' in kwargs:
            assert kwargs['state'] in ('state_dynamics_training',
                                       'state_agent_training')
            state = kwargs['state']
            kwargs.pop('state')
        else:
            state = None

        if not state or state == 'state_dynamics_training':

            dynamics_train_res_dict = self._fit_dynamics_model(
                batch_data=batch_data,
                train_iter=self.parameters('dynamics_model_train_iter'))
            for key, val in dynamics_train_res_dict.items():
                res_dict["{}_{}".format(self._dynamics_model.name, key)] = val
        if not state or state == 'state_agent_training':
            model_free_algo_train_res_dict = self._train_model_free_algo(
                batch_data=batch_data,
                train_iter=self.parameters('model_free_algo_train_iter'))

            for key, val in model_free_algo_train_res_dict.items():
                res_dict['{}_{}'.format(self.model_free_algo.name, key)] = val
        return res_dict

    @register_counter_info_to_status_decorator(increment=1,
                                               info_key='test_counter',
                                               under_status='TEST')
    def test(self, *arg, **kwargs):
        super().test(*arg, **kwargs)

    @register_counter_info_to_status_decorator(increment=1,
                                               info_key='predict_counter')
    def predict(self, obs, **kwargs):
        return self.model_free_algo.predict(obs)

    def append_to_memory(self, *args, **kwargs):
        self.model_free_algo.append_to_memory(kwargs['samples'])

    @record_return_decorator(which_recorder='self')
    def save(self, global_step, save_path=None, name=None, **kwargs):
        save_path = save_path if save_path else GlobalConfig(
        ).DEFAULT_MODEL_CHECKPOINT_PATH
        name = name if name else self.name
        self.model_free_algo.save(global_step=global_step,
                                  name=None,
                                  save_path=os.path.join(
                                      save_path, self.model_free_algo.name))
        self.dynamics_env.save(global_step=global_step,
                               name=None,
                               save_path=os.path.join(save_path,
                                                      self.dynamics_env.name))
        return dict(check_point_save_path=save_path,
                    check_point_save_global_step=global_step,
                    check_point_save_name=name)

    @record_return_decorator(which_recorder='self')
    def load(self, path_to_model, model_name, global_step=None, **kwargs):
        self.model_free_algo.load(path_to_model=os.path.join(
            path_to_model, self.model_free_algo.name),
                                  model_name=self.model_free_algo.name,
                                  global_step=global_step)
        self.dynamics_env.load(global_step=global_step,
                               path_to_model=os.path.join(
                                   path_to_model, self.dynamics_env.name),
                               model_name=self.dynamics_env.name)
        return dict(check_point_load_path=path_to_model,
                    check_point_load_global_step=global_step,
                    check_point_load_name=model_name)

    @register_counter_info_to_status_decorator(
        increment=1, info_key='dyanmics_train_counter', under_status='TRAIN')
    def _fit_dynamics_model(self,
                            batch_data: TransitionData,
                            train_iter,
                            sess=None) -> dict:
        res_dict = self._dynamics_model.train(
            batch_data, **dict(sess=sess, train_iter=train_iter))
        return res_dict

    @register_counter_info_to_status_decorator(
        increment=1,
        info_key='mode_free_algo_dyanmics_train_counter',
        under_status='TRAIN')
    def _train_model_free_algo(self,
                               batch_data=None,
                               train_iter=None,
                               sess=None):
        res_dict = self.model_free_algo.train(
            **dict(batch_data=batch_data, train_iter=train_iter, sess=sess))
        return res_dict
示例#5
0
class PPO(ModelFreeAlgo, OnPolicyAlgo, MultiPlaceholderInput):
    required_key_dict = DictConfig.load_json(
        file_path=GlobalConfig().DEFAULT_PPO_REQUIRED_KEY_LIST)

    @typechecked
    def __init__(self,
                 env_spec: EnvSpec,
                 stochastic_policy: StochasticPolicy,
                 config_or_config_dict: (DictConfig, dict),
                 value_func: VValueFunction,
                 warm_up_trajectories_number=5,
                 use_time_index_flag=False,
                 name='ppo'):
        ModelFreeAlgo.__init__(
            self,
            env_spec=env_spec,
            name=name,
            warm_up_trajectories_number=warm_up_trajectories_number)
        self.use_time_index_flag = use_time_index_flag
        self.config = construct_dict_config(config_or_config_dict, self)
        self.policy = stochastic_policy
        self.value_func = value_func
        to_ph_parameter_dict = dict()
        self.trajectory_memory = TrajectoryData(env_spec=env_spec)
        self.transition_data_for_trajectory = TransitionData(env_spec=env_spec)
        self.value_func_train_data_buffer = None
        self.scaler = RunningStandardScaler(dims=self.env_spec.flat_obs_dim)
        if use_time_index_flag:
            scale_last_time_index_mean = self.scaler._mean
            scale_last_time_index_mean[-1] = 0
            scale_last_time_index_var = self.scaler._var
            scale_last_time_index_var[-1] = 1000 * 1000
            self.scaler.set_param(mean=scale_last_time_index_mean,
                                  var=scale_last_time_index_var)
        with tf.variable_scope(name):
            self.advantages_ph = tf.placeholder(tf.float32, (None, ),
                                                'advantages')
            self.v_func_val_ph = tf.placeholder(tf.float32, (None, ),
                                                'val_val_func')
            dist_info_list = self.policy.get_dist_info()
            self.old_dist_tensor = [
                (tf.placeholder(**dict(dtype=dist_info['dtype'],
                                       shape=dist_info['shape'],
                                       name=dist_info['name'])),
                 dist_info['name']) for dist_info in dist_info_list
            ]
            self.old_policy = self.policy.make_copy(
                reuse=False,
                name_scope='old_{}'.format(self.policy.name),
                name='old_{}'.format(self.policy.name),
                distribution_tensors_tuple=tuple(self.old_dist_tensor))
            to_ph_parameter_dict['beta'] = tf.placeholder(
                tf.float32, (), 'beta')
            to_ph_parameter_dict['eta'] = tf.placeholder(tf.float32, (), 'eta')
            to_ph_parameter_dict['kl_target'] = tf.placeholder(
                tf.float32, (), 'kl_target')
            to_ph_parameter_dict['lr_multiplier'] = tf.placeholder(
                tf.float32, (), 'lr_multiplier')

        self.parameters = ParametersWithTensorflowVariable(
            tf_var_list=[],
            rest_parameters=dict(
                advantages_ph=self.advantages_ph,
                v_func_val_ph=self.v_func_val_ph,
            ),
            to_ph_parameter_dict=to_ph_parameter_dict,
            name='ppo_param',
            save_rest_param_flag=False,
            source_config=self.config,
            require_snapshot=False)
        with tf.variable_scope(name):
            with tf.variable_scope('train'):
                self.kl = tf.reduce_mean(self.old_policy.kl(self.policy))
                self.average_entropy = tf.reduce_mean(self.policy.entropy())
                self.policy_loss, self.policy_optimizer, self.policy_update_op = self._setup_policy_loss(
                )
                self.value_func_loss, self.value_func_optimizer, self.value_func_update_op = self._setup_value_func_loss(
                )
        var_list = get_tf_collection_var_list(
            '{}/train'.format(name)) + self.policy_optimizer.variables(
            ) + self.value_func_optimizer.variables()
        self.parameters.set_tf_var_list(
            tf_var_list=sorted(list(set(var_list)), key=lambda x: x.name))
        MultiPlaceholderInput.__init__(self,
                                       sub_placeholder_input_list=[
                                           dict(
                                               obj=self.value_func,
                                               attr_name='value_func',
                                           ),
                                           dict(obj=self.policy,
                                                attr_name='policy')
                                       ],
                                       parameters=self.parameters)

    def warm_up(self, trajectory_data: TrajectoryData):
        for traj in trajectory_data.trajectories:
            self.scaler.update_scaler(data=traj.state_set)
        if self.use_time_index_flag:
            scale_last_time_index_mean = self.scaler._mean
            scale_last_time_index_mean[-1] = 0
            scale_last_time_index_var = self.scaler._var
            scale_last_time_index_var[-1] = 1000 * 1000
            self.scaler.set_param(mean=scale_last_time_index_mean,
                                  var=scale_last_time_index_var)

    @register_counter_info_to_status_decorator(increment=1,
                                               info_key='init',
                                               under_status='INITED')
    def init(self, sess=None, source_obj=None):
        self.policy.init()
        self.value_func.init()
        self.parameters.init()
        if source_obj:
            self.copy_from(source_obj)
        super().init()

    @record_return_decorator(which_recorder='self')
    @register_counter_info_to_status_decorator(increment=1,
                                               info_key='train',
                                               under_status='TRAIN')
    def train(self,
              trajectory_data: TrajectoryData = None,
              train_iter=None,
              sess=None) -> dict:
        super(PPO, self).train()
        if trajectory_data is None:
            trajectory_data = self.trajectory_memory
        if len(trajectory_data) == 0:
            raise MemoryBufferLessThanBatchSizeError(
                'not enough trajectory data')
        for i, traj in enumerate(trajectory_data.trajectories):
            trajectory_data.trajectories[i].append_new_set(
                name='state_set',
                shape=self.env_spec.obs_shape,
                data_set=np.reshape(
                    np.array(self.scaler.process(np.array(traj.state_set))),
                    [-1] + list(self.env_spec.obs_shape)))
            trajectory_data.trajectories[i].append_new_set(
                name='new_state_set',
                shape=self.env_spec.obs_shape,
                data_set=np.reshape(
                    np.array(self.scaler.process(np.array(
                        traj.new_state_set))),
                    [-1] + list(self.env_spec.obs_shape)))

        tf_sess = sess if sess else tf.get_default_session()
        SampleProcessor.add_estimated_v_value(trajectory_data,
                                              value_func=self.value_func)
        SampleProcessor.add_discount_sum_reward(trajectory_data,
                                                gamma=self.parameters('gamma'))
        SampleProcessor.add_gae(trajectory_data,
                                gamma=self.parameters('gamma'),
                                name='advantage_set',
                                lam=self.parameters('lam'),
                                value_func=self.value_func)
        trajectory_data = SampleProcessor.normalization(trajectory_data,
                                                        key='advantage_set')
        policy_res_dict = self._update_policy(
            state_set=np.concatenate(
                [t('state_set') for t in trajectory_data.trajectories],
                axis=0),
            action_set=np.concatenate(
                [t('action_set') for t in trajectory_data.trajectories],
                axis=0),
            advantage_set=np.concatenate(
                [t('advantage_set') for t in trajectory_data.trajectories],
                axis=0),
            train_iter=train_iter
            if train_iter else self.parameters('policy_train_iter'),
            sess=tf_sess)
        value_func_res_dict = self._update_value_func(
            state_set=np.concatenate(
                [t('state_set') for t in trajectory_data.trajectories],
                axis=0),
            discount_set=np.concatenate(
                [t('discount_set') for t in trajectory_data.trajectories],
                axis=0),
            train_iter=train_iter
            if train_iter else self.parameters('value_func_train_iter'),
            sess=tf_sess)
        trajectory_data.reset()
        self.trajectory_memory.reset()
        return {**policy_res_dict, **value_func_res_dict}

    @register_counter_info_to_status_decorator(increment=1,
                                               info_key='test',
                                               under_status='TEST')
    def test(self, *arg, **kwargs) -> dict:
        return super().test(*arg, **kwargs)

    @register_counter_info_to_status_decorator(increment=1, info_key='predict')
    def predict(self, obs: np.ndarray, sess=None, batch_flag: bool = False):
        tf_sess = sess if sess else tf.get_default_session()
        ac = self.policy.forward(
            obs=self.scaler.process(
                data=make_batch(obs, original_shape=self.env_spec.obs_shape)),
            sess=tf_sess,
            feed_dict=self.parameters.return_tf_parameter_feed_dict())
        return ac

    def append_to_memory(self, samples: TrajectoryData):
        # todo how to make sure the data's time sequential
        obs_list = samples.trajectories[0].state_set
        for i in range(1, len(samples.trajectories)):
            obs_list = np.array(
                np.concatenate([obs_list, samples.trajectories[i].state_set],
                               axis=0))
        self.trajectory_memory.union(samples)
        self.scaler.update_scaler(data=np.array(obs_list))
        if self.use_time_index_flag:
            scale_last_time_index_mean = self.scaler._mean
            scale_last_time_index_mean[-1] = 0
            scale_last_time_index_var = self.scaler._var
            scale_last_time_index_var[-1] = 1000 * 1000
            self.scaler.set_param(mean=scale_last_time_index_mean,
                                  var=scale_last_time_index_var)

    @record_return_decorator(which_recorder='self')
    def save(self, global_step, save_path=None, name=None, **kwargs):
        save_path = save_path if save_path else GlobalConfig(
        ).DEFAULT_MODEL_CHECKPOINT_PATH
        name = name if name else self.name
        MultiPlaceholderInput.save(self,
                                   save_path=save_path,
                                   global_step=global_step,
                                   name=name,
                                   **kwargs)
        return dict(check_point_save_path=save_path,
                    check_point_save_global_step=global_step,
                    check_point_save_name=name)

    @record_return_decorator(which_recorder='self')
    def load(self, path_to_model, model_name, global_step=None, **kwargs):
        MultiPlaceholderInput.load(self, path_to_model, model_name,
                                   global_step, **kwargs)
        return dict(check_point_load_path=path_to_model,
                    check_point_load_global_step=global_step,
                    check_point_load_name=model_name)

    def _setup_policy_loss(self):
        """
        Code clip from pat-cody
        Three loss terms:
            1) standard policy gradient
            2) D_KL(pi_old || pi_new)
            3) Hinge loss on [D_KL - kl_targ]^2

        See: https://arxiv.org/pdf/1707.02286.pdf
        """

        if self.parameters('clipping_range') is not None:
            pg_ratio = tf.exp(self.policy.log_prob() -
                              self.old_policy.log_prob())
            clipped_pg_ratio = tf.clip_by_value(
                pg_ratio, 1 - self.parameters('clipping_range')[0],
                1 + self.parameters('clipping_range')[1])
            surrogate_loss = tf.minimum(self.advantages_ph * pg_ratio,
                                        self.advantages_ph * clipped_pg_ratio)
            loss = -tf.reduce_mean(surrogate_loss)
        else:
            loss1 = -tf.reduce_mean(
                self.advantages_ph *
                tf.exp(self.policy.log_prob() - self.old_policy.log_prob()))
            loss2 = tf.reduce_mean(self.parameters('beta') * self.kl)
            loss3 = self.parameters('eta') * tf.square(
                tf.maximum(0.0, self.kl - 2.0 * self.parameters('kl_target')))
            loss = loss1 + loss2 + loss3
            self.loss1 = loss1
            self.loss2 = loss2
            self.loss3 = loss3
        if isinstance(self.policy, PlaceholderInput):
            reg_list = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                                         scope=self.policy.name_scope)
            if len(reg_list) > 0:
                reg_loss = tf.reduce_sum(reg_list)
                loss += reg_loss

        optimizer = tf.train.AdamOptimizer(
            learning_rate=self.parameters('policy_lr') *
            self.parameters('lr_multiplier'))
        train_op = optimizer.minimize(
            loss, var_list=self.policy.parameters('tf_var_list'))
        return loss, optimizer, train_op

    def _setup_value_func_loss(self):
        # todo update the value_func design
        loss = tf.reduce_mean(
            tf.square(
                tf.squeeze(self.value_func.v_tensor) - self.v_func_val_ph))
        reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                                     scope=self.value_func.name_scope)
        if len(reg_loss) > 0:
            loss += tf.reduce_sum(reg_loss)
        optimizer = tf.train.AdamOptimizer(self.parameters('value_func_lr'))
        train_op = optimizer.minimize(
            loss, var_list=self.value_func.parameters('tf_var_list'))
        return loss, optimizer, train_op

    def _update_policy(self, state_set, action_set, advantage_set, train_iter,
                       sess):
        old_policy_feed_dict = dict()
        res = sess.run(
            [
                getattr(self.policy, tensor[1])
                for tensor in self.old_dist_tensor
            ],
            feed_dict={
                self.policy.parameters('state_input'): state_set,
                self.policy.parameters('action_input'): action_set,
                **self.parameters.return_tf_parameter_feed_dict()
            })

        for tensor, val in zip(self.old_dist_tensor, res):
            old_policy_feed_dict[tensor[0]] = val

        feed_dict = {
            self.policy.parameters('action_input'): action_set,
            self.old_policy.parameters('action_input'): action_set,
            self.policy.parameters('state_input'): state_set,
            self.advantages_ph: advantage_set,
            **self.parameters.return_tf_parameter_feed_dict(),
            **old_policy_feed_dict
        }
        average_loss, average_kl, average_entropy = 0.0, 0.0, 0.0
        total_epoch = 0
        kl = None
        for i in range(train_iter):
            _ = sess.run(self.policy_update_op, feed_dict=feed_dict)
            loss, kl, entropy = sess.run(
                [self.policy_loss, self.kl, self.average_entropy],
                feed_dict=feed_dict)
            average_loss += loss
            average_kl += kl
            average_entropy += entropy
            total_epoch = i + 1
            if kl > self.parameters('kl_target', require_true_value=True) * 4:
                # early stopping if D_KL diverges badly
                break
        average_loss, average_kl, average_entropy = average_loss / total_epoch, average_kl / total_epoch, average_entropy / total_epoch

        if kl > self.parameters('kl_target', require_true_value=True
                                ) * 2:  # servo beta to reach D_KL target
            self.parameters.set(
                key='beta',
                new_val=np.minimum(
                    35,
                    1.5 * self.parameters('beta', require_true_value=True)))
            if self.parameters(
                    'beta', require_true_value=True) > 30 and self.parameters(
                        'lr_multiplier', require_true_value=True) > 0.1:
                self.parameters.set(
                    key='lr_multiplier',
                    new_val=self.parameters('lr_multiplier',
                                            require_true_value=True) / 1.5)
        elif kl < self.parameters('kl_target', require_true_value=True) / 2:
            self.parameters.set(
                key='beta',
                new_val=np.maximum(
                    1 / 35,
                    self.parameters('beta', require_true_value=True) / 1.5))

            if self.parameters('beta', require_true_value=True) < (
                    1 / 30) and self.parameters('lr_multiplier',
                                                require_true_value=True) < 10:
                self.parameters.set(
                    key='lr_multiplier',
                    new_val=self.parameters('lr_multiplier',
                                            require_true_value=True) * 1.5)
        return dict(policy_average_loss=average_loss,
                    policy_average_kl=average_kl,
                    policy_average_entropy=average_entropy,
                    policy_total_train_epoch=total_epoch)

    def _update_value_func(self, state_set, discount_set, train_iter, sess):
        y_hat = self.value_func.forward(obs=state_set).squeeze()
        old_exp_var = 1 - np.var(discount_set - y_hat) / np.var(discount_set)

        if self.value_func_train_data_buffer is None:
            self.value_func_train_data_buffer = (state_set, discount_set)
        else:
            self.value_func_train_data_buffer = (
                np.concatenate(
                    [self.value_func_train_data_buffer[0], state_set], axis=0),
                np.concatenate(
                    [self.value_func_train_data_buffer[1], discount_set],
                    axis=0))
        if len(self.value_func_train_data_buffer[0]) > self.parameters(
                'value_func_memory_size'):
            self.value_func_train_data_buffer = tuple(
                np.array(data[-self.parameters('value_func_memory_size'):])
                for data in self.value_func_train_data_buffer)
        state_set_all, discount_set_all = self.value_func_train_data_buffer

        param_dict = self.parameters.return_tf_parameter_feed_dict()

        for i in range(train_iter):
            random_index = np.random.choice(np.arange(len(state_set_all)),
                                            len(state_set_all))
            state_set_all = state_set_all[random_index]
            discount_set_all = discount_set_all[random_index]
            for index in range(
                    0,
                    len(state_set_all) -
                    self.parameters('value_func_train_batch_size'),
                    self.parameters('value_func_train_batch_size')):
                state = np.array(
                    state_set_all[index:index + self.
                                  parameters('value_func_train_batch_size')])
                discount = discount_set_all[
                    index:index +
                    self.parameters('value_func_train_batch_size')]
                loss, _ = sess.run(
                    [self.value_func_loss, self.value_func_update_op],
                    options=tf.RunOptions(
                        report_tensor_allocations_upon_oom=True),
                    feed_dict={
                        self.value_func.state_input: state,
                        self.v_func_val_ph: discount,
                        **param_dict
                    })
        y_hat = self.value_func.forward(obs=state_set).squeeze()
        loss = np.mean(np.square(y_hat - discount_set))
        exp_var = 1 - np.var(discount_set - y_hat) / np.var(discount_set)
        return dict(value_func_loss=loss,
                    value_func_policy_exp_var=exp_var,
                    value_func_policy_old_exp_var=old_exp_var)
示例#6
0
class DQN(ModelFreeAlgo, OffPolicyAlgo, MultiPlaceholderInput):
    required_key_dict = DictConfig.load_json(file_path=GlobalConfig().DEFAULT_DQN_REQUIRED_KEY_LIST)

    @init_func_arg_record_decorator()
    @typechecked
    def __init__(self,
                 env_spec,
                 config_or_config_dict: (DictConfig, dict),
                 value_func: MLPQValueFunction,
                 schedule_param_list=None,
                 name: str = 'dqn',
                 replay_buffer=None):
        ModelFreeAlgo.__init__(self, env_spec=env_spec, name=name)
        self.config = construct_dict_config(config_or_config_dict, self)

        if replay_buffer:
            assert issubclass(replay_buffer, BaseReplayBuffer)
            self.replay_buffer = replay_buffer
        else:
            self.replay_buffer = UniformRandomReplayBuffer(limit=self.config('REPLAY_BUFFER_SIZE'),
                                                           action_shape=self.env_spec.action_shape,
                                                           observation_shape=self.env_spec.obs_shape)
        self.q_value_func = value_func
        self.state_input = self.q_value_func.state_input
        self.action_input = self.q_value_func.action_input
        self.update_target_q_every_train = self.config('UPDATE_TARGET_Q_FREQUENCY') if 'UPDATE_TARGET_Q_FREQUENCY' in \
                                                                                       self.config.config_dict else 1
        self.parameters = ParametersWithTensorflowVariable(tf_var_list=[],
                                                           rest_parameters=dict(),
                                                           to_scheduler_param_tuple=schedule_param_list,
                                                           name='{}_param'.format(name),
                                                           source_config=self.config,
                                                           require_snapshot=False)

        with tf.variable_scope(name):
            self.reward_input = tf.placeholder(shape=[None, 1], dtype=tf.float32)
            self.next_state_input = tf.placeholder(shape=[None, self.env_spec.flat_obs_dim], dtype=tf.float32)
            self.done_input = tf.placeholder(shape=[None, 1], dtype=tf.bool)
            self.target_q_input = tf.placeholder(shape=[None, 1], dtype=tf.float32)
            done = tf.cast(self.done_input, dtype=tf.float32)
            self.target_q_value_func = self.q_value_func.make_copy(name_scope='{}_target_q_value_net'.format(name),
                                                                   name='{}_target_q_value_net'.format(name),
                                                                   reuse=False)
            self.predict_q_value = (1. - done) * self.config('GAMMA') * self.target_q_input + self.reward_input
            self.td_error = self.predict_q_value - self.q_value_func.q_tensor

            with tf.variable_scope('train'):
                self.q_value_func_loss, self.optimizer, self.update_q_value_func_op = self._set_up_loss()
                self.update_target_q_value_func_op = self._set_up_target_update()

        # redundant sort operation on var_list
        var_list = get_tf_collection_var_list(key=tf.GraphKeys.GLOBAL_VARIABLES,
                                              scope='{}/train'.format(name)) + self.optimizer.variables()
        self.parameters.set_tf_var_list(tf_var_list=sorted(list(set(var_list)), key=lambda x: x.name))

        MultiPlaceholderInput.__init__(self,
                                       sub_placeholder_input_list=[dict(obj=self.q_value_func, attr_name='q_value_func'),
                                                                   dict(obj=self.target_q_value_func, attr_name='target_q_value_func')],
                                       parameters=self.parameters)

    @register_counter_info_to_status_decorator(increment=1, info_key='init', under_status='INITED')
    def init(self, sess=None, source_obj=None):
        super().init()
        self.q_value_func.init()
        self.target_q_value_func.init(source_obj=self.q_value_func)
        self.parameters.init()
        if source_obj:
            self.copy_from(source_obj)

    @record_return_decorator(which_recorder='self')
    @register_counter_info_to_status_decorator(increment=1, info_key='train_counter', under_status='TRAIN')
    def train(self, batch_data=None, train_iter=None, sess=None, update_target=True) -> dict:
        super(DQN, self).train()
        self.recorder.record()
        if batch_data and not isinstance(batch_data, TransitionData):
            raise TypeError()

        tf_sess = sess if sess else tf.get_default_session()
        train_iter = self.parameters("TRAIN_ITERATION") if not train_iter else train_iter
        average_loss = 0.0

        for i in range(train_iter):
            if batch_data is None:
                train_data = self.replay_buffer.sample(batch_size=self.parameters('BATCH_SIZE'))
            else:
                train_data = batch_data

            _, target_q_val_on_new_s = self.predict_target_with_q_val(obs=train_data.new_state_set, batch_flag=True)
            target_q_val_on_new_s = np.expand_dims(target_q_val_on_new_s, axis=1)
            assert target_q_val_on_new_s.shape[0] == train_data.state_set.shape[0]

            feed_dict = {
                self.reward_input: np.reshape(train_data.reward_set, [-1, 1]),
                self.action_input: flatten_n(self.env_spec.action_space, train_data.action_set),
                self.state_input: train_data.state_set,
                self.done_input: np.reshape(train_data.done_set, [-1, 1]),
                self.target_q_input: target_q_val_on_new_s,
                **self.parameters.return_tf_parameter_feed_dict()
            }
            res, _ = tf_sess.run([self.q_value_func_loss, self.update_q_value_func_op],
                                 feed_dict=feed_dict)
            average_loss += res

        average_loss /= train_iter

        if update_target is True and self.get_status()['train_counter'] % self.update_target_q_every_train == 0:
            tf_sess.run(self.update_target_q_value_func_op,
                        feed_dict=self.parameters.return_tf_parameter_feed_dict())
        return dict(average_loss=average_loss)

    @register_counter_info_to_status_decorator(increment=1, info_key='test_counter', under_status='TEST')
    def test(self, *arg, **kwargs):
        return super().test(*arg, **kwargs)

    @register_counter_info_to_status_decorator(increment=1, info_key='predict_counter')
    def predict(self, obs: np.ndarray, sess=None, batch_flag: bool = False):
        if batch_flag:
            action, q_val = self._predict_batch_action(obs=obs,
                                                       q_value_tensor=self.q_value_func.q_tensor,
                                                       action_ph=self.action_input,
                                                       state_ph=self.state_input,
                                                       sess=sess)
        else:
            action, q_val = self._predict_action(obs=obs,
                                                 q_value_tensor=self.q_value_func.q_tensor,
                                                 action_ph=self.action_input,
                                                 state_ph=self.state_input,
                                                 sess=sess)
        if not batch_flag:
            return int(action)
        else:
            return action.astype(np.int).tolist()

    def predict_target_with_q_val(self, obs: np.ndarray, sess=None, batch_flag: bool = False):
        if batch_flag:
            action, q_val = self._predict_batch_action(obs=obs,
                                                       q_value_tensor=self.target_q_value_func.q_tensor,
                                                       action_ph=self.target_q_value_func.action_input,
                                                       state_ph=self.target_q_value_func.state_input,
                                                       sess=sess)
        else:
            action, q_val = self._predict_action(obs=obs,
                                                 q_value_tensor=self.target_q_value_func.q_tensor,
                                                 action_ph=self.target_q_value_func.action_input,
                                                 state_ph=self.target_q_value_func.state_input,
                                                 sess=sess)
        return action, q_val

    # Store Transition
    @register_counter_info_to_status_decorator(increment=1, info_key='append_to_memory')
    def append_to_memory(self, samples: TransitionData):
        self.replay_buffer.append_batch(obs0=samples.state_set,
                                        obs1=samples.new_state_set,
                                        action=samples.action_set,
                                        reward=samples.reward_set,
                                        terminal1=samples.done_set)
        self._status.update_info(info_key='replay_buffer_data_total_count', increment=len(samples))

    @record_return_decorator(which_recorder='self')
    def save(self, global_step, save_path=None, name=None, **kwargs):
        save_path = save_path if save_path else GlobalConfig().DEFAULT_MODEL_CHECKPOINT_PATH
        name = name if name else self.name
        MultiPlaceholderInput.save(self, save_path=save_path, global_step=global_step, name=name, **kwargs)
        return dict(check_point_save_path=save_path, check_point_save_global_step=global_step,
                    check_point_save_name=name)

    @record_return_decorator(which_recorder='self')
    def load(self, path_to_model, model_name, global_step=None, **kwargs):
        MultiPlaceholderInput.load(self, path_to_model, model_name, global_step, **kwargs)
        return dict(check_point_load_path=path_to_model, check_point_load_global_step=global_step,
                    check_point_load_name=model_name)

    def _predict_action(self, obs: np.ndarray, q_value_tensor: tf.Tensor, action_ph: tf.Tensor, state_ph: tf.Tensor, sess=None):
        if self.env_spec.obs_space.contains(obs) is False:
            raise StateOrActionOutOfBoundError("obs {} out of bound {}".format(obs, self.env_spec.obs_space.bound()))
        obs = repeat_ndarray(obs, repeats=self.env_spec.flat_action_dim)
        tf_sess = sess if sess else tf.get_default_session()
        feed_dict = {action_ph: generate_n_actions_hot_code(n=self.env_spec.flat_action_dim),
                     state_ph: obs, **self.parameters.return_tf_parameter_feed_dict()}
        res = tf_sess.run([q_value_tensor],
                          feed_dict=feed_dict)[0]
        return np.argmax(res, axis=0), np.max(res, axis=0)

    def _predict_batch_action(self, obs: np.ndarray, q_value_tensor: tf.Tensor, action_ph: tf.Tensor,
                              state_ph: tf.Tensor, sess=None):
        actions = []
        q_values = []
        for obs_i in obs:
            action, q_val = self._predict_action(obs=obs_i,
                                                 q_value_tensor=q_value_tensor,
                                                 action_ph=action_ph,
                                                 state_ph=state_ph,
                                                 sess=sess)
            actions.append(np.argmax(action, axis=0))
            q_values.append(np.max(q_val, axis=0))
        return np.array(actions), np.array(q_values)

    def _set_up_loss(self):
        reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=self.q_value_func.name_scope)
        loss = tf.reduce_sum((self.predict_q_value - self.q_value_func.q_tensor) ** 2)
        if len(reg_loss) > 0:
            loss += tf.reduce_sum(reg_loss)
        optimizer = tf.train.AdamOptimizer(learning_rate=self.parameters('LEARNING_RATE'))
        optimize_op = optimizer.minimize(loss=loss, var_list=self.q_value_func.parameters('tf_var_list'))
        return loss, optimizer, optimize_op

    # update target net
    def _set_up_target_update(self):
        op = []
        for var, target_var in zip(self.q_value_func.parameters('tf_var_list'),
                                   self.target_q_value_func.parameters('tf_var_list')):
            ref_val = self.parameters('DECAY') * target_var + (1.0 - self.parameters('DECAY')) * var
            op.append(tf.assign(target_var, ref_val))
        return op

    def _evaluate_td_error(self, sess=None):
        # tf_sess = sess if sess else tf.get_default_session()
        # feed_dict = {
        #     self.reward_input: train_data.reward_set,
        #     self.action_input: flatten_n(self.env_spec.action_space, train_data.action_set),
        #     self.state_input: train_data.state_set,
        #     self.done_input: train_data.done_set,
        #     self.target_q_input: target_q_val_on_new_s,
        #     **self.parameters.return_tf_parameter_feed_dict()
        # }
        # td_loss = tf_sess.run([self.td_error], feed_dict=feed_dict)
        pass
示例#7
0
class DDPG(ModelFreeAlgo, OffPolicyAlgo, MultiPlaceholderInput):
    required_key_dict = DictConfig.load_json(file_path=GlobalConfig().DEFAULT_DDPG_REQUIRED_KEY_LIST)

    @typechecked
    def __init__(self,
                 env_spec: EnvSpec,
                 config_or_config_dict: (DictConfig, dict),
                 value_func: MLPQValueFunction,
                 policy: DeterministicMLPPolicy,
                 schedule_param_list=None,
                 name='ddpg',
                 replay_buffer=None):
        """

        :param env_spec: environment specifications, like action apace or observation space
        :param config_or_config_dict: configuraion dictionary, like learning rate or decay, if any
        :param value_func: value function
        :param policy: agent policy
        :param schedule_param_list: schedule parameter list, if any  initla final function to schedule learning process
        :param name: name of algorithm class instance
        :param replay_buffer: replay buffer, if any
        """
        ModelFreeAlgo.__init__(self, env_spec=env_spec, name=name)
        config = construct_dict_config(config_or_config_dict, self)

        self.config = config
        self.actor = policy
        self.target_actor = self.actor.make_copy(name_scope='{}_target_actor'.format(self.name),
                                                 name='{}_target_actor'.format(self.name),
                                                 reuse=False)
        self.critic = value_func
        self.target_critic = self.critic.make_copy(name_scope='{}_target_critic'.format(self.name),
                                                   name='{}_target_critic'.format(self.name),
                                                   reuse=False)

        self.state_input = self.actor.state_input

        if replay_buffer:
            assert issubclass(replay_buffer, BaseReplayBuffer)
            self.replay_buffer = replay_buffer
        else:
            self.replay_buffer = UniformRandomReplayBuffer(limit=self.config('REPLAY_BUFFER_SIZE'),
                                                           action_shape=self.env_spec.action_shape,
                                                           observation_shape=self.env_spec.obs_shape)
        """
        self.parameters contains all the parameters (variables) of the algorithm
        """
        self.parameters = ParametersWithTensorflowVariable(tf_var_list=[],
                                                           rest_parameters=dict(),
                                                           to_scheduler_param_tuple=schedule_param_list,
                                                           name='ddpg_param',
                                                           source_config=config,
                                                           require_snapshot=False)
        self._critic_with_actor_output = self.critic.make_copy(reuse=True,
                                                               name='actor_input_{}'.format(self.critic.name),
                                                               state_input=self.state_input,
                                                               action_input=self.actor.action_tensor)
        self._target_critic_with_target_actor_output = self.target_critic.make_copy(reuse=True,
                                                                                    name='target_critic_with_target_actor_output_{}'.format(
                                                                                        self.critic.name),
                                                                                    action_input=self.target_actor.action_tensor)

        with tf.variable_scope(name):
            self.reward_input = tf.placeholder(shape=[None, 1], dtype=tf.float32)
            self.next_state_input = tf.placeholder(shape=[None, self.env_spec.flat_obs_dim], dtype=tf.float32)
            self.done_input = tf.placeholder(shape=[None, 1], dtype=tf.bool)
            self.target_q_input = tf.placeholder(shape=[None, 1], dtype=tf.float32)
            done = tf.cast(self.done_input, dtype=tf.float32)
            self.predict_q_value = (1. - done) * self.config('GAMMA') * self.target_q_input + self.reward_input
            with tf.variable_scope('train'):
                self.critic_loss, self.critic_update_op, self.target_critic_update_op, self.critic_optimizer, \
                self.critic_grads = self._setup_critic_loss()
                self.actor_loss, self.actor_update_op, self.target_actor_update_op, self.action_optimizer, \
                self.actor_grads = self._set_up_actor_loss()

        var_list = get_tf_collection_var_list(
            '{}/train'.format(name)) + self.critic_optimizer.variables() + self.action_optimizer.variables()
        self.parameters.set_tf_var_list(tf_var_list=sorted(list(set(var_list)), key=lambda x: x.name))
        MultiPlaceholderInput.__init__(self,
                                       sub_placeholder_input_list=[dict(obj=self.target_actor,
                                                                        attr_name='target_actor',
                                                                        ),
                                                                   dict(obj=self.actor,
                                                                        attr_name='actor'),
                                                                   dict(obj=self.critic,
                                                                        attr_name='critic'),
                                                                   dict(obj=self.target_critic,
                                                                        attr_name='target_critic')
                                                                   ],
                                       parameters=self.parameters)

    @register_counter_info_to_status_decorator(increment=1, info_key='init', under_status='INITED')
    def init(self, sess=None, source_obj=None):
        self.actor.init()
        self.critic.init()
        self.target_actor.init()
        self.target_critic.init(source_obj=self.critic)
        self.parameters.init()
        if source_obj:
            self.copy_from(source_obj)
        super().init()

    @record_return_decorator(which_recorder='self')
    @register_counter_info_to_status_decorator(increment=1, info_key='train', under_status='TRAIN')
    def train(self, batch_data=None, train_iter=None, sess=None, update_target=True) -> dict:
        super(DDPG, self).train()
        if isinstance(batch_data, TrajectoryData):
            batch_data = batch_data.return_as_transition_data(shuffle_flag=True)
        tf_sess = sess if sess else tf.get_default_session()
        train_iter = self.parameters("TRAIN_ITERATION") if not train_iter else train_iter
        average_critic_loss = 0.0
        average_actor_loss = 0.0
        for i in range(train_iter):
            train_batch = self.replay_buffer.sample(
                batch_size=self.parameters('BATCH_SIZE')) if batch_data is None else batch_data
            assert isinstance(train_batch, TransitionData)

            critic_loss, _ = self._critic_train(train_batch, tf_sess)

            actor_loss, _ = self._actor_train(train_batch, tf_sess)

            average_actor_loss += actor_loss
            average_critic_loss += critic_loss
        if update_target:
            tf_sess.run([self.target_actor_update_op, self.target_critic_update_op])
        return dict(average_actor_loss=average_actor_loss / train_iter,
                    average_critic_loss=average_critic_loss / train_iter)

    def _critic_train(self, batch_data, sess) -> ():
        target_q = sess.run(
            self._target_critic_with_target_actor_output.q_tensor,
            feed_dict={
                self._target_critic_with_target_actor_output.state_input: batch_data.new_state_set,
                self.target_actor.state_input: batch_data.new_state_set
            }
        )
        loss, _, grads = sess.run(
            [self.critic_loss, self.critic_update_op, self.critic_grads
             ],
            feed_dict={
                self.target_q_input: target_q,
                self.critic.state_input: batch_data.state_set,
                self.critic.action_input: batch_data.action_set,
                self.done_input: np.reshape(batch_data.done_set, [-1, 1]),
                self.reward_input: np.reshape(batch_data.reward_set, [-1, 1]),
                **self.parameters.return_tf_parameter_feed_dict()
            }
        )
        return loss, grads

    def _actor_train(self, batch_data, sess) -> ():
        target_q, loss, _, grads = sess.run(
            [self._critic_with_actor_output.q_tensor, self.actor_loss, self.actor_update_op, self.actor_grads],
            feed_dict={
                self.actor.state_input: batch_data.state_set,
                self._critic_with_actor_output.state_input: batch_data.state_set,
                **self.parameters.return_tf_parameter_feed_dict()
            }
        )
        return loss, grads

    @register_counter_info_to_status_decorator(increment=1, info_key='test', under_status='TEST')
    def test(self, *arg, **kwargs) -> dict:
        return super().test(*arg, **kwargs)

    def predict(self, obs: np.ndarray, sess=None, batch_flag: bool = False):
        tf_sess = sess if sess else tf.get_default_session()
        feed_dict = {
            self.state_input: make_batch(obs, original_shape=self.env_spec.obs_shape),
            **self.parameters.return_tf_parameter_feed_dict()
        }
        return self.actor.forward(obs=obs, sess=tf_sess, feed_dict=feed_dict)

    def append_to_memory(self, samples: TransitionData):

        self.replay_buffer.append_batch(obs0=samples.state_set,
                                        obs1=samples.new_state_set,
                                        action=samples.action_set,
                                        reward=samples.reward_set,
                                        terminal1=samples.done_set)

    @record_return_decorator(which_recorder='self')
    def save(self, global_step, save_path=None, name=None, **kwargs):
        save_path = save_path if save_path else GlobalConfig().DEFAULT_MODEL_CHECKPOINT_PATH
        name = name if name else self.name
        MultiPlaceholderInput.save(self, save_path=save_path, global_step=global_step, name=name, **kwargs)
        return dict(check_point_save_path=save_path, check_point_save_global_step=global_step,
                    check_point_save_name=name)

    @record_return_decorator(which_recorder='self')
    def load(self, path_to_model, model_name, global_step=None, **kwargs):
        MultiPlaceholderInput.load(self, path_to_model, model_name, global_step, **kwargs)
        return dict(check_point_load_path=path_to_model, check_point_load_global_step=global_step,
                    check_point_load_name=model_name)

    def _setup_critic_loss(self):
        reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=self.critic.name_scope)
        loss = tf.reduce_sum((self.predict_q_value - self.critic.q_tensor) ** 2)
        if len(reg_loss) > 0:
            loss += tf.reduce_sum(reg_loss)
        optimizer = tf.train.AdamOptimizer(learning_rate=self.parameters('CRITIC_LEARNING_RATE'))
        grad_var_pair = optimizer.compute_gradients(loss=loss, var_list=self.critic.parameters('tf_var_list'))
        grads = [g[0] for g in grad_var_pair]
        if self.parameters('critic_clip_norm') is not None:
            grad_var_pair, grads = clip_grad(optimizer=optimizer,
                                             loss=loss,
                                             var_list=self.critic.parameters('tf_var_list'),
                                             clip_norm=self.parameters('critic_clip_norm'))
        optimize_op = optimizer.apply_gradients(grad_var_pair)
        op = []
        for var, target_var in zip(self.critic.parameters('tf_var_list'),
                                   self.target_critic.parameters('tf_var_list')):
            ref_val = self.parameters('DECAY') * target_var + (1.0 - self.parameters('DECAY')) * var
            op.append(tf.assign(target_var, ref_val))

        return loss, optimize_op, op, optimizer, grads

    def _set_up_actor_loss(self):
        reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=self.actor.name_scope)
        loss = -tf.reduce_mean(self._critic_with_actor_output.q_tensor)

        if len(reg_loss) > 0:
            loss += tf.reduce_sum(reg_loss)

        optimizer = tf.train.AdamOptimizer(learning_rate=self.parameters('CRITIC_LEARNING_RATE'))
        grad_var_pair = optimizer.compute_gradients(loss=loss, var_list=self.actor.parameters('tf_var_list'))
        grads = [g[0] for g in grad_var_pair]
        if self.parameters('actor_clip_norm') is not None:
            grad_var_pair, grads = clip_grad(optimizer=optimizer,
                                             loss=loss,
                                             var_list=self.actor.parameters('tf_var_list'),
                                             clip_norm=self.parameters('critic_clip_norm'))
        optimize_op = optimizer.apply_gradients(grad_var_pair)
        op = []
        for var, target_var in zip(self.actor.parameters('tf_var_list'),
                                   self.target_actor.parameters('tf_var_list')):
            ref_val = self.parameters('DECAY') * target_var + (1.0 - self.parameters('DECAY')) * var
            op.append(tf.assign(target_var, ref_val))

        return loss, optimize_op, op, optimizer, grads
示例#8
0
class ModelPredictiveControl(ModelBasedAlgo):
    required_key_dict = DictConfig.load_json(
        file_path=GlobalConfig().DEFAULT_MPC_REQUIRED_KEY_LIST)

    def __init__(
        self,
        env_spec,
        dynamics_model: DynamicsModel,
        config_or_config_dict: (DictConfig, dict),
        policy: Policy,
        name='mpc',
    ):
        super().__init__(env_spec, dynamics_model, name)
        self.config = construct_dict_config(config_or_config_dict, self)
        self.policy = policy
        self.parameters = Parameters(parameters=dict(),
                                     source_config=self.config,
                                     name=name + '_' + 'mpc_param')
        self.memory = TransitionData(env_spec=env_spec)

    def init(self, source_obj=None):
        super().init()
        self.parameters.init()
        self._dynamics_model.init()
        self.policy.init()
        if source_obj:
            self.copy_from(source_obj)

    def train(self, *arg, **kwargs) -> dict:
        super(ModelPredictiveControl, self).train()
        res_dict = {}
        batch_data = kwargs[
            'batch_data'] if 'batch_data' in kwargs else self.memory

        dynamics_train_res_dict = self._fit_dynamics_model(
            batch_data=batch_data,
            train_iter=self.parameters('dynamics_model_train_iter'))
        for key, val in dynamics_train_res_dict.items():
            res_dict["mlp_dynamics_{}".format(key)] = val
        return res_dict

    def test(self, *arg, **kwargs) -> dict:
        return super().test(*arg, **kwargs)

    def _fit_dynamics_model(self,
                            batch_data: TransitionData,
                            train_iter,
                            sess=None) -> dict:
        res_dict = self._dynamics_model.train(
            batch_data, **dict(sess=sess, train_iter=train_iter))
        return res_dict

    def predict(self, obs, **kwargs):
        if self.is_training is True:
            return self.env_spec.action_space.sample()
        rollout = TrajectoryData(env_spec=self.env_spec)
        state = obs
        for i in range(self.parameters('SAMPLED_PATH_NUM')):
            path = TransitionData(env_spec=self.env_spec)
            # todo terminal_func signal problem to be consider?
            for _ in range(self.parameters('SAMPLED_HORIZON')):
                ac = self.policy.forward(obs=state)
                new_state, re, done, _ = self.dynamics_env.step(action=ac,
                                                                state=state)
                path.append(state=state,
                            action=ac,
                            new_state=new_state,
                            reward=re,
                            done=done)
                state = new_state
            rollout.append(path)
        rollout.trajectories.sort(key=lambda x: x.cumulative_reward,
                                  reverse=True)
        ac = rollout.trajectories[0].action_set[0]
        assert self.env_spec.action_space.contains(ac)
        return ac

    def append_to_memory(self, samples: TransitionData):
        self.memory.union(samples)

    def copy_from(self, obj) -> bool:
        if not isinstance(obj, type(self)):
            raise TypeError(
                'Wrong type of obj %s to be copied, which should be %s' %
                (type(obj), type(self)))
        self.parameters.copy_from(obj.parameters)
        self._dynamics_model.copy_from(obj._dynamics_model)
        ConsoleLogger().print('info',
                              'model: {} copied from {}'.format(self, obj))
        return True

    @record_return_decorator(which_recorder='self')
    def save(self, global_step, save_path=None, name=None, **kwargs):
        save_path = save_path if save_path else GlobalConfig(
        ).DEFAULT_MODEL_CHECKPOINT_PATH
        name = name if name else self.name

        self._dynamics_model.save(save_path=save_path,
                                  global_step=global_step,
                                  name=name,
                                  **kwargs)
        self.policy.save(save_path=save_path,
                         global_step=global_step,
                         name=name,
                         **kwargs)
        return dict(check_point_save_path=save_path,
                    check_point_save_global_step=global_step,
                    check_point_save_name=name)

    @record_return_decorator(which_recorder='self')
    def load(self, path_to_model, model_name, global_step=None, **kwargs):
        self._dynamics_model.load(path_to_model, model_name, global_step,
                                  **kwargs)
        self.policy.load(path_to_model, model_name, global_step, **kwargs)
        return dict(check_point_load_path=path_to_model,
                    check_point_load_global_step=global_step,
                    check_point_load_name=model_name)
示例#9
0
class MEPPO(ModelBasedAlgo):
    """
    Model Ensemble, Proximal Policy Optimisation

    """
    required_key_dict = DictConfig.load_json(file_path=GlobalConfig().DEFAULT_ALGO_DYNA_REQUIRED_KEY_LIST)

    @init_func_arg_record_decorator()
    @typechecked
    def __init__(self, env_spec, dynamics_model: ModelEnsemble,
                 model_free_algo: ModelFreeAlgo,
                 config_or_config_dict: (DictConfig, dict),
                 name='sample_with_dynamics'
                 ):
        if not isinstance(dynamics_model.model[0], ContinuousMLPGlobalDynamicsModel):
            raise TypeError("Model ensemble elements should be of type ContinuousMLPGlobalDynamicsModel")
        super().__init__(env_spec, dynamics_model, name)
        config = construct_dict_config(config_or_config_dict, self)
        parameters = Parameters(parameters=dict(),
                                name='dyna_param',
                                source_config=config)
        sub_placeholder_input_list = []
        if isinstance(dynamics_model, PlaceholderInput):
            sub_placeholder_input_list.append(dict(obj=dynamics_model,
                                                   attr_name='dynamics_model'))
        if isinstance(model_free_algo, PlaceholderInput):
            sub_placeholder_input_list.append(dict(obj=model_free_algo,
                                                   attr_name='model_free_algo'))
        self.model_free_algo = model_free_algo
        self.config = config
        self.parameters = parameters
        self.result = list()
        self.validation_result = [0] * len(dynamics_model)
        self._dynamics_model.__class__ = ModelEnsemble

    @register_counter_info_to_status_decorator(increment=1, info_key='init', under_status='JUST_INITED')
    def init(self):
        self.parameters.init()
        self.model_free_algo.init()
        self.dynamics_env.init()
        super().init()

    @record_return_decorator(which_recorder='self')
    @register_counter_info_to_status_decorator(increment=1, info_key='train_counter', under_status='TRAIN')
    def train(self, *args, **kwargs) -> dict:
        super(MEPPO, self).train()
        res_dict = {}
        batch_data = kwargs['batch_data'] if 'batch_data' in kwargs else None
        if 'state' in kwargs:
            assert kwargs['state'] in ('state_dynamics_training', 'state_agent_training')
            state = kwargs['state']
            kwargs.pop('state')
        else:
            state = None

        if not state or state == 'state_dynamics_training':

            dynamics_train_res_dict = self._fit_dynamics_model(batch_data=batch_data,
                                                               train_iter=self.parameters('dynamics_model_train_iter'))
            for key, val in dynamics_train_res_dict.items():
                res_dict["{}_{}".format(self._dynamics_model.name, key)] = val
        if not state or state == 'state_agent_training':
            model_free_algo_train_res_dict = self._train_model_free_algo(batch_data=batch_data,
                                                                         train_iter=self.parameters(
                                                                             'model_free_algo_train_iter'))

            for key, val in model_free_algo_train_res_dict.items():
                res_dict['{}_{}'.format(self.model_free_algo.name, key)] = val
        return res_dict

    @register_counter_info_to_status_decorator(increment=1, info_key='test_counter', under_status='TEST')
    def test(self, *arg, **kwargs):
        return super().test(*arg, **kwargs)

    def validate(self, *args, **kwargs):
        old_result = self.result
        self.validation_result = 0
        for a in range(len(self._dynamics_model)):
            individual_model = self._dynamics_model.model[a]
            env = individual_model.return_as_env()
            new_state, reward, terminal, () = env.step(self, *args, **kwargs)
            self.result[a] = reward
            if reward > old_result[a]:
                self.validation_result += 1

        self.validation_result = self.validation_result / len(self._dynamics_model)

        return self.validation_result

    @register_counter_info_to_status_decorator(increment=1, info_key='predict_counter')
    def predict(self, obs, **kwargs):
        return self.model_free_algo.predict(obs)

    def append_to_memory(self, *args, **kwargs):
        self.model_free_algo.append_to_memory(kwargs['samples'])

    @record_return_decorator(which_recorder='self')
    def save(self, global_step, save_path=None, name=None, **kwargs):
        save_path = save_path if save_path else GlobalConfig().DEFAULT_MODEL_CHECKPOINT_PATH
        name = name if name else self.name
        self.model_free_algo.save(global_step=global_step,
                                  name=None,
                                  save_path=os.path.join(save_path, self.model_free_algo.name))
        self.dynamics_env.save(global_step=global_step,
                               name=None,
                               save_path=os.path.join(save_path, self.dynamics_env.name))
        return dict(check_point_save_path=save_path, check_point_save_global_step=global_step,
                    check_point_save_name=name)

    @record_return_decorator(which_recorder='self')
    def load(self, path_to_model, model_name, global_step=None, **kwargs):
        self.model_free_algo.load(path_to_model=os.path.join(path_to_model, self.model_free_algo.name),
                                  model_name=self.model_free_algo.name,
                                  global_step=global_step)
        self.dynamics_env.load(global_step=global_step,
                               path_to_model=os.path.join(path_to_model, self.dynamics_env.name),
                               model_name=self.dynamics_env.name)
        return dict(check_point_load_path=path_to_model, check_point_load_global_step=global_step,
                    check_point_load_name=model_name)

    @register_counter_info_to_status_decorator(increment=1, info_key='dyanmics_train_counter', under_status='TRAIN')
    def _fit_dynamics_model(self, batch_data: TransitionData, train_iter, sess=None) -> dict:
        res_dict = self._dynamics_model.train(batch_data, **dict(sess=sess,
                                                                 train_iter=train_iter))
        return res_dict

    @register_counter_info_to_status_decorator(increment=1, info_key='mode_free_algo_dyanmics_train_counter',
                                               under_status='TRAIN')
    def _train_model_free_algo(self, batch_data=None, train_iter=None, sess=None):
        res_dict = self.model_free_algo.train(**dict(batch_data=batch_data,
                                                     train_iter=train_iter,
                                                     sess=sess))
        return res_dict