Exemple #1
0
 def call(self, y_true, y_pred):
     sample_weight = tf.gather_nd(self.class_weight,
                                  tf.cast(y_true, tf.int32))
     losses = super(SparseCategoricalCrossentropy,
                    self).call(y_true, y_pred)
     return losses_utils.compute_weighted_loss(
         losses, sample_weight, reduction=self._get_reduction())
Exemple #2
0
  def __call__(self, y_true, y_pred, sample_weight=None):
    """Invokes the `Loss` instance.

    Args:
      y_true: Ground truth values.
      y_pred: The predicted values.
      sample_weight: Optional `Tensor` whose rank is either 0, or the same rank
        as `y_true`, or is broadcastable to `y_true`. `sample_weight` acts as a
        coefficient for the loss. If a scalar is provided, then the loss is
        simply scaled by the given value. If `sample_weight` is a tensor of size
        `[batch_size]`, then the total loss for each sample of the batch is
        rescaled by the corresponding element in the `sample_weight` vector. If
        the shape of `sample_weight` matches the shape of `y_pred`, then the
        loss of each measurable element of `y_pred` is scaled by the
        corresponding value of `sample_weight`.

    Returns:
      Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
        shape as `y_true`; otherwise, it is scalar.

    Raises:
      ValueError: If the shape of `sample_weight` is invalid.
    """
    with ops.name_scope(self.name, format(self.__class__.__name__),
                        (y_pred, y_true, sample_weight)):
      losses = self.call(y_true, y_pred)
      return compute_weighted_loss(
          losses, sample_weight, reduction=self.reduction)
Exemple #3
0
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        logs.pop("lr", None)
        self.history.setdefault("lr",
                                []).append(float(self.model.optimizer.lr))
        for k, v in logs.items():
            k = "accuracy" if "accuracy" in k else k
            self.history.setdefault(k, []).append(float(v))
        for ee in self.evals:
            self.history.setdefault(ee.test_names,
                                    []).append(float(ee.cur_acc))
            self.history.setdefault(ee.test_names + "_thresh",
                                    []).append(float(ee.acc_thresh))
        for kk, vv in self.custom_obj.items():
            tt = losses_utils.compute_weighted_loss(vv())
            self.history.setdefault(kk, []).append(tt)
        if len(self.model.losses) != 0:
            regular_loss = K.sum(self.model.losses).numpy()
            self.history.setdefault("regular_loss",
                                    []).append(float(regular_loss))
            self.history["loss"][-1] -= regular_loss

        if self.initial_file:
            with open(self.initial_file, "w") as ff:
                json.dump(self.history, ff)
    def loss(self,
             labels,
             logits,
             features=None,
             mode=None,
             regularization_losses=None):
        """Returns regularized training loss. See `base_head.Head` for details."""
        del mode  # Unused for this head.
        with tf.compat.v1.name_scope('losses',
                                     values=(logits, labels,
                                             regularization_losses, features)):
            logits = base_head.check_logits_final_dim(logits,
                                                      self.logits_dimension)
            labels = self._processed_labels(logits, labels)
            unweighted_loss, weights = self._unweighted_loss_and_weights(
                logits, labels, features)
            vector_training_loss = losses_utils.compute_weighted_loss(
                unweighted_loss,
                sample_weight=weights,
                reduction=tf.keras.losses.Reduction.NONE)
            regularization_loss = tf.math.add_n(
                regularization_losses
            ) if regularization_losses is not None else None
            vector_regularized_training_loss = (
                tf.add(vector_training_loss, regularization_loss)
                if regularization_loss is not None else vector_training_loss)

        return vector_regularized_training_loss
