def predict_on_batch(model, x): """Returns predictions for a single batch of samples. Arguments: model: The model to predict with. x: Input data. It could be: - A Numpy array (or array-like), or a list of arrays (in case the model has multiple inputs). - A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs). - A `tf.data` dataset. Returns: Numpy array(s) of predictions. Raises: ValueError: In case of mismatch between given number of inputs and expectations of the model. """ # TODO(scottzhu): Standardization should happen in the data handlers, ## not on a per batch basis in the *_on_batch methods # Validate and standardize user data. inputs, _, _ = model._standardize_user_data( x, extract_tensors_from_dataset=True) # If `model._distribution_strategy` is True, then we are in a replica context # at this point. inputs = training_utils.cast_to_model_input_dtypes(inputs, model) if isinstance(inputs, collections.Sequence): # Unwrap lists with only one input, as we do when training on batch if len(inputs) == 1: inputs = inputs[0] with backend.eager_learning_phase_scope(0): return model(inputs) # pylint: disable=not-callable
def test_on_batch(model, inputs, targets, sample_weights=None, output_loss_metrics=None): """Calculates the loss for one input batch. Arguments: model: Model whose loss has to be calculated. inputs: Input batch data. targets: Target batch data. sample_weights: Sample weight batch data. output_loss_metrics: List of metrics that are used to aggregated output loss values. Returns: Dict with three items: 'total_loss': single tensor for overall loss, 'output_losses': list of tensors for loss corresponding to each of the model output. Could be a empty list when model has only one output. 'metrics': list of tensors for metric specified. """ inputs = training_utils.cast_to_model_input_dtypes(inputs, model) with backend.eager_learning_phase_scope(0): outs, total_loss, output_losses, masks = (_model_loss( model, inputs, targets, sample_weights=sample_weights, training=False, output_loss_metrics=output_loss_metrics)) if not isinstance(outs, list): outs = [outs] metrics_results = _eager_metrics_fn(model, outs, targets, sample_weights=sample_weights, masks=masks) total_loss = nest.flatten(total_loss) return { 'total_loss': total_loss, 'output_losses': output_losses, 'metrics': metrics_results }
def test_on_batch(model, inputs, targets, sample_weights=None, output_loss_metrics=None): """Calculates the loss for one input batch. Arguments: model: Model whose loss has to be calculated. inputs: Input batch data. targets: Target batch data. sample_weights: Sample weight batch data. output_loss_metrics: List of metrics that are used to aggregated output loss values. Returns: total loss, loss and metrics associated with each output. """ if isinstance(inputs, collections.Sequence): inputs = training_utils.cast_to_model_input_dtypes(inputs, model) if targets: targets = training_utils.cast_if_floating_dtype(targets) if sample_weights: sample_weights = [ training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val)) if val is not None else None for val in sample_weights ] with backend.eager_learning_phase_scope(0): outs, total_loss, output_losses, masks = (_model_loss( model, inputs, targets, sample_weights=sample_weights, training=False, output_loss_metrics=output_loss_metrics)) if not isinstance(outs, list): outs = [outs] metrics_results = _eager_metrics_fn(model, outs, targets, sample_weights=sample_weights, masks=masks) total_loss = nest.flatten(total_loss) results = total_loss + output_losses + metrics_results return results