예제 #1
0
def _is_building_keras_layer():
  # TODO(srbs): Remove this function when we no long support session with Keras.
  keras_call_context_function = keras_deps.get_call_context_function()
  if keras_call_context_function:
    return keras_call_context_function().layer is not None
  else:
    return False
예제 #2
0
  def _wrapped_model(*args):
    """A concrete tf.function that wraps the model's call function."""
    # When given a single input, Keras models will call the model on the tensor
    # rather than a list consisting of the single tensor.
    inputs = args[0] if len(input_signature) == 1 else list(args)

    with keras_deps.get_call_context_function()().enter(
        model, inputs=inputs, build_graph=False, training=False, saving=True):
      outputs = model(inputs, training=False)

    return outputs
예제 #3
0
    def _wrapped_model(*args):
        """A concrete tf.function that wraps the model's call function."""
        # When given a single input, Keras models will call the model on the tensor
        # rather than a list consisting of the single tensor.
        inputs = args[0] if len(input_signature) == 1 else list(args)

        with keras_deps.get_call_context_function()().enter(model,
                                                            inputs=inputs,
                                                            build_graph=False,
                                                            training=False,
                                                            saving=True):
            outputs = model(inputs, training=False)

        # Outputs always has to be a flat dict.
        output_names = model.output_names  # Functional Model.
        if output_names is None:  # Subclassed Model.
            output_names = create_pseudo_output_names(outputs)
        outputs = nest.flatten(outputs)
        return {name: output for name, output in zip(output_names, outputs)}