Exemplo n.º 1
0
class MAMLNPO(BatchMAMLPolopt):
    """
    Natural Policy Optimization.
    """
    def __init__(self,
                 optimizer=None,
                 optimizer_args=None,
                 step_size=0.01,
                 use_maml=True,
                 **kwargs):
        assert optimizer is not None  # only for use with MAML TRPO

        self.optimizer = optimizer
        self.offPolicy_optimizer = FirstOrderOptimizer(max_epochs=1)
        self.step_size = step_size
        self.use_maml = use_maml
        self.kl_constrain_step = -1  # needs to be 0 or -1 (original pol params, or new pol params)
        super(MAMLNPO, self).__init__(**kwargs)

    def make_vars(self, stepnum='0'):
        # lists over the meta_batch_size
        obs_vars, action_vars, adv_vars, imp_vars = [], [], [], []
        for i in range(self.meta_batch_size):
            obs_vars.append(
                self.env.observation_space.new_tensor_variable(
                    'obs' + stepnum + '_' + str(i),
                    extra_dims=1,
                ))
            action_vars.append(
                self.env.action_space.new_tensor_variable(
                    'action' + stepnum + '_' + str(i),
                    extra_dims=1,
                ))
            adv_vars.append(
                tensor_utils.new_tensor(
                    name='advantage' + stepnum + '_' + str(i),
                    ndim=1,
                    dtype=tf.float32,
                ))

            imp_vars.append(
                tensor_utils.new_tensor(
                    name='imp_ratios' + stepnum + '_' + str(i),
                    ndim=1,
                    dtype=tf.float32,
                ))

        return obs_vars, action_vars, adv_vars, imp_vars

    @overrides
    def init_opt(self):
        is_recurrent = int(self.policy.recurrent)
        assert not is_recurrent  # not supported

        dist = self.policy.distribution

        old_dist_info_vars, old_dist_info_vars_list = [], []
        for i in range(self.meta_batch_size):
            old_dist_info_vars.append({
                k: tf.placeholder(tf.float32,
                                  shape=[None] + list(shape),
                                  name='old_%s_%s' % (i, k))
                for k, shape in dist.dist_info_specs
            })
            old_dist_info_vars_list += [
                old_dist_info_vars[i][k] for k in dist.dist_info_keys
            ]

        state_info_vars, state_info_vars_list = {}, []

        all_surr_objs, input_list = [], []
        new_params = None
        for j in range(self.num_grad_updates):
            obs_vars, action_vars, adv_vars, _ = self.make_vars(str(j))
            surr_objs = []

            cur_params = new_params
            new_params = [
            ]  # if there are several grad_updates the new_params are overwritten
            kls = []

            for i in range(self.meta_batch_size):
                if j == 0:
                    dist_info_vars, params = self.policy.dist_info_sym(
                        obs_vars[i],
                        state_info_vars,
                        all_params=self.policy.all_params)
                    if self.kl_constrain_step == 0:
                        kl = dist.kl_sym(old_dist_info_vars[i], dist_info_vars)
                        kls.append(kl)
                else:
                    dist_info_vars, params = self.policy.updated_dist_info_sym(
                        i,
                        all_surr_objs[-1][i],
                        obs_vars[i],
                        params_dict=cur_params[i])

                new_params.append(params)
                logli = dist.log_likelihood_sym(action_vars[i], dist_info_vars)

                # formulate as a minimization problem
                # The gradient of the surrogate objective is the policy gradient
                surr_objs.append(-tf.reduce_mean(logli * adv_vars[i]))

            input_list += obs_vars + action_vars + adv_vars + state_info_vars_list
            if j == 0:
                # For computing the fast update for sampling
                self.policy.set_init_surr_obj(input_list, surr_objs)
                init_input_list = input_list

            all_surr_objs.append(surr_objs)

        obs_vars, action_vars, adv_vars, _ = self.make_vars('test')
        surr_objs = []
        for i in range(self.meta_batch_size):
            dist_info_vars, _ = self.policy.updated_dist_info_sym(
                i,
                all_surr_objs[-1][i],
                obs_vars[i],
                params_dict=new_params[i])

            if self.kl_constrain_step == -1:  # if we only care about the kl of the last step, the last item in kls will be the overall
                kl = dist.kl_sym(old_dist_info_vars[i], dist_info_vars)
                kls.append(kl)
            lr = dist.likelihood_ratio_sym(action_vars[i],
                                           old_dist_info_vars[i],
                                           dist_info_vars)
            surr_objs.append(-tf.reduce_mean(lr * adv_vars[i]))

        if self.use_maml:
            surr_obj = tf.reduce_mean(tf.stack(
                surr_objs, 0))  # mean over meta_batch_size (the diff tasks)
            input_list += obs_vars + action_vars + adv_vars + old_dist_info_vars_list
        else:
            surr_obj = tf.reduce_mean(
                tf.stack(all_surr_objs[0],
                         0))  # if not meta, just use the first surr_obj
            input_list = init_input_list

        if self.use_maml:
            mean_kl = tf.reduce_mean(
                tf.concat(kls, 0)
            )  ##CF shouldn't this have the option of self.kl_constrain_step == -1?
            max_kl = tf.reduce_max(tf.concat(kls, 0))

            self.optimizer.update_opt(loss=surr_obj,
                                      target=self.policy,
                                      leq_constraint=(mean_kl, self.step_size),
                                      inputs=input_list,
                                      constraint_name="mean_kl")
        else:
            self.optimizer.update_opt(
                loss=surr_obj,
                target=self.policy,
                inputs=input_list,
            )
        return dict()

    @overrides
    def init_opt_offPolicy(self):

        is_recurrent = int(self.policy.recurrent)
        assert not is_recurrent  # not supported
        dist = self.policy.distribution
        state_info_vars, state_info_vars_list = {}, []
        all_surr_objs, input_list = [], []
        new_params = None

        for j in range(self.num_grad_updates):
            obs_vars, action_vars, adv_vars, imp_vars = self.make_vars(str(j))
            surr_objs = []

            cur_params = new_params
            new_params = [
            ]  # if there are several grad_updates the new_params are overwritten

            for i in range(self.meta_batch_size):
                if j == 0:
                    dist_info_vars, params = self.policy.dist_info_sym(
                        obs_vars[i],
                        state_info_vars,
                        all_params=self.policy.all_params)

                else:
                    dist_info_vars, params = self.policy.updated_dist_info_sym(
                        i,
                        all_surr_objs[-1][i],
                        obs_vars[i],
                        params_dict=cur_params[i])

                new_params.append(params)
                logli = dist.log_likelihood_sym(action_vars[i], dist_info_vars)

                # formulate as a minimization problem
                # The gradient of the surrogate objective is the policy gradient
                surr_objs.append(-tf.reduce_mean(logli * imp_vars[i] *
                                                 adv_vars[i]))

            input_list += obs_vars + action_vars + adv_vars + imp_vars
            all_surr_objs.append(surr_objs)

        obs_vars, action_vars, _, _ = self.make_vars('test')
        surr_objs = []
        for i in range(self.meta_batch_size):

            dist_info_vars, _ = self.policy.updated_dist_info_sym(
                i,
                all_surr_objs[-1][i],
                obs_vars[i],
                params_dict=new_params[i])
            logli = dist.log_likelihood_sym(action_vars[i], dist_info_vars)
            surr_objs.append(-tf.reduce_mean(logli))

        surr_obj = tf.reduce_mean(tf.stack(
            surr_objs, 0))  # mean over meta_batch_size (the diff tasks)
        input_list += obs_vars + action_vars

        self.offPolicy_optimizer.update_opt(
            loss=surr_obj,
            target=self.policy,
            inputs=input_list,
        )

    def offPolicy_optimization_step(self, samples_data, expert_data):

        input_list = []
        #for step in range(len(all_samples_data)):  # these are the gradient steps
        obs_list, action_list, adv_list , imp_list , expert_obs_list , expert_action_list = [], [], [] , [] , [], []
        for i in range(self.meta_batch_size):

            inputs = ext.extract(samples_data[i], "observations", "actions",
                                 "advantages", 'traj_imp_weights')
            obs_list.append(inputs[0])
            action_list.append(inputs[1])
            adv_list.append(inputs[2])
            imp_list.append(inputs[3])

            expert_inputs = ext.extract(expert_data[i], "observations",
                                        "actions")
            expert_obs_list.append(expert_inputs[0])
            expert_action_list.append(expert_inputs[1])

        input_list += obs_list + action_list + adv_list + imp_list + expert_obs_list + expert_action_list

        self.offPolicy_optimizer.optimize(input_list)

    @overrides
    def optimize_policy(self, itr, all_samples_data):
        assert len(
            all_samples_data
        ) == self.num_grad_updates + 1  # we collected the rollouts to compute the grads and then the test!

        if not self.use_maml:
            all_samples_data = [all_samples_data[0]]

        input_list = []
        for step in range(
                len(all_samples_data)):  # these are the gradient steps
            obs_list, action_list, adv_list = [], [], []
            for i in range(self.meta_batch_size):

                inputs = ext.extract(all_samples_data[step][i], "observations",
                                     "actions", "advantages")
                obs_list.append(inputs[0])
                action_list.append(inputs[1])
                adv_list.append(inputs[2])
            input_list += obs_list + action_list + adv_list  # [ [obs_0], [act_0], [adv_0], [obs_1], ... ]

            if step == 0:  ##CF not used?
                init_inputs = input_list

        if self.use_maml:
            dist_info_list = []
            for i in range(self.meta_batch_size):
                agent_infos = all_samples_data[
                    self.kl_constrain_step][i]['agent_infos']
                dist_info_list += [
                    agent_infos[k]
                    for k in self.policy.distribution.dist_info_keys
                ]
            input_list += tuple(dist_info_list)
            logger.log("Computing KL before")
            mean_kl_before = self.optimizer.constraint_val(input_list)

        logger.log("Computing loss before")
        loss_before = self.optimizer.loss(input_list)
        logger.log("Optimizing")
        self.optimizer.optimize(input_list)
        logger.log("Computing loss after")
        loss_after = self.optimizer.loss(input_list)
        if self.use_maml:
            logger.log("Computing KL after")
            mean_kl = self.optimizer.constraint_val(input_list)
            logger.record_tabular('MeanKLBefore',
                                  mean_kl_before)  # this now won't be 0!
            logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('LossBefore', loss_before)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict()

    @overrides
    def get_itr_snapshot(self, itr, samples_data):
        return dict(
            itr=itr,
            policy=self.policy,
            baseline=self.baseline,
            env=self.env,
        )
