Example #1
0
 def wrapper(self, *args, **kwargs):
   layer = self.call_collection.layer
   original_losses = _reset_layer_losses(layer)
   with base_layer_utils.call_context().enter(layer, None, True, None):
     ret = method(self, *args, **kwargs)
   _restore_layer_losses(original_losses)
   return ret
Example #2
0
  def _wrapped_model(*args):
    """A concrete tf.function that wraps the model's call function."""
    # When given a single input, Keras models will call the model on the tensor
    # rather than a list consisting of the single tensor.
    # inputs = args[0] if len(input_signature) == 1 else list(args)
    inputs = args[0] if len(args) == 1 else list(args)

    with base_layer_utils.call_context().enter(
        model, inputs=inputs, build_graph=False, training=False, saving=True):
      return model(*args, training=False)
Example #3
0
 def wrapper(self, *args, **kwargs):
   """Calls method within call context."""
   layer = self.call_collection.layer
   training = None
   # pylint: disable=protected-access
   if (args or kwargs) and layer._call_arg_was_passed(
       'training', args, kwargs, inputs_in_args=True):
     training = layer._get_call_arg_value(
         'training', args, kwargs, inputs_in_args=True)
   # pylint: enable=protected-access
   original_losses = _reset_layer_losses(layer)
   with base_layer_utils.call_context().enter(layer, None, True, training):
     ret = method(self, *args, **kwargs)
   _restore_layer_losses(original_losses)
   return ret
Example #4
0
  def _wrapped_model(*args):
    """A concrete tf.function that wraps the model's call function."""
    # When given a single input, Keras models will call the model on the tensor
    # rather than a list consisting of the single tensor.
    inputs = args[0] if len(input_signature) == 1 else list(args)

    with base_layer_utils.call_context().enter(
        model, inputs=inputs, build_graph=False, training=False, saving=True):
      outputs_list = nest.flatten(model(inputs=inputs, training=False))

    try:
      output_names = model.output_names
    except AttributeError:
      from tensorflow.python.keras.engine import training_utils  # pylint: disable=g-import-not-at-top
      output_names = training_utils.generic_output_names(outputs_list)
    return {name: output for name, output in zip(output_names, outputs_list)}
Example #5
0
  def _wrapped_model(*args):
    """A concrete tf.function that wraps the model's call function."""
    # When given a single input, Keras models will call the model on the tensor
    # rather than a list consisting of the single tensor.
    inputs = args[0] if len(input_signature) == 1 else list(args)

    with base_layer_utils.call_context().enter(
        model, inputs=inputs, build_graph=False, training=False, saving=True):
      outputs = model(inputs, training=False)

    # Outputs always has to be a flat dict.
    output_names = model.output_names  # Functional Model.
    if output_names is None:  # Subclassed Model.
      from tensorflow.python.keras.engine import compile_utils  # pylint: disable=g-import-not-at-top
      output_names = compile_utils.create_pseudo_output_names(outputs)
    outputs = nest.flatten(outputs)
    return {name: output for name, output in zip(output_names, outputs)}
Example #6
0
  def test_Bidirectional_updates(self):
    if context.executing_eagerly():
      self.skipTest('layer.updates is only available in graph mode.')

    with self.cached_session():
      x = keras.layers.Input(shape=(3, 2))
      x_reachable_update = x * x
      layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3))
      _ = layer(x)
      assert not layer.updates
      # TODO(b/128684069): Remove when Wrapper sublayers are __call__'d.
      with base_layer_utils.call_context().enter(layer, x, True, None):
        layer.forward_layer.add_update(x_reachable_update, inputs=x)
        layer.forward_layer.add_update(1, inputs=None)
        layer.backward_layer.add_update(x_reachable_update, inputs=x)
        layer.backward_layer.add_update(1, inputs=None)
      assert len(layer.updates) == 4