Exemple #5
0
  def __call__(self, y_true, y_pred, sample_weight=None):
    """Invokes the `Loss` instance.

    Args:
      y_true: Ground truth values.
      y_pred: The predicted values.
      sample_weight: Optional `Tensor` whose rank is either 0, or the same rank
        as `y_true`, or is broadcastable to `y_true`. `sample_weight` acts as a
        coefficient for the loss. If a scalar is provided, then the loss is
        simply scaled by the given value. If `sample_weight` is a tensor of size
        `[batch_size]`, then the total loss for each sample of the batch is
        rescaled by the corresponding element in the `sample_weight` vector. If
        the shape of `sample_weight` matches the shape of `y_pred`, then the
        loss of each measurable element of `y_pred` is scaled by the
        corresponding value of `sample_weight`.

    Returns:
      Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
        shape as `y_true`; otherwise, it is scalar.

    Raises:
      ValueError: If the shape of `sample_weight` is invalid.
    """
    # If we are wrapping a lambda function strip '<>' from the name as it is not
    # accepted in scope name.
    scope_name = 'lambda' if self.name == '<lambda>' else self.name
    graph_ctx = tf_utils.graph_context_for_symbolic_tensors(
        y_true, y_pred, sample_weight)
    with K.name_scope(scope_name or self.__class__.__name__), graph_ctx:
      losses = self.call(y_true, y_pred)
      return losses_utils.compute_weighted_loss(
          losses, sample_weight, reduction=self._get_reduction())
Exemple #6
0
    def __call__(self, y_true, y_pred, sample_weight=None):
        """Invokes the `Loss` instance.

    Args:
      y_true: Ground truth values.
      y_pred: The predicted values.
      sample_weight: Optional `Tensor` whose rank is either 0, or the same rank
        as `y_true`, or is broadcastable to `y_true`. `sample_weight` acts as a
        coefficient for the loss. If a scalar is provided, then the loss is
        simply scaled by the given value. If `sample_weight` is a tensor of size
        `[batch_size]`, then the total loss for each sample of the batch is
        rescaled by the corresponding element in the `sample_weight` vector. If
        the shape of `sample_weight` matches the shape of `y_pred`, then the
        loss of each measurable element of `y_pred` is scaled by the
        corresponding value of `sample_weight`.

    Returns:
      Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
        shape as `y_true`; otherwise, it is scalar.

    Raises:
      ValueError: If the shape of `sample_weight` is invalid.
    """
        # If we are wrapping a lambda function strip '<>' from the name as it is not
        # accepted in scope name.
        scope_name = 'lambda' if self.name == '<lambda>' else self.name
        with ops.name_scope(scope_name, format(self.__class__.__name__),
                            (y_pred, y_true, sample_weight)):
            losses = self.call(y_true, y_pred)
            return compute_weighted_loss(losses,
                                         sample_weight,
                                         reduction=self.reduction)
Exemple #7
0
    def __call__(self, y_true, y_pred, sample_weight=None):
        """Invokes the `Loss` instance.

    Args:
      y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`
      y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`
      sample_weight: Optional `sample_weight` acts as a
        coefficient for the loss. If a scalar is provided, then the loss is
        simply scaled by the given value. If `sample_weight` is a tensor of size
        `[batch_size]`, then the total loss for each sample of the batch is
        rescaled by the corresponding element in the `sample_weight` vector. If
        the shape of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be
        broadcasted to this shape), then each loss element of `y_pred` is scaled
        by the corresponding value of `sample_weight`. (Note on`dN-1`: all loss
        functions reduce by 1 dimension, usually axis=-1.)

    Returns:
      Weighted loss float `Tensor`. If `reduction` is `NONE`, this has
        shape `[batch_size, d0, .. dN-1]`; otherwise, it is scalar. (Note `dN-1`
        because all loss functions reduce by 1 dimension, usually axis=-1.)

    Raises:
      ValueError: If the shape of `sample_weight` is invalid.
    """
        # If we are wrapping a lambda function strip '<>' from the name as it is not
        # accepted in scope name.
        scope_name = 'lambda' if self.name == '<lambda>' else self.name
        graph_ctx = tf_utils.graph_context_for_symbolic_tensors(
            y_true, y_pred, sample_weight)
        with K.name_scope(scope_name or self.__class__.__name__), graph_ctx:
            losses = self.call(y_true, y_pred)
            return losses_utils.compute_weighted_loss(
                losses, sample_weight, reduction=self._get_reduction())
