Пример #1
0
    def _build_inference_loss(self, i):
        """ Build loss function for the inference network """

        infer_dist = self.inference._dist
        with tf.name_scope("infer_loss"):
            traj_ll_flat = self.inference.log_likelihood_sym(
                i.flat.trajectory_var, i.flat.latent_var, name="traj_ll_flat")
            traj_ll = tf.reshape(traj_ll_flat, [-1, self.max_path_length],
                                 name="traj_ll")

            # Calculate loss
            traj_gammas = tf.constant(float(self.discount),
                                      dtype=tf.float32,
                                      shape=[self.max_path_length])
            traj_discounts = tf.cumprod(traj_gammas,
                                        exclusive=True,
                                        name="traj_discounts")
            discount_traj_ll = traj_discounts * traj_ll
            discount_traj_ll_flat = flatten_batch(discount_traj_ll,
                                                  name="discount_traj_ll_flat")
            discount_traj_ll_valid = filter_valids(
                discount_traj_ll_flat,
                i.flat.valid_var,
                name="discount_traj_ll_valid")

            with tf.name_scope("loss"):
                infer_loss = -tf.reduce_mean(discount_traj_ll_valid,
                                             name="infer_loss")

            with tf.name_scope("kl"):
                # Calculate predicted embedding distributions for each timestep
                infer_dist_info_flat = self.inference.dist_info_sym(
                    i.flat.trajectory_var,
                    i.flat.infer_state_info_vars,
                    name="infer_dist_info_flat")

                infer_dist_info_valid = filter_valids_dict(
                    infer_dist_info_flat,
                    i.flat.valid_var,
                    name="infer_dist_info_valid")

                # Calculate KL divergence
                kl = infer_dist.kl_sym(i.valid.infer_old_dist_info_vars,
                                       infer_dist_info_valid)
                infer_kl = tf.reduce_mean(kl, name="infer_kl")

            return infer_loss, infer_kl
Пример #2
0
    def _build_policy_loss(self, i):
        """ Build policy network loss """
        pol_dist = self.policy._dist

        # Entropy terms
        embedding_entropy, inference_ce, policy_entropy = \
            self._build_entropy_terms(i)

        # Augment the path rewards with entropy terms
        with tf.name_scope("augmented_rewards"):
            rewards = i.reward_var \
                      - (self.inference_ce_coeff * inference_ce) \
                      + (self.policy_ent_coeff * policy_entropy)

        with tf.name_scope("policy_loss"):
            with tf.name_scope("advantages"):
                advantages = compute_advantages(self.discount,
                                                self.gae_lambda,
                                                self.max_path_length,
                                                i.baseline_var,
                                                rewards,
                                                name="advantages")

                # Flatten and filter valids
                adv_flat = flatten_batch(advantages, name="adv_flat")
                adv_valid = filter_valids(adv_flat,
                                          i.flat.valid_var,
                                          name="adv_valid")

            policy_dist_info_flat = self.policy.dist_info_sym(
                i.flat.task_var,
                i.flat.obs_var,
                i.flat.policy_state_info_vars,
                name="policy_dist_info_flat")
            policy_dist_info_valid = filter_valids_dict(
                policy_dist_info_flat,
                i.flat.valid_var,
                name="policy_dist_info_valid")

            # Optionally normalize advantages
            eps = tf.constant(1e-8, dtype=tf.float32)
            if self.center_adv:
                with tf.name_scope("center_adv"):
                    mean, var = tf.nn.moments(adv_valid, axes=[0])
                    adv_valid = tf.nn.batch_normalization(
                        adv_valid, mean, var, 0, 1, eps)
            if self.positive_adv:
                with tf.name_scope("positive_adv"):
                    m = tf.reduce_min(adv_valid)
                    adv_valid = (adv_valid - m) + eps

            # Calculate loss function and KL divergence
            with tf.name_scope("kl"):
                kl = pol_dist.kl_sym(
                    i.valid.policy_old_dist_info_vars,
                    policy_dist_info_valid,
                )
                pol_mean_kl = tf.reduce_mean(kl)

            # Calculate surrogate loss
            with tf.name_scope("surr_loss"):
                lr = pol_dist.likelihood_ratio_sym(
                    i.valid.action_var,
                    i.valid.policy_old_dist_info_vars,
                    policy_dist_info_valid,
                    name="lr")

                # Policy gradient surrogate objective
                surr_vanilla = lr * adv_valid

                if self._pg_loss == PGLoss.VANILLA:
                    # VPG, TRPO use the standard surrogate objective
                    surr_obj = tf.identity(surr_vanilla, name="surr_obj")
                elif self._pg_loss == PGLoss.CLIP:
                    # PPO uses a surrogate objective with clipped LR
                    lr_clip = tf.clip_by_value(lr,
                                               1 - self.lr_clip_range,
                                               1 + self.lr_clip_range,
                                               name="lr_clip")
                    surr_clip = lr_clip * adv_valid
                    surr_obj = tf.minimum(surr_vanilla,
                                          surr_clip,
                                          name="surr_obj")
                else:
                    raise NotImplementedError("Unknown PGLoss")

                # Maximize E[surrogate objective] by minimizing
                # -E_t[surrogate objective]
                surr_loss = -tf.reduce_mean(surr_obj)

                # Embedding entropy bonus
                surr_loss -= self.embedding_ent_coeff * embedding_entropy

            embed_mean_kl = self._build_embedding_kl(i)

        # Diagnostic functions
        self.f_policy_kl = tensor_utils.compile_function(
            flatten_inputs(self._policy_opt_inputs),
            pol_mean_kl,
            log_name="f_policy_kl")

        self.f_rewards = tensor_utils.compile_function(flatten_inputs(
            self._policy_opt_inputs),
                                                       rewards,
                                                       log_name="f_rewards")

        # returns = self._build_returns(rewards)
        returns = discounted_returns(self.discount,
                                     self.max_path_length,
                                     rewards,
                                     name="returns")
        self.f_returns = tensor_utils.compile_function(flatten_inputs(
            self._policy_opt_inputs),
                                                       returns,
                                                       log_name="f_returns")

        return surr_loss, pol_mean_kl, embed_mean_kl
