Exemple #1
0
def train_on_batch(model,
                   inputs,
                   targets,
                   sample_weights=None,
                   output_loss_metrics=None):
    """Calculates the loss and gradient updates for one input batch.

  Arguments:
      model: Model whose loss has to be calculated.
      inputs: Input batch data.
      targets: Target batch data.
      sample_weights: Sample weight batch data.
      output_loss_metrics: List of metrics that are used to aggregated output
        loss values.

  Returns:
      total loss and the loss associated with each output.
  """
    if isinstance(inputs, collections.Sequence):
        if len(inputs) and tensor_util.is_tensor(inputs[0]):
            inputs = training_utils.cast_if_floating_to_model_input_dtypes(
                inputs, model)
            if targets:
                targets = training_utils.cast_if_floating_dtype(targets)
        else:
            inputs = training_utils.cast_if_floating_to_model_input_dtypes(
                [ops.convert_to_tensor(val) for val in inputs], model)
            if targets:
                targets = training_utils.cast_if_floating_dtype(
                    [ops.convert_to_tensor(val) for val in targets])
    if sample_weights:
        sample_weights = [
            training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
            if val is not None else None for val in sample_weights
        ]

    outs, total_loss, output_losses, masks = (_process_single_batch(
        model,
        inputs,
        targets,
        sample_weights=sample_weights,
        training=True,
        output_loss_metrics=output_loss_metrics))
    if not isinstance(outs, list):
        outs = [outs]
    metrics_results = _eager_metrics_fn(model,
                                        outs,
                                        targets,
                                        sample_weights=sample_weights,
                                        masks=masks)
    total_loss = nest.flatten(total_loss)
    results = total_loss + output_losses + metrics_results

    return [_non_none_constant_value(v) for v in results]
def train_on_batch(model,
                   inputs,
                   targets,
                   sample_weights=None,
                   output_loss_metrics=None):
  """Calculates the loss and gradient updates for one input batch.

  Arguments:
      model: Model whose loss has to be calculated.
      inputs: Input batch data.
      targets: Target batch data.
      sample_weights: Sample weight batch data.
      output_loss_metrics: List of metrics that are used to aggregated output
        loss values.

  Returns:
      total loss and the loss associated with each output.
  """
  if isinstance(inputs, collections.Sequence):
    if len(inputs) and tensor_util.is_tensor(inputs[0]):
      inputs = training_utils.cast_if_floating_to_model_input_dtypes(inputs,
                                                                     model)
      if targets:
        targets = training_utils.cast_if_floating_dtype(targets)
    else:
      inputs = training_utils.cast_if_floating_to_model_input_dtypes(
          [ops.convert_to_tensor(val) for val in inputs], model)
      if targets:
        targets = training_utils.cast_if_floating_dtype(
            [ops.convert_to_tensor(val) for val in targets])
  if sample_weights:
    sample_weights = [
        training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
        if val is not None else None for val in sample_weights
    ]

  outs, total_loss, output_losses, masks = (
      _process_single_batch(
          model,
          inputs,
          targets,
          sample_weights=sample_weights,
          training=True,
          output_loss_metrics=output_loss_metrics))
  if not isinstance(outs, list):
    outs = [outs]
  metrics_results = _eager_metrics_fn(
      model, outs, targets, sample_weights=sample_weights, masks=masks)
  total_loss = nest.flatten(total_loss)
  results = total_loss + output_losses + metrics_results

  return [_non_none_constant_value(v) for v in results]