Exemple #8
0
 def loss(self,
          labels,
          logits,
          features=None,
          mode=None,
          regularization_losses=None):
     """Returns regularized training loss. See `base_head.Head` for details."""
     del mode  # Unused for this head.
     with ops.name_scope('losses',
                         values=(logits, labels, regularization_losses,
                                 features)):
         logits = base_head.check_logits_final_dim(logits,
                                                   self.logits_dimension)
         processed_labels = self._processed_labels(logits, labels)
         unweighted_loss, weights = self._unweighted_loss_and_weights(
             logits, processed_labels, features)
         training_loss = losses_utils.compute_weighted_loss(
             unweighted_loss,
             sample_weight=weights,
             reduction=self._loss_reduction)
         regularization_loss = math_ops.add_n(
             regularization_losses
         ) if regularization_losses is not None else None
         regularized_training_loss = (training_loss + regularization_loss
                                      if regularization_loss is not None
                                      else training_loss)
     return regularized_training_loss
Exemple #9
0
 def __call__(self, y_true, y_pred, w, sample_weight=None):
   scope_name = 'lambda' if self.name == '<lambda>' else self.name
   graph_ctx = tf_utils.graph_context_for_symbolic_tensors(
       y_true, y_pred, w, sample_weight)
   with K.name_scope(scope_name or self.__class__.__name__), graph_ctx:
     losses = self.call(y_true, y_pred, w)
     return losses_utils.compute_weighted_loss(
         losses, sample_weight, reduction=self._get_reduction())
Exemple #10
0
 def __call__(self,
              log_probs,
              log_probs_anchor,
              advantages,
              sample_weight=None):
     losses = clipped_policy_gradient(log_probs, log_probs_anchor,
                                      advantages)
     return losses_utils.compute_weighted_loss(losses,
                                               sample_weight,
                                               reduction=self.reduction)
Exemple #11
0
 def on_epoch_end(self, epoch, logs=None):
     logs = logs or {}
     for k, v in logs.items():
         k = "accuracy" if "accuracy" in k else k
         self.history.setdefault(k, []).append(float(v))
     for ee in self.evals:
         self.history.setdefault(ee.test_names,
                                 []).append(float(ee.cur_acc))
     for kk, vv in self.custom_obj.items():
         tt = losses_utils.compute_weighted_loss(vv())
         self.history.setdefault(kk, []).append(tt)
     if self.initial_file:
         with open(self.initial_file, "w") as ff:
             json.dump(self.history, ff)
Exemple #12
0
def compute_weighted_loss(losses, sample_weight=None, reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE):
    if distribution_strategy_context.has_strategy() and \
            reduction in {tf.keras.losses.Reduction.AUTO, tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE}:
        raise ValueError(
            'Please use `tf.keras.losses.Reduction.SUM` or  `tf.keras.losses.Reduction.NONE` for loss reduction when '
            'losses are used with `tf.distribute.Strategy` outside of the built-in training loops. You can implement '
            '`tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` using global batch size like:\n'
            '```\n'
            'with strategy.scope():\n'
            '    loss_obj = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)\n'
            '....\n'
            '    loss = tf.reduce_sum(loss_obj(labels, predictions)) * (1. / global_batch_size)\n'
            '```\n'
            'Please see https://www.tensorflow.org/tutorials/distribute/custom_training for more details.')

    return losses_utils.compute_weighted_loss(losses, sample_weight=sample_weight, reduction=reduction)
Exemple #13
0
 def __call__(self, y_true, y_pred, sample_weight=None):
     """See _RankingLoss."""
     losses, sample_weight = self._loss.compute_per_list(
         y_true, y_pred, sample_weight)
     return losses_utils.compute_weighted_loss(
         losses, sample_weight, reduction=self._get_reduction())
Exemple #14
0
 def __call__(self, entropy, sample_weight=None):
     losses = policy_entropy(entropy)
     return losses_utils.compute_weighted_loss(losses,
                                               sample_weight,
                                               reduction=self.reduction)