Пример #3
0
    def _build_policy_loss(self, i):
        pol_dist = self.policy.distribution

        policy_entropy = self._build_entropy_term(i)

        with tf.name_scope("augmented_rewards"):
            rewards = i.reward_var + (self.policy_ent_coeff * policy_entropy)

        with tf.name_scope("policy_loss"):
            advantages = compute_advantages(
                self.discount,
                self.gae_lambda,
                self.max_path_length,
                i.baseline_var,
                rewards,
                name="advantages")

            adv_flat = flatten_batch(advantages, name="adv_flat")
            adv_valid = filter_valids(
                adv_flat, i.flat.valid_var, name="adv_valid")

            if self.policy.recurrent:
                advantages = tf.reshape(advantages, [-1, self.max_path_length])

            # Optionally normalize advantages
            eps = tf.constant(1e-8, dtype=tf.float32)
            if self.center_adv:
                with tf.name_scope("center_adv"):
                    mean, var = tf.nn.moments(adv_valid, axes=[0])
                    adv_valid = tf.nn.batch_normalization(
                        adv_valid, mean, var, 0, 1, eps)
            if self.positive_adv:
                with tf.name_scope("positive_adv"):
                    m = tf.reduce_min(adv_valid)
                    adv_valid = (adv_valid - m) + eps

            if self.policy.recurrent:
                policy_dist_info = self.policy.dist_info_sym(
                    i.obs_var,
                    i.policy_state_info_vars,
                    name="policy_dist_info")
            else:
                policy_dist_info_flat = self.policy.dist_info_sym(
                    i.flat.obs_var,
                    i.flat.policy_state_info_vars,
                    name="policy_dist_info_flat")

                policy_dist_info_valid = filter_valids_dict(
                    policy_dist_info_flat,
                    i.flat.valid_var,
                    name="policy_dist_info_valid")

            # Calculate loss function and KL divergence
            with tf.name_scope("kl"):
                if self.policy.recurrent:
                    kl = pol_dist.kl_sym(
                        i.policy_old_dist_info_vars,
                        policy_dist_info,
                    )
                    pol_mean_kl = tf.reduce_sum(
                        kl * i.valid_var) / tf.reduce_sum(i.valid_var)
                else:
                    kl = pol_dist.kl_sym(
                        i.valid.policy_old_dist_info_vars,
                        policy_dist_info_valid,
                    )
                    pol_mean_kl = tf.reduce_mean(kl)

            # Calculate vanilla loss
            with tf.name_scope("vanilla_loss"):
                if self.policy.recurrent:
                    ll = pol_dist.log_likelihood_sym(
                        i.action_var, policy_dist_info, name="log_likelihood")

                    vanilla = ll * advantages * i.valid_var
                else:
                    ll = pol_dist.log_likelihood_sym(
                        i.valid.action_var,
                        policy_dist_info_valid,
                        name="log_likelihood")

                    vanilla = ll * adv_valid

            # Calculate surrogate loss
            with tf.name_scope("surrogate_loss"):
                if self.policy.recurrent:
                    lr = pol_dist.likelihood_ratio_sym(
                        i.action_var,
                        i.policy_old_dist_info_vars,
                        policy_dist_info,
                        name="lr")

                    surrogate = lr * advantages * i.valid_var
                else:
                    lr = pol_dist.likelihood_ratio_sym(
                        i.valid.action_var,
                        i.valid.policy_old_dist_info_vars,
                        policy_dist_info_valid,
                        name="lr")

                    surrogate = lr * adv_valid

            # Finalize objective function
            with tf.name_scope("loss"):
                if self._pg_loss == PGLoss.VANILLA:
                    # VPG uses the vanilla objective
                    obj = tf.identity(vanilla, name="vanilla_obj")
                elif self._pg_loss == PGLoss.SURROGATE:
                    # TRPO uses the standard surrogate objective
                    obj = tf.identity(surrogate, name="surr_obj")
                elif self._pg_loss == PGLoss.SURROGATE_CLIP:
                    lr_clip = tf.clip_by_value(
                        lr,
                        1 - self.lr_clip_range,
                        1 + self.lr_clip_range,
                        name="lr_clip")
                    if self.policy.recurrent:
                        surr_clip = lr_clip * advantages * i.valid_var
                    else:
                        surr_clip = lr_clip * adv_valid
                    obj = tf.minimum(surrogate, surr_clip, name="surr_obj")
                else:
                    raise NotImplementedError("Unknown PGLoss")

                # Maximize E[surrogate objective] by minimizing
                # -E_t[surrogate objective]
                if self.policy.recurrent:
                    loss = -tf.reduce_sum(obj) / tf.reduce_sum(i.valid_var)
                else:
                    loss = -tf.reduce_mean(obj)

            # Diagnostic functions
            self.f_policy_kl = tensor_utils.compile_function(
                flatten_inputs(self._policy_opt_inputs),
                pol_mean_kl,
                log_name="f_policy_kl")

            self.f_rewards = tensor_utils.compile_function(
                flatten_inputs(self._policy_opt_inputs),
                rewards,
                log_name="f_rewards")

            returns = discounted_returns(self.discount, self.max_path_length,
                                         rewards)
            self.f_returns = tensor_utils.compile_function(
                flatten_inputs(self._policy_opt_inputs),
                returns,
                log_name="f_returns")

            return loss, pol_mean_kl
