예제 #1
0
    def _wrapped_model(*args, **kwargs):
        """A concrete tf.function that wraps the model's call function."""
        (
            args,
            kwargs,
        ) = model._call_spec.set_arg_value("training",
                                           False,
                                           args,
                                           kwargs,
                                           inputs_in_args=True)

        with base_layer_utils.call_context().enter(model,
                                                   inputs=None,
                                                   build_graph=False,
                                                   training=False,
                                                   saving=True):
            outputs = model(*args, **kwargs)

        # Outputs always has to be a flat dict.
        output_names = model.output_names  # Functional Model.
        if output_names is None:  # Subclassed Model.
            from keras.engine import compile_utils

            output_names = compile_utils.create_pseudo_output_names(outputs)
        outputs = tf.nest.flatten(outputs)
        return {name: output for name, output in zip(output_names, outputs)}
예제 #2
0
  def _wrapped_model(*args, **kwargs):
    """A concrete tf.function that wraps the model's call function."""
    kwargs['training'] = False
    with base_layer_utils.call_context().enter(
        model, inputs=None, build_graph=False, training=False, saving=True):
      outputs = model(*args, **kwargs)

    # Outputs always has to be a flat dict.
    output_names = model.output_names  # Functional Model.
    if output_names is None:  # Subclassed Model.
      from keras.engine import compile_utils  # pylint: disable=g-import-not-at-top
      output_names = compile_utils.create_pseudo_output_names(outputs)
    outputs = tf.nest.flatten(outputs)
    return {name: output for name, output in zip(output_names, outputs)}
예제 #3
0
  def _wrapped_model(*args):
    """A concrete tf.function that wraps the model's call function."""
    # When given a single input, Keras models will call the model on the tensor
    # rather than a list consisting of the single tensor.
    inputs = args[0] if len(input_signature) == 1 else list(args)

    with base_layer_utils.call_context().enter(
        model, inputs=inputs, build_graph=False, training=False, saving=True):
      outputs = model(inputs, training=False)

    # Outputs always has to be a flat dict.
    output_names = model.output_names  # Functional Model.
    if output_names is None:  # Subclassed Model.
      from keras.engine import compile_utils  # pylint: disable=g-import-not-at-top
      output_names = compile_utils.create_pseudo_output_names(outputs)
    outputs = tf.nest.flatten(outputs)
    return {name: output for name, output in zip(output_names, outputs)}