Exemple #15
0
def _model_loss(model,
                inputs,
                targets,
                output_loss_metrics=None,
                sample_weights=None,
                training=False):
    """Calculates the loss for a given model.

  Args:
      model: The model on which metrics are being calculated.
      inputs: Either a dictionary of inputs to the model or a list of input
        arrays.
      targets: List of target arrays.
      output_loss_metrics: List of metrics that are used to aggregated output
        loss values.
      sample_weights: Optional list of sample weight arrays.
      training: Whether the model should be run in inference or training mode.

  Returns:
     Returns the model output, total loss, loss value calculated using the
     specified loss function and masks for each output. The total loss includes
     regularization losses and applies masking and sample weighting
     to the loss value.
  """
    # TODO(psv): Dedup code here with graph mode prepare_total_loss() fn.
    # Used to keep track of the total loss value (stateless).
    # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) +
    #                   loss_weight_2 * output_2_loss_fn(...) +
    #                   layer losses.
    total_loss = 0
    kwargs = {}
    if model._expects_training_arg:
        kwargs['training'] = training
    if len(inputs) == 1 and not isinstance(inputs, dict):
        inputs = inputs[0]

    # Allow mixed `NumPy` and `EagerTensor` input here.
    if any(
            isinstance(input_t, (np.ndarray, float, int))
            for input_t in nest.flatten(inputs)):
        inputs = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch,
                                    inputs)

    outs = model(inputs, **kwargs)
    outs = nest.flatten(outs)

    if targets:
        targets = training_utils_v1.cast_if_floating_dtype_and_mismatch(
            targets, outs)
    # TODO(sallymatson/psv): check if we should do same mismatch fix for weights
    if sample_weights:
        sample_weights = [
            training_utils_v1.cast_if_floating_dtype(
                ops.convert_to_tensor_v2_with_dispatch(val))
            if val is not None else None for val in sample_weights
        ]

    masks = [getattr(t, '_keras_mask', None) for t in outs]
    targets = nest.flatten(targets)

    # Used to keep track of individual output losses.
    output_losses = []

    with backend.name_scope('loss'):
        loss_fns = [
            loss_fn for loss_fn in model.loss_functions if loss_fn is not None
        ]
        custom_losses = model.losses  # Regularization losses

        if not loss_fns and not custom_losses:
            if training:
                raise ValueError('The model cannot be trained '
                                 'because it has no loss to optimize.')
            else:
                raise ValueError('The model cannot be evaluated '
                                 'because it has no loss to compute.')

        for i, loss_fn in enumerate(loss_fns):
            weights = sample_weights[i] if sample_weights else None
            mask = masks[i]
            with backend.name_scope(model.output_names[i] + '_loss'):
                if mask is not None:
                    mask = math_ops.cast(mask, outs[i].dtype)
                    # Update weights with mask.
                    if weights is None:
                        weights = mask
                    else:
                        # Update dimensions of weights to match with mask if possible.
                        weights = math_ops.cast(weights, outs[i].dtype)
                        mask, _, weights = (
                            losses_utils.squeeze_or_expand_dimensions(
                                mask, sample_weight=weights))
                        weights *= mask

                if hasattr(loss_fn, 'reduction'):
                    per_sample_losses = loss_fn.call(targets[i], outs[i])
                    weighted_losses = losses_utils.compute_weighted_loss(
                        per_sample_losses,
                        sample_weight=weights,
                        reduction=losses_utils.ReductionV2.NONE)
                    loss_reduction = loss_fn.reduction

                    # `AUTO` loss reduction defaults to `SUM_OVER_BATCH_SIZE` for all
                    # compile use cases.
                    if loss_reduction == losses_utils.ReductionV2.AUTO:
                        loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE

                    # Compute the stateless loss value.
                    output_loss = losses_utils.reduce_weighted_loss(
                        weighted_losses, reduction=loss_reduction)
                else:
                    # Compute the stateless loss value for a custom loss class.
                    # Here we assume that the class takes care of loss reduction
                    # because if this class returns a vector value we cannot
                    # differentiate between use case where a custom optimizer
                    # expects a vector loss value vs unreduced per-sample loss value.
                    output_loss = loss_fn(targets[i],
                                          outs[i],
                                          sample_weight=weights)
                    loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE

            # If the number of outputs is 1 then we don't append the loss metric
            # associated with each model output. When there are multiple outputs
            # associated with a model, each output's loss is calculated and returned
            # as part of the loss_metrics.
            if len(model.outputs) > 1:
                # Keep track of the stateful output loss result.
                output_losses.append(output_loss_metrics[i](output_loss))

            # Scale output loss for distribution. For custom losses we assume
            # reduction was mean.
            if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE:
                output_loss = losses_utils.scale_loss_for_distribution(
                    output_loss)
            total_loss += model._loss_weights_list[i] * output_loss

        # Add regularization losses
        if custom_losses:
            total_loss += losses_utils.scale_loss_for_distribution(
                math_ops.add_n(custom_losses))
    return outs, total_loss, output_losses, masks