Example #7
0
 def wrapper(*args, **kwargs):
   """Calls method within call context."""
   layer = call_collection.layer
   training = None
   inputs = call_collection.get_input_arg_value(args, kwargs)
   # pylint: disable=protected-access
   if (args or kwargs) and call_collection.training_arg_was_passed(
       args, kwargs):
     training = call_collection.get_training_arg_value(args, kwargs)
   # pylint: enable=protected-access
   original_losses = _reset_layer_losses(layer)
   with base_layer_utils.call_context().enter(
       layer, inputs=inputs, build_graph=False, training=training,
       saving=True):
     with base_layer_utils.autocast_context_manager(layer._compute_dtype):  # pylint: disable=protected-access
       ret = method(*args, **kwargs)
   _restore_layer_losses(original_losses)
   return ret
Example #8
0
 def wrapper(*args, **kwargs):
     """Calls method within call context."""
     layer = call_collection.layer
     training = None
     inputs = None
     # pylint: disable=protected-access
     if (args or kwargs) and call_collection.training_arg_was_passed(
             args, kwargs):
         inputs = args[0]
         training = call_collection.get_training_arg_value(args, kwargs)
     # pylint: enable=protected-access
     original_losses = _reset_layer_losses(layer)
     with base_layer_utils.call_context().enter(layer,
                                                inputs=inputs,
                                                build_graph=False,
                                                training=training):
         ret = method(*args, **kwargs)
     _restore_layer_losses(original_losses)
     return ret
Example #9
0
def wrap_layer_functions(layer, serialization_cache):
    """Returns dict of wrapped layer call function and losses in tf.functions.

  Args:
    layer: Keras Layer object.
    serialization_cache: Dictionary shared between all objects during
      serialization.

  Returns:
    A dictionary containing all keras tf.functions to serialize. See
    LayerAttributes and ModelAttributes for the list of all attributes.
  """
    # Since Sequential models may be modified in place using model.add() or
    # model.pop(), don't use saved functions.
    if (isinstance(layer, keras_load.RevivedLayer)
            and not isinstance(layer, keras_load.RevivedSequential)):
        return {
            fn_name: getattr(layer.keras_api, fn_name, None)
            for fn_name in serialized_attributes.LayerAttributes.all_functions
        }

    # Reset the losses of the layer and its children. The call function in each
    # child layer is replaced with tf.functions.
    original_fns = _replace_child_layer_functions(layer, serialization_cache)
    original_losses = _reset_layer_losses(layer)

    # Wrap all the layer call and activity regularizer functions.

    # Use LayerCallCollection to ensure that all layer call functions (__call__,
    # call with losses) are traced with the same inputs.
    call_collection = LayerCallCollection(layer)
    call_fn_with_losses = call_collection.add_function(
        _wrap_call_and_conditional_losses(layer),
        '{}_layer_call_and_return_conditional_losses'.format(layer.name))
    call_fn = call_collection.add_function(
        _extract_outputs_from_fn(layer, call_fn_with_losses),
        '{}_layer_call_fn'.format(layer.name))

    fns = {
        'call_and_return_conditional_losses': call_fn_with_losses,
        '__call__': call_fn
    }

    if layer.activity_regularizer is not None:
        fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer)
        fns['call_and_return_all_conditional_losses'] = (
            call_collection.add_function(
                _append_activity_regularizer_loss(
                    layer, call_fn_with_losses,
                    fns['activity_regularizer_fn']),
                '{}_layer_call_and_return_all_conditional_losses'.format(
                    layer.name)))
    else:
        fns['activity_regularizer_fn'] = None
        fns['call_and_return_all_conditional_losses'] = call_fn_with_losses

    # Manually trigger traces before restoring the overwritten functions. The
    # functions are traced within the layer call context to ensure that layer
    # functions (e.g. add_loss) behave as though running in graph mode.
    with base_layer_utils.call_context().enter(layer,
                                               inputs=None,
                                               build_graph=True,
                                               training=None,
                                               saving=True):
        for fn in fns.values():
            if fn is not None and fn.input_signature is not None:
                fn.get_concrete_function()

    # Restore overwritten functions and losses
    _restore_child_layer_functions(original_fns)
    _restore_layer_losses(original_losses)

    return fns
