Ejemplo n.º 1
0
 def make_vars(self, stepnum='0'):
     # lists over the meta_batch_size
     obs_vars, action_vars, adv_vars = [], [], []
     for i in range(self.meta_batch_size):
         obs_vars.append(self.env.observation_space.new_tensor_variable(
             'obs' + stepnum + '_' + str(i),
             extra_dims=1,
         ))
         action_vars.append(self.env.action_space.new_tensor_variable(
             'action' + stepnum + '_' + str(i),
             extra_dims=1,
         ))
         adv_vars.append(tensor_utils.new_tensor(
             name='advantage' + stepnum + '_' + str(i),
             ndim=1, dtype=tf.float32,
         ))
     return obs_vars, action_vars, adv_vars
Ejemplo n.º 2
0
    def init_opt(self):
        is_recurrent = int(self.policy.recurrent)

        obs_var = self.env.observation_space.new_tensor_variable(
            'obs',
            extra_dims=1 + is_recurrent,
        )
        action_var = self.env.action_space.new_tensor_variable(
            'action',
            extra_dims=1 + is_recurrent,
        )
        advantage_var = tensor_utils.new_tensor(
            name='advantage',
            ndim=1 + is_recurrent,
            dtype=tf.float32,
        )
        dist = self.policy.distribution

        old_dist_info_vars = {
            k: tf.placeholder(tf.float32,
                              shape=[None] * (1 + is_recurrent) + list(shape),
                              name='old_%s' % k)
            for k, shape in dist.dist_info_specs
        }
        old_dist_info_vars_list = [
            old_dist_info_vars[k] for k in dist.dist_info_keys
        ]

        state_info_vars = {
            k: tf.placeholder(tf.float32,
                              shape=[None] * (1 + is_recurrent) + list(shape),
                              name=k)
            for k, shape in self.policy.state_info_specs
        }
        state_info_vars_list = [
            state_info_vars[k] for k in self.policy.state_info_keys
        ]

        if is_recurrent:
            valid_var = tf.placeholder(tf.float32,
                                       shape=[None, None],
                                       name="valid")
        else:
            valid_var = None

        dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars)
        logli = dist.log_likelihood_sym(action_var, dist_info_vars)
        kl = dist.kl_sym(old_dist_info_vars, dist_info_vars)

        # formulate as a minimization problem
        # The gradient of the surrogate objective is the policy gradient
        if is_recurrent:
            surr_obj = -tf.reduce_sum(
                logli * advantage_var * valid_var) / tf.reduce_sum(valid_var)
            mean_kl = tf.reduce_sum(kl * valid_var) / tf.reduce_sum(valid_var)
            max_kl = tf.reduce_max(kl * valid_var)
        else:
            surr_obj = -tf.reduce_mean(logli * advantage_var)
            mean_kl = tf.reduce_mean(kl)
            max_kl = tf.reduce_max(kl)

        input_list = [obs_var, action_var, advantage_var
                      ] + state_info_vars_list
        if is_recurrent:
            input_list.append(valid_var)

        #self.policy.set_init_surr_obj(input_list, [surr_obj])  # debugging
        self.optimizer.update_opt(loss=surr_obj,
                                  target=self.policy,
                                  inputs=input_list)

        f_kl = tensor_utils.compile_function(
            inputs=input_list + old_dist_info_vars_list,
            outputs=[mean_kl, max_kl],
        )
        self.opt_info = dict(f_kl=f_kl, )
Ejemplo n.º 3
0
    def init_opt(self):
        is_recurrent = int(self.policy.recurrent)
        obs_var = self.env.observation_space.new_tensor_variable(
            'obs',
            extra_dims=1 + is_recurrent,
        )
        action_var = self.env.action_space.new_tensor_variable(
            'action',
            extra_dims=1 + is_recurrent,
        )
        advantage_var = tensor_utils.new_tensor(
            'advantage',
            ndim=1 + is_recurrent,
            dtype=tf.float32,
        )
        dist = self.policy.distribution

        old_dist_info_vars = {
            k: tf.placeholder(tf.float32,
                              shape=[None] * (1 + is_recurrent) + list(shape),
                              name='old_%s' % k)
            for k, shape in dist.dist_info_specs
        }
        old_dist_info_vars_list = [
            old_dist_info_vars[k] for k in dist.dist_info_keys
        ]

        state_info_vars = {
            k: tf.placeholder(tf.float32,
                              shape=[None] * (1 + is_recurrent) + list(shape),
                              name=k)
            for k, shape in self.policy.state_info_specs
        }
        state_info_vars_list = [
            state_info_vars[k] for k in self.policy.state_info_keys
        ]

        if is_recurrent:
            valid_var = tf.placeholder(tf.float32,
                                       shape=[None, None],
                                       name="valid")
        else:
            valid_var = None

        dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars)
        kl = dist.kl_sym(old_dist_info_vars, dist_info_vars)
        lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars,
                                       dist_info_vars)
        if is_recurrent:
            mean_kl = tf.reduce_sum(kl * valid_var) / tf.reduce_sum(valid_var)
            surr_loss = -tf.reduce_sum(
                lr * advantage_var * valid_var) / tf.reduce_sum(valid_var)
        else:
            mean_kl = tf.reduce_mean(kl)
            surr_loss = -tf.reduce_mean(lr * advantage_var)

        input_list = [
            obs_var,
            action_var,
            advantage_var,
        ] + state_info_vars_list + old_dist_info_vars_list
        if is_recurrent:
            input_list.append(valid_var)

        self.optimizer.update_opt(loss=surr_loss,
                                  target=self.policy,
                                  leq_constraint=(mean_kl, self.step_size),
                                  inputs=input_list,
                                  constraint_name="mean_kl")
        return dict()