Exemple #16
0
    def _batch_train(self, batch):
        with tf.GradientTape() as tape:
            # forward passes
            outputs = self.model.get_training_outputs(inputs=batch,
                                                      training=True,
                                                      reset_state=True)
            log_probs = outputs["log_probs"]
            entropy = outputs["entropy"]
            values = outputs["values"]

            move_hot = tf.one_hot(batch["actions"][..., 0],
                                  depth=self.model.action_space.nvec[0])
            grasp_hot = tf.one_hot(batch["actions"][..., 1],
                                   depth=self.model.action_space.nvec[1])

            # losses
            policy_loss = self.policy_loss_fn(
                log_probs=log_probs,
                log_probs_anchor=batch["log_probs_anchor"],
                advantages=batch["advantages"],
                sample_weight=batch["weights"],
            )
            value_loss = self.value_loss_fn(
                y_pred=values[..., None],
                y_true=batch["returns"][..., None],
                sample_weight=batch["weights"] * self.params.value_coef,
            )
            entropy_loss = self.entropy_loss_fn(
                entropy=entropy,
                sample_weight=batch["weights"] * self.params.entropy_coef,
            )
            forward_loss = self.forward_loss_fn(
                y_pred=outputs["embedding_next_pred"],
                y_true=outputs["embedding_next"],
                sample_weight=batch["weights"] * self.params.forward_coef,
            )
            inverse_move_loss = self.inverse_move_loss_fn(
                y_pred=outputs["move_pred"],
                y_true=move_hot,
                sample_weight=batch["weights"],
            )
            inverse_grasp_loss = self.inverse_grasp_loss_fn(
                y_pred=outputs["grasp_pred"],
                y_true=grasp_hot,
                sample_weight=batch["weights"],
            )
            inverse_loss = (inverse_move_loss +
                            inverse_grasp_loss) * self.params.inverse_coef
            intrinsic_loss = (forward_loss +
                              inverse_loss) * self.params.intrinsic_coef
            regularization_loss = tf.add_n([
                tf.nn.l2_loss(tvar) * self.params.l2_coef
                for tvar in self.model.trainable_variables
            ])
            loss = self._compute_loss(
                policy_loss=policy_loss,
                value_loss=value_loss,
                entropy_loss=entropy_loss,
                intrinsic_loss=intrinsic_loss,
                regularization_loss=regularization_loss,
            )

        # compute gradients
        grads = tape.gradient(loss, self.model.trainable_variables)
        if self.params.grad_clipping is not None:
            grads_clipped, _ = tf.clip_by_global_norm(
                grads, self.params.grad_clipping)
        else:
            grads_clipped = grads
        grads_and_vars = zip(grads_clipped, self.model.trainable_variables)

        # optimization
        self.optimizer.apply_gradients(grads_and_vars)

        # summaries
        entropy_mean = losses_utils.compute_weighted_loss(
            losses=entropy, sample_weight=batch["weights"])
        gradient_norm = tf.linalg.global_norm(grads)
        clipped_gradient_norm = tf.linalg.global_norm(grads_clipped)

        tf.summary.scalar("loss/policy",
                          policy_loss,
                          step=self.optimizer.iterations)
        tf.summary.scalar("loss/value",
                          value_loss,
                          step=self.optimizer.iterations)
        tf.summary.scalar("loss/entropy",
                          entropy_loss,
                          step=self.optimizer.iterations)
        tf.summary.scalar("loss/forward",
                          forward_loss,
                          step=self.optimizer.iterations)
        tf.summary.scalar("loss/inverse",
                          inverse_loss,
                          step=self.optimizer.iterations)
        tf.summary.scalar("loss/inverse/move",
                          inverse_move_loss,
                          step=self.optimizer.iterations)
        tf.summary.scalar("loss/inverse/grasp",
                          inverse_grasp_loss,
                          step=self.optimizer.iterations)
        tf.summary.scalar("loss/intrinsic",
                          intrinsic_loss,
                          step=self.optimizer.iterations)
        tf.summary.scalar("loss/regularization",
                          regularization_loss,
                          step=self.optimizer.iterations)
        tf.summary.scalar("gradient_norm",
                          gradient_norm,
                          step=self.optimizer.iterations)
        tf.summary.scalar(
            "gradient_norm/clipped",
            clipped_gradient_norm,
            step=self.optimizer.iterations,
        )
        tf.summary.scalar("entropy",
                          entropy_mean,
                          step=self.optimizer.iterations)