Exemplo n.º 2
0
class NPO(BatchPolopt):
    """
    Natural Policy Optimization.
    """
    def __init__(self,
                 optimizer_class=None,
                 optimizer_args=None,
                 step_size=0.01,
                 penalty=0.0,
                 **kwargs):

        self.optimizer_class = default(optimizer_class, PenaltyLbfgsOptimizer)
        self.optimizer_args = default(optimizer_args, dict())

        self.penalty = penalty
        self.constrain_together = penalty > 0

        self.step_size = step_size

        self.metrics = []
        super(NPO, self).__init__(**kwargs)

    @overrides
    def init_opt(self):

        ###############################
        #
        # Variable Definitions
        #
        ###############################

        all_task_dist_info_vars = []
        all_obs_vars = []

        for i, policy in enumerate(self.local_policies):

            task_obs_var = self.env_partitions[
                i].observation_space.new_tensor_variable('obs%d' % i,
                                                         extra_dims=1)
            task_dist_info_vars = []

            for j, other_policy in enumerate(self.local_policies):

                state_info_vars = dict()  # Not handling recurrent policies
                dist_info_vars = other_policy.dist_info_sym(
                    task_obs_var, state_info_vars)
                task_dist_info_vars.append(dist_info_vars)

            all_obs_vars.append(task_obs_var)
            all_task_dist_info_vars.append(task_dist_info_vars)

        obs_var = self.env.observation_space.new_tensor_variable('obs',
                                                                 extra_dims=1)
        action_var = self.env.action_space.new_tensor_variable('action',
                                                               extra_dims=1)
        advantage_var = tensor_utils.new_tensor('advantage',
                                                ndim=1,
                                                dtype=tf.float32)

        old_dist_info_vars = {
            k: tf.placeholder(tf.float32,
                              shape=[None] + list(shape),
                              name='old_%s' % k)
            for k, shape in self.policy.distribution.dist_info_specs
        }

        old_dist_info_vars_list = [
            old_dist_info_vars[k]
            for k in self.policy.distribution.dist_info_keys
        ]

        input_list = [obs_var, action_var, advantage_var
                      ] + old_dist_info_vars_list + all_obs_vars

        ###############################
        #
        # Local Policy Optimization
        #
        ###############################

        self.optimizers = []
        self.metrics = []

        for n, policy in enumerate(self.local_policies):

            state_info_vars = dict()
            dist_info_vars = policy.dist_info_sym(obs_var, state_info_vars)
            dist = policy.distribution

            kl = dist.kl_sym(old_dist_info_vars, dist_info_vars)
            lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars,
                                           dist_info_vars)
            surr_loss = -tf.reduce_mean(lr * advantage_var)

            if self.constrain_together:
                additional_loss = Metrics.kl_on_others(
                    n, dist, all_task_dist_info_vars)
            else:
                additional_loss = tf.constant(0.0)

            local_loss = surr_loss + self.penalty * additional_loss

            kl_metric = tensor_utils.compile_function(inputs=input_list,
                                                      outputs=additional_loss,
                                                      log_name="KLPenalty%d" %
                                                      n)
            self.metrics.append(kl_metric)

            mean_kl_constraint = tf.reduce_mean(kl)

            optimizer = self.optimizer_class(**self.optimizer_args)
            optimizer.update_opt(
                loss=local_loss,
                target=policy,
                leq_constraint=(mean_kl_constraint, self.step_size),
                inputs=input_list,
                constraint_name="mean_kl_%d" % n,
            )
            self.optimizers.append(optimizer)

        ###############################
        #
        # Global Policy Optimization
        #
        ###############################

        # Behaviour Cloning Loss

        state_info_vars = dict()
        center_dist_info_vars = self.policy.dist_info_sym(
            obs_var, state_info_vars)
        behaviour_cloning_loss = tf.losses.mean_squared_error(
            action_var, center_dist_info_vars['mean'])
        self.center_optimizer = FirstOrderOptimizer(max_epochs=1,
                                                    verbose=True,
                                                    batch_size=1000)
        self.center_optimizer.update_opt(behaviour_cloning_loss, self.policy,
                                         [obs_var, action_var])

        # TRPO Loss

        kl = dist.kl_sym(old_dist_info_vars, center_dist_info_vars)
        lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars,
                                       center_dist_info_vars)
        center_trpo_loss = -tf.reduce_mean(lr * advantage_var)
        mean_kl_constraint = tf.reduce_mean(kl)

        optimizer = self.optimizer_class(**self.optimizer_args)
        optimizer.update_opt(
            loss=center_trpo_loss,
            target=self.policy,
            leq_constraint=(mean_kl_constraint, self.step_size),
            inputs=[obs_var, action_var, advantage_var] +
            old_dist_info_vars_list,
            constraint_name="mean_kl_center",
        )

        self.center_trpo_optimizer = optimizer

        # Reset Local Policies to Global Policy

        assignment_operations = []

        for policy in self.local_policies:
            for param_local, param_center in zip(
                    policy.get_params_internal(),
                    self.policy.get_params_internal()):
                if 'std' not in param_local.name:
                    assignment_operations.append(
                        tf.assign(param_local, param_center))

        self.reset_to_center = tf.group(*assignment_operations)

        return dict()

    def optimize_local_policies(self, itr, all_samples_data):

        dist_info_keys = self.policy.distribution.dist_info_keys
        for n, optimizer in enumerate(self.optimizers):

            obs_act_adv_values = tuple(
                ext.extract(all_samples_data[n], "observations", "actions",
                            "advantages"))
            dist_info_list = tuple([
                all_samples_data[n]["agent_infos"][k] for k in dist_info_keys
            ])
            all_task_obs_values = tuple([
                samples_data["observations"]
                for samples_data in all_samples_data
            ])

            all_input_values = obs_act_adv_values + dist_info_list + all_task_obs_values
            optimizer.optimize(all_input_values)

            kl_penalty = sliced_fun(self.metrics[n], 1)(all_input_values)
            logger.record_tabular('KLPenalty%d' % n, kl_penalty)

    def optimize_global_policy(self, itr, all_samples_data):

        all_observations = np.concatenate([
            samples_data['observations'] for samples_data in all_samples_data
        ])
        all_actions = np.concatenate([
            samples_data['agent_infos']['mean']
            for samples_data in all_samples_data
        ])

        num_itrs = 1 if itr % self.distillation_period != 0 else 30

        for _ in range(num_itrs):
            self.center_optimizer.optimize([all_observations, all_actions])

        paths = self.global_sampler.obtain_samples(itr)
        samples_data = self.global_sampler.process_samples(itr, paths)

        obs_values = tuple(
            ext.extract(samples_data, "observations", "actions", "advantages"))
        dist_info_list = [
            samples_data["agent_infos"][k]
            for k in self.policy.distribution.dist_info_keys
        ]

        all_input_values = obs_values + tuple(dist_info_list)

        self.center_trpo_optimizer.optimize(all_input_values)
        self.env.log_diagnostics(paths)

    @overrides
    def optimize_policy(self, itr, all_samples_data):

        self.optimize_local_policies(itr, all_samples_data)
        self.optimize_global_policy(itr, all_samples_data)

        if itr % self.distillation_period == 0:
            sess = tf.get_default_session()
            sess.run(self.reset_to_center)
            logger.log('Reset Local Policies to Global Policies')

        return dict()