Ejemplo n.º 1
0
    def wrap_with_training_arg(*args, **kwargs):
        """Wrap the `wrapped_call` function, and set training argument."""
        try:
            training = call_spec.get_arg_value("training",
                                               args,
                                               kwargs,
                                               inputs_in_args=True)
        except KeyError:
            training = None

        if training is None:
            training = (default_training_value
                        or base_layer_utils.call_context().training
                        or backend.learning_phase())

        args = list(args)
        kwargs = kwargs.copy()

        def replace_training_and_call(training):
            new_args, new_kwargs = call_spec.set_arg_value("training",
                                                           training,
                                                           args,
                                                           kwargs,
                                                           inputs_in_args=True)
            return wrapped_call(*new_args, **new_kwargs)

        return control_flow_util.smart_cond(
            training,
            lambda: replace_training_and_call(True),
            lambda: replace_training_and_call(False),
        )
Ejemplo n.º 2
0
    def _wrapped_model(*args, **kwargs):
        """A concrete tf.function that wraps the model's call function."""
        (
            args,
            kwargs,
        ) = model._call_spec.set_arg_value("training",
                                           False,
                                           args,
                                           kwargs,
                                           inputs_in_args=True)

        with base_layer_utils.call_context().enter(model,
                                                   inputs=None,
                                                   build_graph=False,
                                                   training=False,
                                                   saving=True):
            outputs = model(*args, **kwargs)

        # Outputs always has to be a flat dict.
        output_names = model.output_names  # Functional Model.
        if output_names is None:  # Subclassed Model.
            from keras.engine import compile_utils

            output_names = compile_utils.create_pseudo_output_names(outputs)
        outputs = tf.nest.flatten(outputs)
        return {name: output for name, output in zip(output_names, outputs)}
Ejemplo n.º 3
0
    def wrapper(*args, **kwargs):
        """Calls method within call context."""
        layer = call_collection.layer
        training = None
        inputs = _filtered_inputs([args, kwargs])

        if (args or kwargs) and call_collection.training_arg_was_passed(
            args, kwargs
        ):
            training = call_collection.get_training_arg_value(args, kwargs)

        original_losses = _reset_layer_losses(layer)
        with base_layer_utils.call_context().enter(
            layer,
            inputs=inputs,
            build_graph=False,
            training=training,
            saving=True,
        ):
            with autocast_variable.enable_auto_cast_variables(
                layer._compute_dtype_object
            ):
                ret = method(*args, **kwargs)
        _restore_layer_losses(original_losses)
        return ret
Ejemplo n.º 4
0
  def _wrapped_model(*args, **kwargs):
    """A concrete tf.function that wraps the model's call function."""
    kwargs['training'] = False
    with base_layer_utils.call_context().enter(
        model, inputs=None, build_graph=False, training=False, saving=True):
      outputs = model(*args, **kwargs)

    # Outputs always has to be a flat dict.
    output_names = model.output_names  # Functional Model.
    if output_names is None:  # Subclassed Model.
      from keras.engine import compile_utils  # pylint: disable=g-import-not-at-top
      output_names = compile_utils.create_pseudo_output_names(outputs)
    outputs = tf.nest.flatten(outputs)
    return {name: output for name, output in zip(output_names, outputs)}
Ejemplo n.º 5
0
  def _wrapped_model(*args):
    """A concrete tf.function that wraps the model's call function."""
    # When given a single input, Keras models will call the model on the tensor
    # rather than a list consisting of the single tensor.
    inputs = args[0] if len(input_signature) == 1 else list(args)

    with base_layer_utils.call_context().enter(
        model, inputs=inputs, build_graph=False, training=False, saving=True):
      outputs = model(inputs, training=False)

    # Outputs always has to be a flat dict.
    output_names = model.output_names  # Functional Model.
    if output_names is None:  # Subclassed Model.
      from keras.engine import compile_utils  # pylint: disable=g-import-not-at-top
      output_names = compile_utils.create_pseudo_output_names(outputs)
    outputs = tf.nest.flatten(outputs)
    return {name: output for name, output in zip(output_names, outputs)}
Ejemplo n.º 6
0
    def test_Bidirectional_updates(self):
        if tf.executing_eagerly():
            self.skipTest("layer.updates is only available in graph mode.")

        with self.cached_session():
            x = keras.layers.Input(shape=(3, 2))
            x_reachable_update = x * x
            layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3))
            _ = layer(x)
            assert not layer.updates
            # TODO(b/128684069): Remove when Wrapper sublayers are __call__'d.
            with base_layer_utils.call_context().enter(layer, x, True, None):
                layer.forward_layer.add_update(x_reachable_update)
                layer.forward_layer.add_update(1)
                layer.backward_layer.add_update(x_reachable_update)
                layer.backward_layer.add_update(1)
            assert len(layer.updates) == 4