Exemple #17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--job-dir", required=True, help="Job directory")
    parser.add_argument(
        "--render", action="store_true", help="Enable evaluation render"
    )
    parser.add_argument("--seed", default=42, type=int, help="Random seed")
    parser.add_argument("--env", default="Pendulum-v0", help="Env name")
    args, _ = parser.parse_known_args()
    print("args:", args)

    # make job dir
    os.makedirs(args.job_dir, exist_ok=True)

    # params
    params = HyperParams()
    params_path = os.path.join(args.job_dir, "params.json")
    params.save(params_path)
    print("params:", params)

    # environment
    env = pyrl.wrappers.Batch(
        lambda batch_id: gym.make(args.env), batch_size=params.episodes
    )
    atexit.register(env.close)

    # seeding
    env.seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    tf.random.set_seed(args.seed)

    # optimization
    optimizer = tf.keras.optimizers.Adam(learning_rate=params.learning_rate)

    # models
    policy = Policy(
        observation_space=env.observation_space,
        action_space=env.action_space,
        scale=params.scale,
    )
    baseline = Value(observation_space=env.observation_space)

    # strategies
    exploration_strategy = pyrl.strategies.Sample(policy)
    inference_strategy = pyrl.strategies.Mode(policy)

    # normalization
    rewards_moments = pynr.moments.ExponentialMovingMoments(
        shape=(), rate=params.reward_decay
    )

    # checkpoints
    checkpoint = tf.train.Checkpoint(
        optimizer=optimizer,
        policy=policy,
        baseline=baseline,
        rewards_moments=rewards_moments,
    )
    checkpoint_path = tf.train.latest_checkpoint(args.job_dir)
    if checkpoint_path is not None:
        checkpoint.restore(checkpoint_path)

    # summaries
    summary_writer = tf.summary.create_file_writer(
        args.job_dir, max_queue=100, flush_millis=5 * 60 * 1000
    )
    summary_writer.set_as_default()

    # rollouts
    rollout = BatchRollout(env, max_episode_steps=env.spec.max_episode_steps)

    # prime models
    # NOTE: TF eager does not initialize weights until they're called
    mock_states = tf.zeros(
        shape=(1, 1, env.observation_space.shape[0]), dtype=np.float32
    )
    policy(mock_states, training=False)

    advantages_fn = pyrl.targets.GeneralizedAdvantages(
        discount_factor=params.discount_factor,
        lambda_factor=params.lambda_factor,
        normalize=True,
    )
    returns_fn = pyrl.targets.DiscountedReturns(discount_factor=params.discount_factor)

    value_loss_fn = tf.losses.MeanSquaredError()
    policy_loss_fn = pyrl.losses.ClippedPolicyGradient(
        epsilon_clipping=params.epsilon_clipping
    )
    entropy_loss_fn = pyrl.losses.PolicyEntropy()

    # training iterations
    with trange(params.train_iters) as pbar:
        for it in pbar:
            # sample training transitions
            states, actions, rewards, weights = rollout(
                policy=exploration_strategy, episodes=params.episodes
            )
            episodic_reward = tf.reduce_mean(tf.reduce_sum(rewards, axis=-1))

            rewards_moments(rewards, sample_weight=weights, training=True)
            rewards_norm = pynr.math.normalize(
                rewards,
                loc=rewards_moments.mean,
                scale=rewards_moments.std,
                sample_weight=weights,
            )

            values = baseline(states, training=False)

            # targets
            advantages = advantages_fn(
                rewards=rewards_norm, values=values, sample_weight=weights
            )
            returns = returns_fn(rewards=rewards_norm, sample_weight=weights)

            policy_anchor_dist = policy(states, training=False)
            log_probs_anchor = policy_anchor_dist.log_prob(actions)

            tf.summary.scalar(
                "rewards/train/mean", rewards_moments.mean, step=optimizer.iterations
            )
            tf.summary.scalar(
                "rewards/train/std", rewards_moments.std, step=optimizer.iterations
            )
            tf.summary.scalar(
                "rewards/train", episodic_reward, step=optimizer.iterations
            )

            tf.summary.histogram("states", states, step=optimizer.iterations)
            tf.summary.histogram("actions", actions, step=optimizer.iterations)
            tf.summary.histogram("rewards", rewards, step=optimizer.iterations)
            tf.summary.histogram(
                "rewards/norm", rewards_norm, step=optimizer.iterations
            )
            tf.summary.histogram("advantages", advantages, step=optimizer.iterations)
            tf.summary.histogram("returns", returns, step=optimizer.iterations)
            tf.summary.histogram("values", values, step=optimizer.iterations)

            # training epochs
            for epoch in range(params.epochs):
                with tf.GradientTape() as tape:
                    # forward passes
                    policy_dist = policy(states, training=True)
                    values = baseline(states, training=True)

                    entropy = policy_dist.entropy()
                    log_probs = policy_dist.log_prob(actions)

                    # losses
                    policy_loss = policy_loss_fn(
                        log_probs=log_probs,
                        log_probs_anchor=log_probs_anchor,
                        advantages=advantages,
                        sample_weight=weights,
                    )
                    value_loss = value_loss_fn(
                        y_pred=values[..., None],
                        y_true=returns[..., None],
                        sample_weight=weights[..., None] * params.value_coef,
                    )
                    entropy_loss = entropy_loss_fn(
                        entropy=entropy, sample_weight=weights * params.entropy_coef
                    )
                    loss = policy_loss + value_loss + entropy_loss

                # optimization
                trainable_variables = (
                    policy.trainable_variables + baseline.trainable_variables
                )
                grads = tape.gradient(loss, trainable_variables)
                if params.grad_clipping is not None:
                    grads, _ = tf.clip_by_global_norm(grads, params.grad_clipping)
                grads_and_vars = zip(grads, trainable_variables)
                optimizer.apply_gradients(grads_and_vars)

                # summaries
                kl = tf.reduce_mean(
                    tfp.distributions.kl_divergence(policy_dist, policy_anchor_dist)
                )
                entropy_mean = losses_utils.compute_weighted_loss(
                    losses=entropy, sample_weight=weights
                )
                gradient_norm = tf.linalg.global_norm(grads)

                tf.summary.scalar(
                    "losses/entropy", entropy_loss, step=optimizer.iterations
                )
                tf.summary.scalar(
                    "losses/policy", policy_loss, step=optimizer.iterations
                )
                tf.summary.scalar("losses/value", value_loss, step=optimizer.iterations)
                tf.summary.scalar("losses/loss", loss, step=optimizer.iterations)
                tf.summary.scalar("entropy", entropy_mean, step=optimizer.iterations)
                tf.summary.scalar("kl", kl, step=optimizer.iterations)
                tf.summary.scalar(
                    "gradient_norm", gradient_norm, step=optimizer.iterations
                )

            # evaluation
            if it % params.eval_interval == params.eval_interval - 1:
                states, actions, rewards, weights = rollout(
                    policy=inference_strategy,
                    episodes=params.episodes,
                    render=args.render,
                )
                episodic_reward = tf.reduce_mean(tf.reduce_sum(rewards, axis=-1))
                pbar.set_description("reward: {:.4f}".format(episodic_reward.numpy()))

                tf.summary.scalar(
                    "rewards/eval", episodic_reward, step=optimizer.iterations
                )

            # save checkpoint
            checkpoint_prefix = os.path.join(args.job_dir, "ckpt")
            checkpoint.save(file_prefix=checkpoint_prefix)