Пример #4
0
    def _build_inputs(self):
        """
        Builds input variables (and trivial views thereof) for the loss
        function network
        """

        observation_space = self.policy.observation_space
        action_space = self.policy.action_space
        task_space = self.policy.task_space
        latent_space = self.policy.latent_space
        trajectory_space = self.inference.input_space

        policy_dist = self.policy._dist
        embed_dist = self.policy.embedding._dist
        infer_dist = self.inference._dist

        with tf.name_scope("inputs"):
            obs_var = observation_space.new_tensor_variable(
                'obs',
                extra_dims=1 + 1,
            )

            task_var = task_space.new_tensor_variable(
                'task',
                extra_dims=1 + 1,
            )

            action_var = action_space.new_tensor_variable(
                'action',
                extra_dims=1 + 1,
            )

            reward_var = tensor_utils.new_tensor(
                'reward',
                ndim=1 + 1,
                dtype=tf.float32,
            )

            latent_var = latent_space.new_tensor_variable(
                'latent',
                extra_dims=1 + 1,
            )

            baseline_var = tensor_utils.new_tensor(
                'baseline',
                ndim=1 + 1,
                dtype=tf.float32,
            )

            trajectory_var = trajectory_space.new_tensor_variable(
                'trajectory',
                extra_dims=1 + 1,
            )

            valid_var = tf.placeholder(tf.float32,
                                       shape=[None, None],
                                       name="valid")

            # Policy state (for RNNs)
            policy_state_info_vars = {
                k: tf.placeholder(tf.float32,
                                  shape=[None] * (1 + 1) + list(shape),
                                  name=k)
                for k, shape in self.policy.state_info_specs
            }
            policy_state_info_vars_list = [
                policy_state_info_vars[k] for k in self.policy.state_info_keys
            ]

            # Old policy distribution (for KL)
            policy_old_dist_info_vars = {
                k: tf.placeholder(tf.float32,
                                  shape=[None] * (1 + 1) + list(shape),
                                  name='policy_old_%s' % k)
                for k, shape in policy_dist.dist_info_specs
            }
            policy_old_dist_info_vars_list = [
                policy_old_dist_info_vars[k]
                for k in policy_dist.dist_info_keys
            ]

            # Embedding state (for RNNs)
            embed_state_info_vars = {
                k: tf.placeholder(tf.float32,
                                  shape=[None] * (1 + 1) + list(shape),
                                  name='embed_%s' % k)
                for k, shape in self.policy.embedding.state_info_specs
            }
            embed_state_info_vars_list = [
                embed_state_info_vars[k]
                for k in self.policy.embedding.state_info_keys
            ]

            # Old embedding distribution (for KL)
            embed_old_dist_info_vars = {
                k: tf.placeholder(tf.float32,
                                  shape=[None] * (1 + 1) + list(shape),
                                  name='embed_old_%s' % k)
                for k, shape in embed_dist.dist_info_specs
            }
            embed_old_dist_info_vars_list = [
                embed_old_dist_info_vars[k] for k in embed_dist.dist_info_keys
            ]

            # Inference distribution state (for RNNs)
            infer_state_info_vars = {
                k: tf.placeholder(tf.float32,
                                  shape=[None] * (1 + 1) + list(shape),
                                  name='infer_%s' % k)
                for k, shape in self.inference.state_info_specs
            }
            infer_state_info_vars_list = [
                infer_state_info_vars[k]
                for k in self.inference.state_info_keys
            ]

            # Old inference distribution (for KL)
            infer_old_dist_info_vars = {
                k: tf.placeholder(tf.float32,
                                  shape=[None] * (1 + 1) + list(shape),
                                  name='infer_old_%s' % k)
                for k, shape in infer_dist.dist_info_specs
            }
            infer_old_dist_info_vars_list = [
                infer_old_dist_info_vars[k] for k in infer_dist.dist_info_keys
            ]

            # Flattened view
            with tf.name_scope("flat"):
                obs_flat = flatten_batch(obs_var, name="obs_flat")
                task_flat = flatten_batch(task_var, name="task_flat")
                action_flat = flatten_batch(action_var, name="action_flat")
                reward_flat = flatten_batch(reward_var, name="reward_flat")
                latent_flat = flatten_batch(latent_var, name="latent_flat")
                trajectory_flat = flatten_batch(trajectory_var,
                                                name="trajectory_flat")
                valid_flat = flatten_batch(valid_var, name="valid_flat")
                policy_state_info_vars_flat = flatten_batch_dict(
                    policy_state_info_vars, name="policy_state_info_vars_flat")
                policy_old_dist_info_vars_flat = flatten_batch_dict(
                    policy_old_dist_info_vars,
                    name="policy_old_dist_info_vars_flat")
                embed_state_info_vars_flat = flatten_batch_dict(
                    embed_state_info_vars, name="embed_state_info_vars_flat")
                embed_old_dist_info_vars_flat = flatten_batch_dict(
                    embed_old_dist_info_vars,
                    name="embed_old_dist_info_vars_flat")
                infer_state_info_vars_flat = flatten_batch_dict(
                    infer_state_info_vars, name="infer_state_info_vars_flat")
                infer_old_dist_info_vars_flat = flatten_batch_dict(
                    infer_old_dist_info_vars,
                    name="infer_old_dist_info_vars_flat")

            # Valid view
            with tf.name_scope("valid"):
                action_valid = filter_valids(action_flat,
                                             valid_flat,
                                             name="action_valid")
                policy_state_info_vars_valid = filter_valids_dict(
                    policy_state_info_vars_flat,
                    valid_flat,
                    name="policy_state_info_vars_valid")
                policy_old_dist_info_vars_valid = filter_valids_dict(
                    policy_old_dist_info_vars_flat,
                    valid_flat,
                    name="policy_old_dist_info_vars_valid")
                embed_old_dist_info_vars_valid = filter_valids_dict(
                    embed_old_dist_info_vars_flat,
                    valid_flat,
                    name="embed_old_dist_info_vars_valid")
                infer_old_dist_info_vars_valid = filter_valids_dict(
                    infer_old_dist_info_vars_flat,
                    valid_flat,
                    name="infer_old_dist_info_vars_valid")

        # Policy and embedding network loss and optimizer inputs
        pol_flat = graph_inputs(
            "PolicyLossInputsFlat",
            obs_var=obs_flat,
            task_var=task_flat,
            action_var=action_flat,
            reward_var=reward_flat,
            latent_var=latent_flat,
            trajectory_var=trajectory_flat,
            valid_var=valid_flat,
            policy_state_info_vars=policy_state_info_vars_flat,
            policy_old_dist_info_vars=policy_old_dist_info_vars_flat,
            embed_state_info_vars=embed_state_info_vars_flat,
            embed_old_dist_info_vars=embed_old_dist_info_vars_flat,
        )
        pol_valid = graph_inputs(
            "PolicyLossInputsValid",
            action_var=action_valid,
            policy_state_info_vars=policy_state_info_vars_valid,
            policy_old_dist_info_vars=policy_old_dist_info_vars_valid,
            embed_old_dist_info_vars=embed_old_dist_info_vars_valid,
        )
        policy_loss_inputs = graph_inputs(
            "PolicyLossInputs",
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            baseline_var=baseline_var,
            trajectory_var=trajectory_var,
            task_var=task_var,
            latent_var=latent_var,
            valid_var=valid_var,
            policy_state_info_vars=policy_state_info_vars,
            policy_old_dist_info_vars=policy_old_dist_info_vars,
            embed_state_info_vars=embed_state_info_vars,
            embed_old_dist_info_vars=embed_old_dist_info_vars,
            flat=pol_flat,
            valid=pol_valid,
        )
        # Special variant for the optimizer
        # * Uses lists instead of dicts for the distribution parameters
        # * Omits flats and valids
        # TODO: eliminate
        policy_opt_inputs = graph_inputs(
            "PolicyOptInputs",
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            baseline_var=baseline_var,
            trajectory_var=trajectory_var,
            task_var=task_var,
            latent_var=latent_var,
            valid_var=valid_var,
            policy_state_info_vars_list=policy_state_info_vars_list,
            policy_old_dist_info_vars_list=policy_old_dist_info_vars_list,
            embed_state_info_vars_list=embed_state_info_vars_list,
            embed_old_dist_info_vars_list=embed_old_dist_info_vars_list,
        )

        # Inference network loss and optimizer inputs
        infer_flat = graph_inputs(
            "InferenceLossInputsFlat",
            latent_var=latent_flat,
            trajectory_var=trajectory_flat,
            valid_var=valid_flat,
            infer_state_info_vars=infer_state_info_vars_flat,
            infer_old_dist_info_vars=infer_old_dist_info_vars_flat,
        )
        infer_valid = graph_inputs(
            "InferenceLossInputsValid",
            infer_old_dist_info_vars=infer_old_dist_info_vars_valid,
        )
        inference_loss_inputs = graph_inputs(
            "InferenceLossInputs",
            latent_var=latent_var,
            trajectory_var=trajectory_var,
            valid_var=valid_var,
            infer_state_info_vars=infer_state_info_vars,
            infer_old_dist_info_vars=infer_old_dist_info_vars,
            flat=infer_flat,
            valid=infer_valid,
        )
        # Special variant for the optimizer
        # * Uses lists instead of dicts for the distribution parameters
        # * Omits flats and valids
        # TODO: eliminate
        inference_opt_inputs = graph_inputs(
            "InferenceOptInputs",
            latent_var=latent_var,
            trajectory_var=trajectory_var,
            valid_var=valid_var,
            infer_state_info_vars_list=infer_state_info_vars_list,
            infer_old_dist_info_vars_list=infer_old_dist_info_vars_list,
        )

        return (policy_loss_inputs, policy_opt_inputs, inference_loss_inputs,
                inference_opt_inputs)
