Exemple #1
0
def test_on_batch(model,
                  inputs,
                  targets,
                  sample_weights=None,
                  output_loss_metrics=None):
    """Calculates the loss for one input batch.

    Args:
        model: Model whose loss has to be calculated.
        inputs: Input batch data.
        targets: Target batch data.
        sample_weights: Sample weight batch data.
        output_loss_metrics: List of metrics that are used to aggregated output
          loss values.

    Returns:
        Dict with three items:
          'total_loss': single tensor for overall loss,
          'output_losses': list of tensors for loss corresponding to each of the
            model output. Could be a empty list when model has only one output.
          'metrics': list of tensors for metric specified.
    """
    inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model)

    with backend.eager_learning_phase_scope(0):
        outs, total_loss, output_losses, masks = _model_loss(
            model,
            inputs,
            targets,
            sample_weights=sample_weights,
            training=False,
            output_loss_metrics=output_loss_metrics,
        )
    if not isinstance(outs, list):
        outs = [outs]
    metrics_results = _eager_metrics_fn(model,
                                        outs,
                                        targets,
                                        sample_weights=sample_weights,
                                        masks=masks)
    total_loss = tf.nest.flatten(total_loss)

    return {
        "total_loss": total_loss,
        "output_losses": output_losses,
        "metrics": metrics_results,
    }
Exemple #2
0
def train_on_batch(model,
                   inputs,
                   targets,
                   sample_weights=None,
                   output_loss_metrics=None):
    """Calculates the loss and gradient updates for one input batch.

  Arguments:
      model: Model whose loss has to be calculated.
      inputs: Input batch data.
      targets: Target batch data.
      sample_weights: Sample weight batch data.
      output_loss_metrics: List of metrics that are used to aggregated output
        loss values.

  Returns:
      Dict with three items:
        'total_loss': list with a single tensor for overall loss,
        'output_losses': list of tensors for loss corresponding to each of the
          model output. Could be a empty list when model has only one output.
        'metrics': list of tensors for metric specified.
  """
    inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model)
    outs, total_loss, output_losses, masks = (_process_single_batch(
        model,
        inputs,
        targets,
        sample_weights=sample_weights,
        training=True,
        output_loss_metrics=output_loss_metrics))
    if not isinstance(outs, list):
        outs = [outs]
    metrics_results = _eager_metrics_fn(model,
                                        outs,
                                        targets,
                                        sample_weights=sample_weights,
                                        masks=masks)
    total_loss = tf.nest.flatten(total_loss)
    return {
        'total_loss': total_loss,
        'output_losses': output_losses,
        'metrics': metrics_results
    }