def _model_loss(model,
                inputs,
                targets,
                output_loss_metrics=None,
                sample_weights=None,
                training=False):
  """Calculates the loss for a given model.

  Arguments:
      model: The model on which metrics are being calculated.
      inputs: Either a dictionary of inputs to the model or a list of input
        arrays.
      targets: List of target arrays.
      output_loss_metrics: List of metrics that are used to aggregated output
        loss values.
      sample_weights: Optional list of sample weight arrays.
      training: Whether the model should be run in inference or training mode.

  Returns:
     Returns the model output, total loss, loss value calculated using the
     specified loss function and masks for each output. The total loss includes
     regularization losses and applies masking and sample weighting
     to the loss value.
  """
  # TODO(psv): Dedup code here with graph mode prepare_total_loss() fn.
  # Used to keep track of the total loss value (stateless).
  # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) +
  #                   loss_weight_2 * output_2_loss_fn(...) +
  #                   layer losses.
  total_loss = 0
  kwargs = {}
  if model._expects_training_arg:
    kwargs['training'] = training
  if len(inputs) == 1 and not isinstance(inputs, dict):
    inputs = inputs[0]

  # Allow mixed `NumPy` and `EagerTensor` input here.
  if any(
      isinstance(input_t, (np.ndarray, float, int))
      for input_t in nest.flatten(inputs)):
    inputs = nest.map_structure(ops.convert_to_tensor, inputs)

  outs = model(inputs, **kwargs)

  outs = nest.flatten(outs)
  masks = [getattr(t, '_keras_mask', None) for t in outs]
  targets = nest.flatten(targets)

  # Used to keep track of individual output losses.
  output_losses = []

  with backend.name_scope('loss'):
    loss_fns = [
        loss_fn for loss_fn in model.loss_functions if loss_fn is not None
    ]
    for i, loss_fn in enumerate(loss_fns):
      weights = sample_weights[i] if sample_weights else None
      mask = masks[i]
      with backend.name_scope(model.output_names[i] + '_loss'):
        if mask is not None:
          mask = math_ops.cast(mask, outs[i].dtype)
          # Update weights with mask.
          if weights is None:
            weights = mask
          else:
            # Update dimensions of weights to match with mask if possible.
            mask, _, weights = (
                losses_utils.squeeze_or_expand_dimensions(mask, None, weights))
            weights *= mask

        weighted_losses = None
        if hasattr(loss_fn, 'reduction'):
          per_sample_losses = loss_fn.call(targets[i], outs[i])
          weighted_losses = losses_utils.compute_weighted_loss(
              per_sample_losses,
              sample_weight=weights,
              reduction=losses_utils.ReductionV2.NONE)
          loss_reduction = loss_fn.reduction

          # `AUTO` loss reduction defaults to `SUM_OVER_BATCH_SIZE` for all
          # compile use cases.
          if loss_reduction == losses_utils.ReductionV2.AUTO:
            loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE

          # Compute the stateless loss value.
          output_loss = losses_utils.reduce_weighted_loss(
              weighted_losses, reduction=loss_reduction)
          if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE:
            output_loss = losses_utils.scale_loss_for_distribution(output_loss)
        else:
          # Compute the stateless loss value for a custom loss class.
          # Here we assume that the class takes care of loss reduction
          # because if this class returns a vector value we cannot
          # differentiate between use case where a custom optimizer
          # expects a vector loss value vs unreduced per-sample loss value.
          output_loss = loss_fn(targets[i], outs[i], sample_weight=weights)
          # For custom losses we assume reduction was mean.
          output_loss = losses_utils.scale_loss_for_distribution(output_loss)

      # If the number of outputs is 1 then we don't append the loss metric
      # associated with each model output. When there are multiple outputs
      # associated with a model, each output's loss is calculated and returned
      # as part of the loss_metrics.
      if len(model.outputs) > 1:
        # Keep track of the stateful output loss result.
        output_losses.append(output_loss_metrics[i](output_loss))

      total_loss += model._loss_weights_list[i] * output_loss

    # Add regularization losses
    custom_losses = model.losses
    if custom_losses:
      total_loss += losses_utils.scale_loss_for_distribution(
          math_ops.add_n(custom_losses))

  return outs, total_loss, output_losses, masks