Пример #5
0
    def _build_inputs(self):
        """Decalre graph inputs variables."""
        observation_space = self.policy.observation_space
        action_space = self.policy.action_space
        policy_dist = self.policy.distribution

        with tf.name_scope('inputs'):
            obs_var = observation_space.to_tf_placeholder(
                name='obs',
                batch_dims=2)   # yapf: disable
            action_var = action_space.to_tf_placeholder(
                name='action',
                batch_dims=2)   # yapf: disable
            reward_var = tensor_utils.new_tensor(
                name='reward',
                ndim=2,
                dtype=tf.float32)   # yapf: disable
            valid_var = tensor_utils.new_tensor(
                name='valid',
                ndim=2,
                dtype=tf.float32)   # yapf: disable
            feat_diff = tensor_utils.new_tensor(
                name='feat_diff',
                ndim=2,
                dtype=tf.float32)   # yapf: disable
            param_v = tensor_utils.new_tensor(
                name='param_v',
                ndim=1,
                dtype=tf.float32)   # yapf: disable
            param_eta = tensor_utils.new_tensor(
                name='param_eta',
                ndim=0,
                dtype=tf.float32)   # yapf: disable
            policy_state_info_vars = {
                k: tf.placeholder(
                    tf.float32,
                    shape=[None] * 2 + list(shape),
                    name=k)
                for k, shape in self.policy.state_info_specs
            }   # yapf: disable
            policy_state_info_vars_list = [
                policy_state_info_vars[k]
                for k in self.policy.state_info_keys
            ]   # yapf: disable

            policy_old_dist_info_vars = {
                k: tf.placeholder(tf.float32,
                                  shape=[None] * 2 + list(shape),
                                  name='policy_old_%s' % k)
                for k, shape in policy_dist.dist_info_specs
            }
            policy_old_dist_info_vars_list = [
                policy_old_dist_info_vars[k]
                for k in policy_dist.dist_info_keys
            ]

            with tf.name_scope('flat'):
                obs_flat = flatten_batch(obs_var, name='obs_flat')
                action_flat = flatten_batch(action_var, name='action_flat')
                reward_flat = flatten_batch(reward_var, name='reward_flat')
                valid_flat = flatten_batch(valid_var, name='valid_flat')
                feat_diff_flat = flatten_batch(
                    feat_diff,
                    name='feat_diff_flat')  # yapf: disable
                policy_state_info_vars_flat = flatten_batch_dict(
                    policy_state_info_vars,
                    name='policy_state_info_vars_flat')  # yapf: disable
                policy_old_dist_info_vars_flat = flatten_batch_dict(
                    policy_old_dist_info_vars,
                    name='policy_old_dist_info_vars_flat')

            with tf.name_scope('valid'):
                reward_valid = filter_valids(
                    reward_flat,
                    valid_flat,
                    name='reward_valid')   # yapf: disable
                action_valid = filter_valids(
                    action_flat,
                    valid_flat,
                    name='action_valid')    # yapf: disable
                policy_state_info_vars_valid = filter_valids_dict(
                    policy_state_info_vars_flat,
                    valid_flat,
                    name='policy_state_info_vars_valid')
                policy_old_dist_info_vars_valid = filter_valids_dict(
                    policy_old_dist_info_vars_flat,
                    valid_flat,
                    name='policy_old_dist_info_vars_valid')

        pol_flat = graph_inputs(
            'PolicyLossInputsFlat',
            obs_var=obs_flat,
            action_var=action_flat,
            reward_var=reward_flat,
            valid_var=valid_flat,
            feat_diff=feat_diff_flat,
            policy_state_info_vars=policy_state_info_vars_flat,
            policy_old_dist_info_vars=policy_old_dist_info_vars_flat,
        )
        pol_valid = graph_inputs(
            'PolicyLossInputsValid',
            reward_var=reward_valid,
            action_var=action_valid,
            policy_state_info_vars=policy_state_info_vars_valid,
            policy_old_dist_info_vars=policy_old_dist_info_vars_valid,
        )
        policy_loss_inputs = graph_inputs(
            'PolicyLossInputs',
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            valid_var=valid_var,
            feat_diff=feat_diff,
            param_eta=param_eta,
            param_v=param_v,
            policy_state_info_vars=policy_state_info_vars,
            policy_old_dist_info_vars=policy_old_dist_info_vars,
            flat=pol_flat,
            valid=pol_valid,
        )
        policy_opt_inputs = graph_inputs(
            'PolicyOptInputs',
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            valid_var=valid_var,
            feat_diff=feat_diff,
            param_eta=param_eta,
            param_v=param_v,
            policy_state_info_vars_list=policy_state_info_vars_list,
            policy_old_dist_info_vars_list=policy_old_dist_info_vars_list,
        )
        dual_opt_inputs = graph_inputs(
            'DualOptInputs',
            reward_var=reward_var,
            valid_var=valid_var,
            feat_diff=feat_diff,
            param_eta=param_eta,
            param_v=param_v,
            policy_state_info_vars_list=policy_state_info_vars_list,
            policy_old_dist_info_vars_list=policy_old_dist_info_vars_list,
        )

        return policy_loss_inputs, policy_opt_inputs, dual_opt_inputs