Ejemplo n.º 7
0
def wrap_layer_functions(layer, serialization_cache):
  """Returns dict of wrapped layer call function and losses in tf.functions.

  Args:
    layer: Keras Layer object.
    serialization_cache: Dictionary shared between all objects during
      serialization.

  Returns:
    A dictionary containing all keras tf.functions to serialize. See
    LayerAttributes and ModelAttributes for the list of all attributes.
  """
  # Since Sequential models may be modified in place using model.add() or
  # model.pop(), don't use saved functions.
  if (isinstance(layer, keras_load.RevivedLayer) and
      not isinstance(layer, sequential_lib.Sequential)):
    return {fn_name: getattr(layer.keras_api, fn_name, None)
            for fn_name in serialized_attributes.LayerAttributes.all_functions}

  # Reset the losses of the layer and its children. The call function in each
  # child layer is replaced with tf.functions.
  original_fns = _replace_child_layer_functions(layer, serialization_cache)
  original_losses = _reset_layer_losses(layer)

  # Wrap all the layer call and activity regularizer functions.

  # Use LayerCallCollection to ensure that all layer call functions (__call__,
  # call with losses) are traced with the same inputs.
  call_collection = LayerCallCollection(layer)
  call_fn_with_losses = call_collection.add_function(
      _wrap_call_and_conditional_losses(layer),
      '{}_layer_call_and_return_conditional_losses'.format(layer.name))
  call_fn = call_collection.add_function(
      _extract_outputs_from_fn(layer, call_fn_with_losses),
      '{}_layer_call_fn'.format(layer.name))

  fns = {'call_and_return_conditional_losses': call_fn_with_losses,
         '__call__': call_fn}

  if layer._activity_regularizer is not None:  # pylint: disable=protected-access
    fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer)
    fns['call_and_return_all_conditional_losses'] = (
        call_collection.add_function(
            _append_activity_regularizer_loss(layer,
                                              call_fn_with_losses,
                                              fns['activity_regularizer_fn']),
            '{}_layer_call_and_return_all_conditional_losses'.format(layer.name)
            ))
  else:
    fns['activity_regularizer_fn'] = None
    fns['call_and_return_all_conditional_losses'] = call_fn_with_losses

  # Manually trigger traces before restoring the overwritten functions. The
  # functions are traced within the layer call context to ensure that layer
  # functions (e.g. add_loss) behave as though running in graph mode.
  with tracing_scope():
    call_collection.trace_with_input_signature()
    with base_layer_utils.call_context().enter(
        layer, inputs=None, build_graph=True, training=None, saving=True):
      for fn in fns.values():
        if fn is not None and fn.input_signature is not None:
          if isinstance(fn, LayerCall):
            fn = fn.wrapped_call
          fn.get_concrete_function()

  # Restore overwritten functions and losses
  _restore_child_layer_functions(original_fns)
  _restore_layer_losses(original_losses)

  return fns
Ejemplo n.º 8
0
def wrap_layer_functions(layer, serialization_cache):
    """Returns dict of wrapped layer call function and losses in tf.functions.

    Args:
      layer: Keras Layer object.
      serialization_cache: Dictionary shared between all objects during
        serialization.

    Returns:
      A dictionary containing all keras tf.functions to serialize. See
      LayerAttributes and ModelAttributes for the list of all attributes.
    """
    # Since Sequential models may be modified in place using model.add() or
    # model.pop(), don't use saved functions.
    if isinstance(layer, keras_load.RevivedLayer) and not isinstance(
            layer, sequential_lib.Sequential):
        return {
            fn_name: getattr(layer.keras_api, fn_name, None)
            for fn_name in serialized_attributes.LayerAttributes.all_functions
        }

    # Reset the losses of the layer and its children. The call function in each
    # child layer is replaced with tf.functions.
    original_fns = _replace_child_layer_functions(layer, serialization_cache)
    original_losses = _reset_layer_losses(layer)

    # Wrap all the layer call and activity regularizer functions.

    # Use LayerCallCollection to ensure that all layer call functions (__call__,
    # call with losses) are traced with the same inputs.
    call_collection = LayerCallCollection(layer)
    call_fn_with_losses = call_collection.add_function(
        _wrap_call_and_conditional_losses(layer),
        "{}_layer_call_and_return_conditional_losses".format(layer.name),
        # If any of this layer's child layers use the training arg, the traced
        # call functions of this layer will have a training keyword argument. If
        # the original layer does not expect the training arg, then it will have
        # to be removed (by setting `match_layer_training_arg`).
        match_layer_training_arg=True,
    )
    call_fn = call_collection.add_function(
        _extract_outputs_from_fn(layer, call_fn_with_losses),
        "{}_layer_call_fn".format(layer.name),
        # Since `call_fn` wraps call_fn_with_losses and not the original call
        # function, `match_layer_training_arg` should be set to False.
        match_layer_training_arg=False,
    )

    fns = {
        "call_and_return_conditional_losses": call_fn_with_losses,
        "__call__": call_fn,
    }

    if (layer._activity_regularizer is not None):  # pylint: disable=protected-access
        fns["activity_regularizer_fn"] = _wrap_activity_regularizer(layer)
        fns["call_and_return_all_conditional_losses"] = call_collection.add_function(
            _append_activity_regularizer_loss(layer, call_fn_with_losses,
                                              fns["activity_regularizer_fn"]),
            "{}_layer_call_and_return_all_conditional_losses".format(
                layer.name),
            match_layer_training_arg=False,
        )
    else:
        fns["activity_regularizer_fn"] = None
        fns["call_and_return_all_conditional_losses"] = call_fn_with_losses

    # Manually trigger traces before restoring the overwritten functions. The
    # functions are traced within the layer call context to ensure that layer
    # functions (e.g. add_loss) behave as though running in graph mode.
    with tracing_scope():
        call_collection.trace_with_input_signature()
        with base_layer_utils.call_context().enter(layer,
                                                   inputs=None,
                                                   build_graph=True,
                                                   training=None,
                                                   saving=True):
            for fn in fns.values():
                if fn is not None and not isinstance(fn, LayerCall):
                    fn.get_concrete_function()

    # Restore overwritten functions and losses
    _restore_child_layer_functions(original_fns)
    _restore_layer_losses(original_losses)

    return fns