Example #10
0
def _is_building_keras_layer():
    return base_layer_utils.call_context().layer is not None
Example #11
0
    def inference(self, inputs, *args, **kwargs):

        call_context = base_layer_utils.call_context()
        input_list = nest.flatten(inputs)

        # We will attempt to build a TF graph if & only if all inputs are symbolic.
        # This is always the case in graph mode. It can also be the case in eager
        # mode when all inputs can be traced back to `keras.Input()` (when building
        # models using the functional API).
        build_graph = tf_utils.are_all_symbolic_tensors(input_list)

        # Accept NumPy and scalar inputs by converting to Tensors.
        if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
            def _convert_non_tensor(x):
                # Don't call `ops.convert_to_tensor` on all `inputs` because
                # `SparseTensors` can't be converted to `Tensor`.
                if isinstance(x, (np.ndarray, float, int)):
                    return ops.convert_to_tensor(x)
                return x
            inputs = nest.map_structure(_convert_non_tensor, inputs)
            input_list = nest.flatten(inputs)

        # Handle `mask` propagation from previous layer to current layer. Masks can
        # be propagated explicitly via the `mask` argument, or implicitly via
        # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed
        # explicitly take priority.
        mask_arg_passed_by_framework = False
        input_masks = self._collect_input_masks(inputs, args, kwargs)
        if (self._expects_mask_arg and input_masks is not None and
                not self._call_arg_was_passed('mask', args, kwargs)):
            mask_arg_passed_by_framework = True
            kwargs['mask'] = input_masks

        # If `training` argument was not explicitly passed, propagate `training`
        # value from this layer's calling layer.
        training_arg_passed_by_framework = False
        # Priority 1: `training` was explicitly passed.
        if self._call_arg_was_passed('training', args, kwargs):
            training_value = self._get_call_arg_value('training', args, kwargs)
            if not self._expects_training_arg:
                kwargs.pop('training')
        else:
            training_value = None
            # Priority 2: `training` was passed to a parent layer.
            if call_context.training is not None:
                training_value = call_context.training
            # Priority 3a: `learning_phase()` has been set.
            elif backend.global_learning_phase_is_set():
                training_value = backend.learning_phase()
            # Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph.
            elif build_graph:
                with backend.get_graph().as_default():
                    if base_layer_utils.is_in_keras_graph():
                        training_value = backend.learning_phase()

            if self._expects_training_arg and training_value is not None:
                # Force the training_value to be bool type which matches to the contract
                # for layer/model call args.
                if tensor_util.is_tensor(training_value):
                    training_value = math_ops.cast(training_value, dtypes.bool)
                else:
                    training_value = bool(training_value)
                kwargs['training'] = training_value
                training_arg_passed_by_framework = True

        # Only create Keras history if at least one tensor originates from a
        # `keras.Input`. Otherwise this Layer may be being used outside the Keras
        # framework.
        if build_graph and base_layer_utils.needs_keras_history(inputs):
            base_layer_utils.create_keras_history(inputs)

        # Clear eager losses on top level model call.
        # We are clearing the losses only on the top level model call and not on
        # every layer/model call because layer/model may be reused.
        if (base_layer_utils.is_in_eager_or_tf_function() and
                not call_context.in_call):
            self._clear_losses()

        with call_context.enter(self, inputs, build_graph, training_value):
            # Check input assumptions set after layer building, e.g. input shape.
            if build_graph:
                # Symbolic execution on symbolic tensors. We will attempt to build
                # the corresponding TF subgraph inside `backend.get_graph()`
                # TODO(reedwm): We should assert input compatibility after the inputs
                # are casted, not before.
                input_spec.assert_input_compatibility(self.input_spec, inputs,
                                                                                            self.name)
                if (any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list)
                        and self._supports_ragged_inputs is False):    # pylint: disable=g-bool-id-comparison
                    raise ValueError('Layer %s does not support RaggedTensors as input. '
                                                     'Inputs received: %s. You can try converting your '
                                                     'input to an uniform tensor.' % (self.name, inputs))

                graph = backend.get_graph()
                with graph.as_default(), backend.name_scope(self._name_scope()):
                    # Build layer if applicable (if the `build` method has been
                    # overridden).
                    self._maybe_build(inputs)
                    cast_inputs = self._maybe_cast_inputs(inputs)

                    # Wrapping `call` function in autograph to allow for dynamic control
                    # flow and control dependencies in call. We are limiting this to
                    # subclassed layers as autograph is strictly needed only for
                    # subclassed layers and models.
                    # tf_convert will respect the value of autograph setting in the
                    # enclosing tf.function, if any.
                    if (base_layer_utils.is_subclassed(self) and
                            not base_layer_utils.from_saved_model(self)):
                        call_fn = autograph.tf_convert(
                                self._inference, ag_ctx.control_status_ctx())
                    else:
                        call_fn = self._inference

                    if not self.dynamic:
                        try:
                            with base_layer_utils.autocast_context_manager(
                                    self._compute_dtype):
                                # Add auto_control_deps in V2 when they are not already added by
                                # a `tf.function`.
                                if (ops.executing_eagerly_outside_functions() and
                                        not base_layer_utils.is_in_eager_or_tf_function()):
                                    with auto_control_deps.AutomaticControlDependencies() as acd:
                                        outputs = call_fn(cast_inputs, *args, **kwargs)
                                        # Wrap Tensors in `outputs` in `tf.identity` to avoid
                                        # circular dependencies.
                                        outputs = base_layer_utils.mark_as_return(outputs, acd)
                                else:
                                    outputs = call_fn(cast_inputs, *args, **kwargs)

                        except errors.OperatorNotAllowedInGraphError as e:
                            raise TypeError('You are attempting to use Python control '
                                                            'flow in a layer that was not declared to be '
                                                            'dynamic. Pass `dynamic=True` to the class '
                                                            'constructor.\nEncountered error:\n"""\n' +
                                                            str(e) + '\n"""')
                    else:
                        # We will use static shape inference to return symbolic tensors
                        # matching the specifications of the layer outputs.
                        # Since `self.dynamic` is True, we will never attempt to
                        # run the underlying TF graph (which is disconnected).
                        # TODO(fchollet): consider py_func as an alternative, which
                        # would enable us to run the underlying graph if needed.
                        outputs = self._symbolic_call(inputs)

                    if outputs is None:
                        raise ValueError('A layer\'s `call` method should return a '
                                                         'Tensor or a list of Tensors, not None '
                                                         '(layer: ' + self.name + ').')
                    if base_layer_utils.have_all_keras_metadata(inputs):
                        if training_arg_passed_by_framework:
                            kwargs.pop('training')
                        if mask_arg_passed_by_framework:
                            kwargs.pop('mask')
                        inputs, outputs = self._set_connectivity_metadata_(
                                inputs, outputs, args, kwargs)
                    self._handle_activity_regularization(inputs, outputs)
                    self._set_mask_metadata(inputs, outputs, input_masks)
                    if hasattr(self, '_set_inputs') and not self.inputs:
                        # Subclassed network: explicitly set metadata normally set by
                        # a call to self._set_inputs().
                        # TODO(b/120997007): This should be done in Eager as well, but
                        # causes garbage collection issues because of the placeholders
                        # created on the default Keras graph.
                        self._set_inputs(inputs, outputs)
            else:
                # Eager execution on data tensors.
                with backend.name_scope(self._name_scope()):
                    self._maybe_build(inputs)
                    cast_inputs = self._maybe_cast_inputs(inputs)
                    with base_layer_utils.autocast_context_manager(
                            self._compute_dtype):
                        outputs = self._inference(cast_inputs, *args, **kwargs)
                    self._handle_activity_regularization(inputs, outputs)
                    self._set_mask_metadata(inputs, outputs, input_masks)

        return outputs