Пример #6
0
    def _build_inputs(self):
        observation_space = self.policy.observation_space
        action_space = self.policy.action_space

        policy_dist = self.policy.distribution

        with tf.name_scope("inputs"):
            obs_var = observation_space.new_tensor_variable(
                name="obs", extra_dims=2)
            action_var = action_space.new_tensor_variable(
                name="action", extra_dims=2)
            reward_var = tensor_utils.new_tensor(
                name="reward", ndim=2, dtype=tf.float32)
            valid_var = tf.placeholder(
                tf.float32, shape=[None, None], name="valid")
            baseline_var = tensor_utils.new_tensor(
                name="baseline", ndim=2, dtype=tf.float32)

            policy_state_info_vars = {
                k: tf.placeholder(
                    tf.float32, shape=[None] * 2 + list(shape), name=k)
                for k, shape in self.policy.state_info_specs
            }
            policy_state_info_vars_list = [
                policy_state_info_vars[k] for k in self.policy.state_info_keys
            ]

            # old policy distribution
            policy_old_dist_info_vars = {
                k: tf.placeholder(
                    tf.float32,
                    shape=[None] * 2 + list(shape),
                    name="policy_old_%s" % k)
                for k, shape in policy_dist.dist_info_specs
            }
            policy_old_dist_info_vars_list = [
                policy_old_dist_info_vars[k]
                for k in policy_dist.dist_info_keys
            ]

            # flattened view
            with tf.name_scope("flat"):
                obs_flat = flatten_batch(obs_var, name="obs_flat")
                action_flat = flatten_batch(action_var, name="action_flat")
                reward_flat = flatten_batch(reward_var, name="reward_flat")
                valid_flat = flatten_batch(valid_var, name="valid_flat")
                policy_state_info_vars_flat = flatten_batch_dict(
                    policy_state_info_vars, name="policy_state_info_vars_flat")
                policy_old_dist_info_vars_flat = flatten_batch_dict(
                    policy_old_dist_info_vars,
                    name="policy_old_dist_info_vars_flat")

            # valid view
            with tf.name_scope("valid"):
                action_valid = filter_valids(
                    action_flat, valid_flat, name="action_valid")
                policy_state_info_vars_valid = filter_valids_dict(
                    policy_state_info_vars_flat,
                    valid_flat,
                    name="policy_state_info_vars_valid")
                policy_old_dist_info_vars_valid = filter_valids_dict(
                    policy_old_dist_info_vars_flat,
                    valid_flat,
                    name="policy_old_dist_info_vars_valid")

        # policy loss and optimizer inputs
        pol_flat = graph_inputs(
            "PolicyLossInputsFlat",
            obs_var=obs_flat,
            action_var=action_flat,
            reward_var=reward_flat,
            valid_var=valid_flat,
            policy_state_info_vars=policy_state_info_vars_flat,
            policy_old_dist_info_vars=policy_old_dist_info_vars_flat,
        )
        pol_valid = graph_inputs(
            "PolicyLossInputsValid",
            action_var=action_valid,
            policy_state_info_vars=policy_state_info_vars_valid,
            policy_old_dist_info_vars=policy_old_dist_info_vars_valid,
        )
        policy_loss_inputs = graph_inputs(
            "PolicyLossInputs",
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            baseline_var=baseline_var,
            valid_var=valid_var,
            policy_state_info_vars=policy_state_info_vars,
            policy_old_dist_info_vars=policy_old_dist_info_vars,
            flat=pol_flat,
            valid=pol_valid,
        )
        policy_opt_inputs = graph_inputs(
            "PolicyOptInputs",
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            baseline_var=baseline_var,
            valid_var=valid_var,
            policy_state_info_vars_list=policy_state_info_vars_list,
            policy_old_dist_info_vars_list=policy_old_dist_info_vars_list,
        )

        return policy_loss_inputs, policy_opt_inputs
Пример #7
0
    def _build_graph(self):

        self.policy = self.policy_cls(**self.policy_args,
                                      name='ddopg_model_policy')

        self.var_list = self.policy.get_params(trainable=True)
        self.var_shapes = [var.shape for var in self.var_list]
        self.n_params = sum(
            [shape.num_elements() for shape in self.var_shapes])

        self.policy_dist = self.policy.distribution
        self.policy_params_shapes = [
            param.shape for param in self.policy.get_params(trainable=True)
        ]

        observation_space = self.policy.observation_space
        action_space = self.policy.action_space

        self.obs_var = observation_space.new_tensor_variable(
            name='obs', extra_dims=2)  # (n_paths, H, Dy)
        self.action_var = action_space.new_tensor_variable(
            name='action', extra_dims=2)  # (n_paths, H, Du)
        self.path_return_var = tf.placeholder(tf.float32, [None],
                                              'path_return')  # (n_paths, )
        self.valid_var = tf.placeholder(tf.float32, [None, None],
                                        'valid')  # (n_paths, H)
        self.train_logps = tf.placeholder(
            dtype=tf.float32, shape=[None, None],
            name='train_logps_pre')  # (n_paths, n_train)

        self.log_std_var = tf.placeholder(dtype=tf.float32,
                                          shape=[],
                                          name='log_std')
        self.delta_var = tf.placeholder(dtype=tf.float32,
                                        shape=[],
                                        name='delta')

        self.input_vars = [
            self.obs_var, self.action_var, self.path_return_var,
            self.valid_var, self.train_logps, self.delta_var, self.log_std_var
        ]

        # Flatten observation and actions for vectorized computations
        self.obs_flat = flatten_batch(self.obs_var,
                                      name='obs_flat')  # (n_paths * H, Dy)
        self.action_flat = flatten_batch(
            self.action_var, name='action_flat')  # (n_paths * H, Du)

        # Shape of training data: (# of paths, path horizon) = (N_train, H)
        self.batch_shape = tf.shape(self.obs_var)[0:2]

        # Compute logp for all policy
        dist_info_flat = self.policy.dist_info_sym(self.obs_flat,
                                                   name='dist_info_flat')
        dist_info_flat['log_std'] = self.log_std_var * tf.ones_like(
            dist_info_flat['mean'])

        test_logp_flat = self.policy_dist.log_likelihood_sym(self.action_flat,
                                                             dist_info_flat,
                                                             name='logp_flat')
        test_logp_full = tf.reshape(test_logp_flat,
                                    self.batch_shape)  # (n_epochs, H)
        self.test_logps = tf.reduce_sum(test_logp_full * self.valid_var,
                                        axis=1)[None, :]

        self.all_logps = tf.concat(
            (
                self.train_logps,  # (n_train + n_test, n_paths)
                self.test_logps),
            axis=0)

        # Prevent exp() overflow by shifting logps
        self.logp_max = tf.reduce_max(self.all_logps, axis=0)  # (n_paths, )

        self.train_logps_0 = self.train_logps - self.logp_max  # (n_train, n_paths)
        self.test_logps_0 = self.test_logps - self.logp_max  # (n_paths, )
        self.train_liks = tf.exp(self.train_logps_0)  # (n_train, n_paths)
        self.test_liks = tf.exp(self.test_logps_0)  # (n_paths, )

        # Mean traj lik for empirical mixture distribution
        self.train_mean_liks = tf.reduce_mean(self.train_liks,
                                              axis=0) + self.eps  # (n_paths, )

        # Compute prediction for all training policies
        train_res = self._compute_prediction_vec(self.train_liks)
        self.J_train = train_res[0]
        self.J2_train = train_res[1]
        self.J_var_train = train_res[2]
        self.J_unc_train = train_res[3]
        self.w_train = train_res[4]
        self.wn_train = train_res[5]
        self.ess_train = train_res[6]

        # Compute prediction for all test policies
        test_res = self._compute_prediction_vec(self.test_liks)
        self.J_test = test_res[0]
        self.J2_test = test_res[1]
        self.J_var_test = test_res[2]
        self.J_unc_test = test_res[3]
        self.w_test = test_res[4]
        self.wn_test = test_res[5]
        self.ess_test = test_res[6]
Пример #8
0
    def _build_inputs(self):
        """Decalre graph inputs variables."""
        observation_space = self.policy.observation_space
        action_space = self.policy.action_space
        policy_dist = self.policy.distribution

        with tf.name_scope("inputs"):
            obs_var = observation_space.new_tensor_variable(
                name="obs",
                extra_dims=2)   # yapf: disable
            action_var = action_space.new_tensor_variable(
                name="action",
                extra_dims=2)   # yapf: disable
            reward_var = tensor_utils.new_tensor(
                name="reward",
                ndim=2,
                dtype=tf.float32)   # yapf: disable
            valid_var = tensor_utils.new_tensor(
                name="valid",
                ndim=2,
                dtype=tf.float32)   # yapf: disable
            feat_diff = tensor_utils.new_tensor(
                name="feat_diff",
                ndim=2,
                dtype=tf.float32)   # yapf: disable
            param_v = tensor_utils.new_tensor(
                name="param_v",
                ndim=1,
                dtype=tf.float32)   # yapf: disable
            param_eta = tensor_utils.new_tensor(
                name="param_eta",
                ndim=0,
                dtype=tf.float32)   # yapf: disable
            policy_state_info_vars = {
                k: tf.placeholder(
                    tf.float32,
                    shape=[None] * 2 + list(shape),
                    name=k)
                for k, shape in self.policy.state_info_specs
            }   # yapf: disable
            policy_state_info_vars_list = [
                policy_state_info_vars[k]
                for k in self.policy.state_info_keys
            ]   # yapf: disable

            policy_old_dist_info_vars = {
                k: tf.placeholder(tf.float32,
                                  shape=[None] * 2 + list(shape),
                                  name="policy_old_%s" % k)
                for k, shape in policy_dist.dist_info_specs
            }
            policy_old_dist_info_vars_list = [
                policy_old_dist_info_vars[k]
                for k in policy_dist.dist_info_keys
            ]

            with tf.name_scope("flat"):
                obs_flat = flatten_batch(obs_var, name="obs_flat")
                action_flat = flatten_batch(action_var, name="action_flat")
                reward_flat = flatten_batch(reward_var, name="reward_flat")
                valid_flat = flatten_batch(valid_var, name="valid_flat")
                feat_diff_flat = flatten_batch(
                    feat_diff,
                    name="feat_diff_flat")  # yapf: disable
                policy_state_info_vars_flat = flatten_batch_dict(
                    policy_state_info_vars,
                    name="policy_state_info_vars_flat")  # yapf: disable
                policy_old_dist_info_vars_flat = flatten_batch_dict(
                    policy_old_dist_info_vars,
                    name="policy_old_dist_info_vars_flat")

            with tf.name_scope("valid"):
                reward_valid = filter_valids(
                    reward_flat,
                    valid_flat,
                    name="reward_valid")   # yapf: disable
                action_valid = filter_valids(
                    action_flat,
                    valid_flat,
                    name="action_valid")    # yapf: disable
                policy_state_info_vars_valid = filter_valids_dict(
                    policy_state_info_vars_flat,
                    valid_flat,
                    name="policy_state_info_vars_valid")
                policy_old_dist_info_vars_valid = filter_valids_dict(
                    policy_old_dist_info_vars_flat,
                    valid_flat,
                    name="policy_old_dist_info_vars_valid")

        pol_flat = graph_inputs(
            "PolicyLossInputsFlat",
            obs_var=obs_flat,
            action_var=action_flat,
            reward_var=reward_flat,
            valid_var=valid_flat,
            feat_diff=feat_diff_flat,
            policy_state_info_vars=policy_state_info_vars_flat,
            policy_old_dist_info_vars=policy_old_dist_info_vars_flat,
        )
        pol_valid = graph_inputs(
            "PolicyLossInputsValid",
            reward_var=reward_valid,
            action_var=action_valid,
            policy_state_info_vars=policy_state_info_vars_valid,
            policy_old_dist_info_vars=policy_old_dist_info_vars_valid,
        )
        policy_loss_inputs = graph_inputs(
            "PolicyLossInputs",
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            valid_var=valid_var,
            feat_diff=feat_diff,
            param_eta=param_eta,
            param_v=param_v,
            policy_state_info_vars=policy_state_info_vars,
            policy_old_dist_info_vars=policy_old_dist_info_vars,
            flat=pol_flat,
            valid=pol_valid,
        )
        policy_opt_inputs = graph_inputs(
            "PolicyOptInputs",
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            valid_var=valid_var,
            feat_diff=feat_diff,
            param_eta=param_eta,
            param_v=param_v,
            policy_state_info_vars_list=policy_state_info_vars_list,
            policy_old_dist_info_vars_list=policy_old_dist_info_vars_list,
        )
        dual_opt_inputs = graph_inputs(
            "DualOptInputs",
            reward_var=reward_var,
            valid_var=valid_var,
            feat_diff=feat_diff,
            param_eta=param_eta,
            param_v=param_v,
            policy_state_info_vars_list=policy_state_info_vars_list,
            policy_old_dist_info_vars_list=policy_old_dist_info_vars_list,
        )

        return policy_loss_inputs, policy_opt_inputs, dual_opt_inputs
Пример #9
0
    def _build_policy_loss(self, i):
        pol_dist = self.policy.distribution
        policy_entropy = self._build_entropy_term(i)
        rewards = i.reward_var

        if self._maximum_entropy:
            with tf.name_scope('augmented_rewards'):
                rewards = i.reward_var + self.policy_ent_coeff * policy_entropy

        with tf.name_scope('policy_loss'):
            adv = compute_advantages(self.discount,
                                     self.gae_lambda,
                                     self.max_path_length,
                                     i.baseline_var,
                                     rewards,
                                     name='adv')

            adv_flat = flatten_batch(adv, name='adv_flat')
            adv_valid = filter_valids(adv_flat,
                                      i.flat.valid_var,
                                      name='adv_valid')

            if self.policy.recurrent:
                adv = tf.reshape(adv, [-1, self.max_path_length])

            # Optionally normalize advantages
            eps = tf.constant(1e-8, dtype=tf.float32)
            if self.center_adv:
                if self.policy.recurrent:
                    adv = center_advs(adv, axes=[0], eps=eps)
                else:
                    adv_valid = center_advs(adv_valid, axes=[0], eps=eps)

            if self.positive_adv:
                if self.policy.recurrent:
                    adv = positive_advs(adv, eps)
                else:
                    adv_valid = positive_advs(adv_valid, eps)

            if self.policy.recurrent:
                policy_dist_info = self.policy.dist_info_sym(
                    i.obs_var,
                    i.policy_state_info_vars,
                    name='policy_dist_info')
            else:
                policy_dist_info_flat = self.policy.dist_info_sym(
                    i.flat.obs_var,
                    i.flat.policy_state_info_vars,
                    name='policy_dist_info_flat')

                policy_dist_info_valid = filter_valids_dict(
                    policy_dist_info_flat,
                    i.flat.valid_var,
                    name='policy_dist_info_valid')

                policy_dist_info = policy_dist_info_valid

            # Calculate loss function and KL divergence
            with tf.name_scope('kl'):
                if self.policy.recurrent:
                    kl = pol_dist.kl_sym(
                        i.policy_old_dist_info_vars,
                        policy_dist_info,
                    )
                    pol_mean_kl = tf.reduce_sum(
                        kl * i.valid_var) / tf.reduce_sum(i.valid_var)
                else:
                    kl = pol_dist.kl_sym(
                        i.valid.policy_old_dist_info_vars,
                        policy_dist_info_valid,
                    )
                    pol_mean_kl = tf.reduce_mean(kl)

            # Calculate vanilla loss
            with tf.name_scope('vanilla_loss'):
                if self.policy.recurrent:
                    ll = pol_dist.log_likelihood_sym(i.action_var,
                                                     policy_dist_info,
                                                     name='log_likelihood')

                    vanilla = ll * adv * i.valid_var
                else:
                    ll = pol_dist.log_likelihood_sym(i.valid.action_var,
                                                     policy_dist_info_valid,
                                                     name='log_likelihood')

                    vanilla = ll * adv_valid

            # Calculate surrogate loss
            with tf.name_scope('surrogate_loss'):
                if self.policy.recurrent:
                    lr = pol_dist.likelihood_ratio_sym(
                        i.action_var,
                        i.policy_old_dist_info_vars,
                        policy_dist_info,
                        name='lr')

                    surrogate = lr * adv * i.valid_var
                else:
                    lr = pol_dist.likelihood_ratio_sym(
                        i.valid.action_var,
                        i.valid.policy_old_dist_info_vars,
                        policy_dist_info_valid,
                        name='lr')

                    surrogate = lr * adv_valid

            # Finalize objective function
            with tf.name_scope('loss'):
                if self._pg_loss == 'vanilla':
                    # VPG uses the vanilla objective
                    obj = tf.identity(vanilla, name='vanilla_obj')
                elif self._pg_loss == 'surrogate':
                    # TRPO uses the standard surrogate objective
                    obj = tf.identity(surrogate, name='surr_obj')
                elif self._pg_loss == 'surrogate_clip':
                    lr_clip = tf.clip_by_value(lr,
                                               1 - self.lr_clip_range,
                                               1 + self.lr_clip_range,
                                               name='lr_clip')
                    if self.policy.recurrent:
                        surr_clip = lr_clip * adv * i.valid_var
                    else:
                        surr_clip = lr_clip * adv_valid
                    obj = tf.minimum(surrogate, surr_clip, name='surr_obj')

                if self._entropy_regularzied:
                    obj += self.policy_ent_coeff * policy_entropy

                # Maximize E[surrogate objective] by minimizing
                # -E_t[surrogate objective]
                if self.policy.recurrent:
                    loss = -tf.reduce_sum(obj) / tf.reduce_sum(i.valid_var)
                else:
                    loss = -tf.reduce_mean(obj)

            # Diagnostic functions
            self.f_policy_kl = tensor_utils.compile_function(
                flatten_inputs(self._policy_opt_inputs),
                pol_mean_kl,
                log_name='f_policy_kl')

            self.f_rewards = tensor_utils.compile_function(
                flatten_inputs(self._policy_opt_inputs),
                rewards,
                log_name='f_rewards')

            returns = discounted_returns(self.discount, self.max_path_length,
                                         rewards)
            self.f_returns = tensor_utils.compile_function(
                flatten_inputs(self._policy_opt_inputs),
                returns,
                log_name='f_returns')

            return loss, pol_mean_kl
Пример #10
0
    def _build_inputs(self):
        observation_space = self.policy.observation_space
        action_space = self.policy.action_space

        policy_dist = self.policy.distribution

        with tf.name_scope('inputs'):
            if self.flatten_input:
                obs_var = tf.compat.v1.placeholder(
                    tf.float32,
                    shape=[None, None, observation_space.flat_dim],
                    name='obs')
            else:
                obs_var = observation_space.to_tf_placeholder(name='obs',
                                                              batch_dims=2)
            action_var = action_space.to_tf_placeholder(name='action',
                                                        batch_dims=2)
            reward_var = tensor_utils.new_tensor(name='reward',
                                                 ndim=2,
                                                 dtype=tf.float32)
            valid_var = tf.compat.v1.placeholder(tf.float32,
                                                 shape=[None, None],
                                                 name='valid')
            baseline_var = tensor_utils.new_tensor(name='baseline',
                                                   ndim=2,
                                                   dtype=tf.float32)

            policy_state_info_vars = {
                k: tf.compat.v1.placeholder(tf.float32,
                                            shape=[None] * 2 + list(shape),
                                            name=k)
                for k, shape in self.policy.state_info_specs
            }
            policy_state_info_vars_list = [
                policy_state_info_vars[k] for k in self.policy.state_info_keys
            ]

            # old policy distribution
            policy_old_dist_info_vars = {
                k: tf.compat.v1.placeholder(tf.float32,
                                            shape=[None] * 2 + list(shape),
                                            name='policy_old_%s' % k)
                for k, shape in policy_dist.dist_info_specs
            }
            policy_old_dist_info_vars_list = [
                policy_old_dist_info_vars[k]
                for k in policy_dist.dist_info_keys
            ]

            # flattened view
            with tf.name_scope('flat'):
                obs_flat = flatten_batch(obs_var, name='obs_flat')
                action_flat = flatten_batch(action_var, name='action_flat')
                reward_flat = flatten_batch(reward_var, name='reward_flat')
                valid_flat = flatten_batch(valid_var, name='valid_flat')
                policy_state_info_vars_flat = flatten_batch_dict(
                    policy_state_info_vars, name='policy_state_info_vars_flat')
                policy_old_dist_info_vars_flat = flatten_batch_dict(
                    policy_old_dist_info_vars,
                    name='policy_old_dist_info_vars_flat')

            # valid view
            with tf.name_scope('valid'):
                action_valid = filter_valids(action_flat,
                                             valid_flat,
                                             name='action_valid')
                policy_state_info_vars_valid = filter_valids_dict(
                    policy_state_info_vars_flat,
                    valid_flat,
                    name='policy_state_info_vars_valid')
                policy_old_dist_info_vars_valid = filter_valids_dict(
                    policy_old_dist_info_vars_flat,
                    valid_flat,
                    name='policy_old_dist_info_vars_valid')

        # policy loss and optimizer inputs
        pol_flat = graph_inputs(
            'PolicyLossInputsFlat',
            obs_var=obs_flat,
            action_var=action_flat,
            reward_var=reward_flat,
            valid_var=valid_flat,
            policy_state_info_vars=policy_state_info_vars_flat,
            policy_old_dist_info_vars=policy_old_dist_info_vars_flat,
        )
        pol_valid = graph_inputs(
            'PolicyLossInputsValid',
            action_var=action_valid,
            policy_state_info_vars=policy_state_info_vars_valid,
            policy_old_dist_info_vars=policy_old_dist_info_vars_valid,
        )
        policy_loss_inputs = graph_inputs(
            'PolicyLossInputs',
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            baseline_var=baseline_var,
            valid_var=valid_var,
            policy_state_info_vars=policy_state_info_vars,
            policy_old_dist_info_vars=policy_old_dist_info_vars,
            flat=pol_flat,
            valid=pol_valid,
        )
        policy_opt_inputs = graph_inputs(
            'PolicyOptInputs',
            obs_var=obs_var,
            action_var=action_var,
            reward_var=reward_var,
            baseline_var=baseline_var,
            valid_var=valid_var,
            policy_state_info_vars_list=policy_state_info_vars_list,
            policy_old_dist_info_vars_list=policy_old_dist_info_vars_list,
        )

        return policy_loss_inputs, policy_opt_inputs
Пример #11
0
    def _build_policy_loss(self, i):
        pol_dist = self.policy.distribution

        policy_entropy = self._build_entropy_term(i)

        with tf.name_scope('augmented_rewards'):
            rewards = i.reward_var + (self.policy_ent_coeff * policy_entropy)

        with tf.name_scope('policy_loss'):
            advantages = compute_advantages(self.discount,
                                            self.gae_lambda,
                                            self.max_path_length,
                                            i.baseline_var,
                                            rewards,
                                            name='advantages')

            adv_flat = flatten_batch(advantages, name='adv_flat')
            adv_valid = filter_valids(adv_flat,
                                      i.flat.valid_var,
                                      name='adv_valid')

            if self.policy.recurrent:
                advantages = tf.reshape(advantages, [-1, self.max_path_length])

            # Optionally normalize advantages
            eps = tf.constant(1e-8, dtype=tf.float32)
            if self.center_adv:
                with tf.name_scope('center_adv'):
                    mean, var = tf.nn.moments(adv_valid, axes=[0])
                    adv_valid = tf.nn.batch_normalization(
                        adv_valid, mean, var, 0, 1, eps)
            if self.positive_adv:
                with tf.name_scope('positive_adv'):
                    m = tf.reduce_min(adv_valid)
                    adv_valid = (adv_valid - m) + eps

            if self.policy.recurrent:
                policy_dist_info = self.policy.dist_info_sym(
                    i.obs_var,
                    i.policy_state_info_vars,
                    name='policy_dist_info')
            else:
                policy_dist_info_flat = self.policy.dist_info_sym(
                    i.flat.obs_var,
                    i.flat.policy_state_info_vars,
                    name='policy_dist_info_flat')

                policy_dist_info_valid = filter_valids_dict(
                    policy_dist_info_flat,
                    i.flat.valid_var,
                    name='policy_dist_info_valid')

            # Calculate loss function and KL divergence
            with tf.name_scope('kl'):
                if self.policy.recurrent:
                    kl = pol_dist.kl_sym(
                        i.policy_old_dist_info_vars,
                        policy_dist_info,
                    )
                    pol_mean_kl = tf.reduce_sum(
                        kl * i.valid_var) / tf.reduce_sum(i.valid_var)
                else:
                    kl = pol_dist.kl_sym(
                        i.valid.policy_old_dist_info_vars,
                        policy_dist_info_valid,
                    )
                    pol_mean_kl = tf.reduce_mean(kl)

            # Calculate surrogate loss
            with tf.name_scope('surr_loss'):
                if self.policy.recurrent:
                    lr = pol_dist.likelihood_ratio_sym(
                        i.action_var,
                        i.policy_old_dist_info_vars,
                        policy_dist_info,
                        name='lr')

                    surr_vanilla = lr * advantages * i.valid_var
                else:
                    lr = pol_dist.likelihood_ratio_sym(
                        i.valid.action_var,
                        i.valid.policy_old_dist_info_vars,
                        policy_dist_info_valid,
                        name='lr')

                    surr_vanilla = lr * adv_valid

                if self._pg_loss == PGLoss.VANILLA:
                    # VPG, TRPO use the standard surrogate objective
                    surr_obj = tf.identity(surr_vanilla, name='surr_obj')
                elif self._pg_loss == PGLoss.CLIP:
                    lr_clip = tf.clip_by_value(lr,
                                               1 - self.clip_range,
                                               1 + self.clip_range,
                                               name='lr_clip')
                    if self.policy.recurrent:
                        surr_clip = lr_clip * advantages * i.valid_var
                    else:
                        surr_clip = lr_clip * adv_valid
                    surr_obj = tf.minimum(surr_vanilla,
                                          surr_clip,
                                          name='surr_obj')
                else:
                    raise NotImplementedError('Unknown PGLoss')

                # Maximize E[surrogate objective] by minimizing
                # -E_t[surrogate objective]
                if self.policy.recurrent:
                    surr_loss = (-tf.reduce_sum(surr_vanilla)) / tf.reduce_sum(
                        i.valid_var)
                else:
                    surr_loss = -tf.reduce_mean(surr_obj)

            # Diagnostic functions
            self.f_policy_kl = compile_function(flatten_inputs(
                self._policy_opt_inputs),
                                                pol_mean_kl,
                                                log_name='f_policy_kl')

            self.f_rewards = compile_function(flatten_inputs(
                self._policy_opt_inputs),
                                              rewards,
                                              log_name='f_rewards')

            returns = discounted_returns(self.discount, self.max_path_length,
                                         rewards)
            self.f_returns = compile_function(flatten_inputs(
                self._policy_opt_inputs),
                                              returns,
                                              log_name='f_returns')

            return surr_loss, pol_mean_kl