Esempio n. 1
0
def deserialize_keras_object(identifier,
                             module_objects=None,
                             custom_objects=None,
                             printable_module_name='object'):
    """Turns the serialized form of a Keras object back into an actual object."""
    if identifier is None:
        return None

    if isinstance(identifier, dict):
        # In this case we are dealing with a Keras config dictionary.
        config = identifier
        (cls, cls_config) = class_and_config_for_serialized_keras_object(
            config, module_objects, custom_objects, printable_module_name)

        if hasattr(cls, 'from_config'):
            arg_spec = tf_inspect.getfullargspec(cls.from_config)
            custom_objects = custom_objects or {}

            if 'custom_objects' in arg_spec.args:
                return cls.from_config(
                    cls_config,
                    custom_objects=dict(
                        list(_GLOBAL_CUSTOM_OBJECTS.items()) +
                        list(custom_objects.items())))
            with CustomObjectScope(custom_objects):
                return cls.from_config(cls_config)
        else:
            # Then `cls` may be a function returning a class.
            # in this case by convention `config` holds
            # the kwargs of the function.
            custom_objects = custom_objects or {}
            with CustomObjectScope(custom_objects):
                return cls(**cls_config)
    elif isinstance(identifier, six.string_types):
        object_name = identifier
        if custom_objects and object_name in custom_objects:
            obj = custom_objects.get(object_name)
        elif object_name in _GLOBAL_CUSTOM_OBJECTS:
            obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
        else:
            obj = module_objects.get(object_name)
            if obj is None:
                raise ValueError(
                    'Unknown {}: {}. Please ensure this object is '
                    'passed to the `custom_objects` argument. See '
                    'https://www.tensorflow.org/guide/keras/save_and_serialize'
                    '#registering_the_custom_object for details.'.format(
                        printable_module_name, object_name))

        # Classes passed by name are instantiated with no args, functions are
        # returned as-is.
        if tf_inspect.isclass(obj):
            return obj()
        return obj
    elif tf_inspect.isfunction(identifier):
        # If a function has already been deserialized, return as is.
        return identifier
    else:
        raise ValueError('Could not interpret serialized %s: %s' %
                         (printable_module_name, identifier))
Esempio n. 2
0
 def __init__(self, x, y, image_data_generator,
              batch_size=32,
              shuffle=False,
              sample_weight=None,
              seed=None,
              data_format=None,
              save_to_dir=None,
              save_prefix='',
              save_format='png',
              subset=None,
              dtype=None):
   if data_format is None:
     data_format = backend.image_data_format()
   kwargs = {}
   if 'dtype' in tf_inspect.getfullargspec(
       image.NumpyArrayIterator.__init__)[0]:
     if dtype is None:
       dtype = backend.floatx()
     kwargs['dtype'] = dtype
   super(NumpyArrayIterator, self).__init__(
       x, y, image_data_generator,
       batch_size=batch_size,
       shuffle=shuffle,
       sample_weight=sample_weight,
       seed=seed,
       data_format=data_format,
       save_to_dir=save_to_dir,
       save_prefix=save_prefix,
       save_format=save_format,
       subset=subset,
       **kwargs)
Esempio n. 3
0
 def _recreate_layer_from_config(self, layer, go_backwards=False):
     # When recreating the layer from its config, it is possible that the
     # layer is a RNN layer that contains custom cells. In this case we
     # inspect the layer and pass the custom cell class as part of the
     # `custom_objects` argument when calling `from_config`.  See
     # https://github.com/tensorflow/tensorflow/issues/26581 for more detail.
     config = layer.get_config()
     if go_backwards:
         config["go_backwards"] = not config["go_backwards"]
     if (
         "custom_objects"
         in tf_inspect.getfullargspec(layer.__class__.from_config).args
     ):
         custom_objects = {}
         cell = getattr(layer, "cell", None)
         if cell is not None:
             custom_objects[cell.__class__.__name__] = cell.__class__
             # For StackedRNNCells
             stacked_cells = getattr(cell, "cells", [])
             for c in stacked_cells:
                 custom_objects[c.__class__.__name__] = c.__class__
         return layer.__class__.from_config(
             config, custom_objects=custom_objects
         )
     else:
         return layer.__class__.from_config(config)
Esempio n. 4
0
 def __init__(self, cell, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self.cell = cell
     cell_call_spec = tf_inspect.getfullargspec(cell.call)
     self._call_spec.expects_training_arg = (
         "training" in cell_call_spec.args
     ) or (cell_call_spec.varkw is not None)
Esempio n. 5
0
  def __init__(self, layer):
    self.layer = layer

    self.layer_call_method = _get_layer_call_method(layer)
    self._expects_training_arg = utils.layer_uses_training_bool(layer)
    self._training_arg_index = utils.get_training_arg_index(
        self.layer_call_method)

    # If the layer call function has kwargs, then the traced function cannot
    # have an input signature.
    arg_spec = tf_inspect.getfullargspec(self.layer_call_method)
    self._has_kwargs = bool(self._expects_training_arg or
                            arg_spec.defaults or
                            arg_spec.kwonlyargs or
                            arg_spec.varkw)

    self._input_signature = self._generate_input_signature(layer)
    self._functions = weakref.WeakValueDictionary()
    # Bool indicating whether this object is currently tracing the layer call
    # functions.
    self.tracing = False

    # Get the input argument name from the args.
    args = arg_spec.args
    if tf_inspect.ismethod(self.layer_call_method):
      args = args[1:]
    self._input_arg_name = args[0] if args else 'inputs'
Esempio n. 6
0
def get_training_arg_index(call_fn):
    """Returns the index of 'training' in the layer call function arguments.

  Args:
    call_fn: Call function.

  Returns:
    - n: index of 'training' in the call function arguments.
    - -1: if 'training' is not found in the arguments, but layer.call accepts
          variable keyword arguments
    - None: if layer doesn't expect a training argument.
  """
    argspec = tf_inspect.getfullargspec(call_fn)
    if argspec.varargs:
        # When there are variable args, training must be a keyword arg.
        if 'training' in argspec.kwonlyargs or argspec.varkw:
            return -1
        return None
    else:
        # Try to find 'training' in the list of args or kwargs.
        arg_list = argspec.args
        if call_is_method(call_fn):
            arg_list = arg_list[1:]

        if 'training' in arg_list:
            return arg_list.index('training')
        elif 'training' in argspec.kwonlyargs or argspec.varkw:
            return -1
        return None
Esempio n. 7
0
def fn_args(fn):
    """Get argument names for function-like object.

  Args:
    fn: Function, or function-like object (e.g., result of `functools.partial`).

  Returns:
    `tuple` of string argument names.

  Raises:
    ValueError: if partial function has positionally bound arguments
  """
    if isinstance(fn, functools.partial):
        args = fn_args(fn.func)
        args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
    else:
        if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
            fn = fn.__call__
        args = tf_inspect.getfullargspec(fn).args
        if is_bound_method(fn) and args:
            # If it's a bound method, it may or may not have a self/cls first
            # argument; for example, self could be captured in *args.
            # If it does have a positional argument, it is self/cls.
            args.pop(0)
    return tuple(args)
Esempio n. 8
0
 def __init__(self,
              featurewise_center=False,
              samplewise_center=False,
              featurewise_std_normalization=False,
              samplewise_std_normalization=False,
              zca_whitening=False,
              zca_epsilon=1e-6,
              rotation_range=0,
              width_shift_range=0.,
              height_shift_range=0.,
              brightness_range=None,
              shear_range=0.,
              zoom_range=0.,
              channel_shift_range=0.,
              fill_mode='nearest',
              cval=0.,
              horizontal_flip=False,
              vertical_flip=False,
              rescale=None,
              preprocessing_function=None,
              data_format=None,
              validation_split=0.0,
              dtype=None):
   if data_format is None:
     data_format = backend.image_data_format()
   kwargs = {}
   if 'dtype' in tf_inspect.getfullargspec(
       image.ImageDataGenerator.__init__)[0]:
     if dtype is None:
       dtype = backend.floatx()
     kwargs['dtype'] = dtype
   super(ImageDataGenerator, self).__init__(
       featurewise_center=featurewise_center,
       samplewise_center=samplewise_center,
       featurewise_std_normalization=featurewise_std_normalization,
       samplewise_std_normalization=samplewise_std_normalization,
       zca_whitening=zca_whitening,
       zca_epsilon=zca_epsilon,
       rotation_range=rotation_range,
       width_shift_range=width_shift_range,
       height_shift_range=height_shift_range,
       brightness_range=brightness_range,
       shear_range=shear_range,
       zoom_range=zoom_range,
       channel_shift_range=channel_shift_range,
       fill_mode=fill_mode,
       cval=cval,
       horizontal_flip=horizontal_flip,
       vertical_flip=vertical_flip,
       rescale=rescale,
       preprocessing_function=preprocessing_function,
       data_format=data_format,
       validation_split=validation_split,
       **kwargs)
Esempio n. 9
0
  def testTrainingDefaults(self):
    def assert_training_default(fn, default_value):
      arg_spec = tf_inspect.getfullargspec(fn)
      index = len(arg_spec.args) - arg_spec.args.index('training')
      self.assertEqual(arg_spec.defaults[-index], default_value)

    class LayerWithTrainingRequiredArg(keras.engine.base_layer.Layer):

      def call(self, inputs, training):
        return control_flow_util.smart_cond(training, lambda: inputs * 0,
                                            lambda: tf.identity(inputs))

    class LayerWithTrainingDefaultTrue(keras.engine.base_layer.Layer):

      def call(self, inputs, training=True):
        return control_flow_util.smart_cond(training, lambda: inputs * 0,
                                            lambda: tf.identity(inputs))

    class Model(keras.models.Model):

      def __init__(self):
        super(Model, self).__init__()
        self.layer_with_training_default_none = LayerWithLearningPhase()
        self.layer_with_training_default_true = LayerWithTrainingDefaultTrue()
        self.layer_with_required_training_arg = LayerWithTrainingRequiredArg()

      def call(self, inputs):
        x = self.layer_with_training_default_none(inputs)
        x += self.layer_with_training_default_true(inputs)
        x += self.layer_with_required_training_arg(inputs, False)
        return x

    model = Model()
    # Build and set model inputs
    model.predict(np.ones([1, 3]).astype('float32'))
    saved_model_dir = self._save_model_dir()
    model.save(saved_model_dir, save_format='tf')
    load = tf.saved_model.load(saved_model_dir)

    # Ensure that the Keras loader is able to load and build the model.
    _ = keras_load.load(saved_model_dir)

    assert_training_default(load.__call__, False)
    assert_training_default(
        load.layer_with_training_default_none.__call__, False)
    assert_training_default(
        load.layer_with_training_default_true.__call__, True)

    # Assert that there are no defaults for layer with required training arg
    arg_spec = tf_inspect.getfullargspec(
        load.layer_with_required_training_arg.__call__)
    self.assertFalse(arg_spec.defaults)  # defaults is None or empty
Esempio n. 10
0
def has_arg(fn, name, accept_all=False):
    """Checks if a callable accepts a given keyword argument.

  Args:
      fn: Callable to inspect.
      name: Check if `fn` can be called with `name` as a keyword argument.
      accept_all: What to return if there is no parameter called `name` but the
        function accepts a `**kwargs` argument.

  Returns:
      bool, whether `fn` accepts a `name` keyword argument.
  """
    arg_spec = tf_inspect.getfullargspec(fn)
    if accept_all and arg_spec.varkw is not None:
        return True
    return name in arg_spec.args or name in arg_spec.kwonlyargs
Esempio n. 11
0
    def __init__(self, layer):
        self.layer = layer

        self.layer_call_method = _get_layer_call_method(layer)
        self._expects_training_arg = utils.layer_uses_training_bool(layer)
        self._training_arg_index = utils.get_training_arg_index(
            self.layer_call_method)

        self._layer_inputs = self._get_layer_inputs(layer)
        self._functions = weakref.WeakValueDictionary()

        # Get the input argument name from the args.
        arg_spec = tf_inspect.getfullargspec(self.layer_call_method)
        args = arg_spec.args
        if tf_inspect.ismethod(self.layer_call_method):
            args = args[1:]
        self._input_arg_name = args[0] if args else 'inputs'
Esempio n. 12
0
def get_training_arg_index(call_fn):
    """Returns the index of 'training' in the layer call function arguments.

  Args:
    call_fn: Call function.

  Returns:
    - n: index of 'training' in the call function arguments.
    - -1: if 'training' is not found in the arguments, but layer.call accepts
          variable keyword arguments
    - None: if layer doesn't expect a training argument.
  """
    arg_list = tf_inspect.getfullargspec(call_fn).args
    if tf_inspect.ismethod(call_fn):
        arg_list = arg_list[1:]
    if 'training' in arg_list:
        return arg_list.index('training')
    else:
        return -1
Esempio n. 13
0
    def __init__(
        self, function, output_shape=None, mask=None, arguments=None, **kwargs
    ):
        super().__init__(**kwargs)

        self.arguments = arguments or {}
        self.function = function

        if mask is not None:
            self.supports_masking = True
        self.mask = mask
        self._output_shape = output_shape

        # Warning on every invocation will be quite irksome in Eager mode.
        self._already_warned = False

        function_args = tf_inspect.getfullargspec(function).args
        self._fn_expects_training_arg = "training" in function_args
        self._fn_expects_mask_arg = "mask" in function_args
Esempio n. 14
0
 def __init__(self, directory, image_data_generator,
              target_size=(256, 256),
              color_mode='rgb',
              classes=None,
              class_mode='categorical',
              batch_size=32,
              shuffle=True,
              seed=None,
              data_format=None,
              save_to_dir=None,
              save_prefix='',
              save_format='png',
              follow_links=False,
              subset=None,
              interpolation='nearest',
              dtype=None):
   if data_format is None:
     data_format = backend.image_data_format()
   kwargs = {}
   if 'dtype' in tf_inspect.getfullargspec(
       image.ImageDataGenerator.__init__)[0]:
     if dtype is None:
       dtype = backend.floatx()
     kwargs['dtype'] = dtype
   super(DirectoryIterator, self).__init__(
       directory, image_data_generator,
       target_size=target_size,
       color_mode=color_mode,
       classes=classes,
       class_mode=class_mode,
       batch_size=batch_size,
       shuffle=shuffle,
       seed=seed,
       data_format=data_format,
       save_to_dir=save_to_dir,
       save_prefix=save_prefix,
       save_format=save_format,
       follow_links=follow_links,
       subset=subset,
       interpolation=interpolation,
       **kwargs)
Esempio n. 15
0
def _has_kwargs(fn):
    """Returns whether the passed callable has **kwargs in its signature.

    Args:
      fn: Function, or function-like object (e.g., result of `functools.partial`).

    Returns:
      `bool`: if `fn` has **kwargs in its signature.

    Raises:
       `TypeError`: If fn is not a Function, or function-like object.
    """
    if isinstance(fn, functools.partial):
        fn = fn.func
    elif _is_callable_object(fn):
        fn = fn.__call__
    elif not callable(fn):
        raise TypeError(
            "fn should be a function-like object, but is of type {}.".format(
                type(fn)))
    return tf_inspect.getfullargspec(fn).varkw is not None
Esempio n. 16
0
def array_to_img(x, data_format=None, scale=True, dtype=None):
  """Converts a 3D Numpy array to a PIL Image instance.

  Usage:

  ```python
  from PIL import Image
  img = np.random.random(size=(100, 100, 3))
  pil_img = tf.keras.preprocessing.image.array_to_img(img)
  ```


  Arguments:
      x: Input Numpy array.
      data_format: Image data format, can be either "channels_first" or
        "channels_last". Defaults to `None`, in which case the global setting
        `tf.keras.backend.image_data_format()` is used (unless you changed it,
        it defaults to "channels_last").
      scale: Whether to rescale image values to be within `[0, 255]`. Defaults
        to `True`.
      dtype: Dtype to use. Default to `None`, in which case the global setting
      `tf.keras.backend.floatx()` is used (unless you changed it, it defaults
      to "float32")

  Returns:
      A PIL Image instance.

  Raises:
      ImportError: if PIL is not available.
      ValueError: if invalid `x` or `data_format` is passed.
  """

  if data_format is None:
    data_format = backend.image_data_format()
  kwargs = {}
  if 'dtype' in tf_inspect.getfullargspec(image.array_to_img)[0]:
    if dtype is None:
      dtype = backend.floatx()
    kwargs['dtype'] = dtype
  return image.array_to_img(x, data_format=data_format, scale=scale, **kwargs)
Esempio n. 17
0
    def _maybe_wrap_with_training_arg(self, call_fn, match_layer_training_arg):
        """Wraps call function with added training argument if necessary."""
        if not self.layer._expects_training_arg and self._expects_training_arg:  # pylint: disable=protected-access
            # Add training arg to wrapper function.
            arg_spec = tf_inspect.getfullargspec(call_fn)
            args = arg_spec.args + ['training']
            defaults = list(arg_spec.defaults or [])
            defaults.append(False)
            new_arg_spec = tf_inspect.FullArgSpec(
                args=args,
                varargs=arg_spec.varargs,
                varkw=arg_spec.varkw,
                defaults=defaults,
                kwonlyargs=arg_spec.kwonlyargs,
                kwonlydefaults=arg_spec.kwonlydefaults,
                annotations=arg_spec.annotations)

            # Set new training arg index
            self._training_arg_index = len(args) - 1
            if tf_inspect.ismethod(call_fn):
                self._training_arg_index -= 1

            def wrap_with_training_arg(*args, **kwargs):
                if match_layer_training_arg:
                    # Remove the training value, since the original call_fn does not
                    # expect a training arg. Instead, the training value will be
                    # propagated using the call context created in LayerCall.
                    args = list(args)
                    kwargs = kwargs.copy()
                    utils.remove_training_arg(self._training_arg_index, args,
                                              kwargs)
                return call_fn(*args, **kwargs)

            return tf.__internal__.decorator.make_decorator(
                target=call_fn,
                decorator_func=wrap_with_training_arg,
                decorator_argspec=new_arg_spec)

        return call_fn
Esempio n. 18
0
def img_to_array(img, data_format=None, dtype=None):
  """Converts a PIL Image instance to a Numpy array.

  Usage:

  ```python
  from PIL import Image
  img_data = np.random.random(size=(100, 100, 3))
  img = tf.keras.preprocessing.image.array_to_img(img_data)
  array = tf.keras.preprocessing.image.img_to_array(img)
  ```


  Arguments:
      img: Input PIL Image instance.
      data_format: Image data format, can be either "channels_first" or
        "channels_last". Defaults to `None`, in which case the global setting
        `tf.keras.backend.image_data_format()` is used (unless you changed it,
        it defaults to "channels_last").
      dtype: Dtype to use. Default to `None`, in which case the global setting
      `tf.keras.backend.floatx()` is used (unless you changed it, it defaults
      to "float32")

  Returns:
      A 3D Numpy array.

  Raises:
      ValueError: if invalid `img` or `data_format` is passed.
  """

  if data_format is None:
    data_format = backend.image_data_format()
  kwargs = {}
  if 'dtype' in tf_inspect.getfullargspec(image.img_to_array)[0]:
    if dtype is None:
      dtype = backend.floatx()
    kwargs['dtype'] = dtype
  return image.img_to_array(img, data_format=data_format, **kwargs)
Esempio n. 19
0
  def add(self, layer):
    """Adds a layer instance on top of the layer stack.

    Args:
        layer: layer instance.

    Raises:
        TypeError: If `layer` is not a layer instance.
        ValueError: In case the `layer` argument does not
            know its input shape.
        ValueError: In case the `layer` argument has
            multiple output tensors, or is already connected
            somewhere else (forbidden in `Sequential` models).
    """
    # If we are passed a Keras tensor created by keras.Input(), we can extract
    # the input layer from its keras history and use that without any loss of
    # generality.
    if hasattr(layer, '_keras_history'):
      origin_layer = layer._keras_history[0]
      if isinstance(origin_layer, input_layer.InputLayer):
        layer = origin_layer
        logging.warning(
            'Please add `keras.layers.InputLayer` instead of `keras.Input` to '
            'Sequential model. `keras.Input` is intended to be used by '
            'Functional model.')

    if isinstance(layer, tf.Module):
      if not isinstance(layer, base_layer.Layer):
        layer = functional.ModuleWrapper(layer)
    else:
      raise TypeError('The added layer must be '
                      'an instance of class Layer. '
                      'Found: ' + str(layer))

    tf_utils.assert_no_legacy_layers([layer])
    if not self._is_layer_name_unique(layer):
      raise ValueError('All layers added to a Sequential model '
                       'should have unique names. Name "%s" is already the name'
                       ' of a layer in this model. Update the `name` argument '
                       'to pass a unique name.' % (layer.name,))

    self.built = False
    set_inputs = False
    self._maybe_create_attribute('_self_tracked_trackables', [])
    if not self._self_tracked_trackables:
      if isinstance(layer, input_layer.InputLayer):
        # Case where the user passes an Input or InputLayer layer via `add`.
        set_inputs = True
      else:
        batch_shape, dtype = training_utils.get_input_shape_and_dtype(layer)
        if batch_shape:
          # Instantiate an input layer.
          x = input_layer.Input(
              batch_shape=batch_shape, dtype=dtype, name=layer.name + '_input')
          # This will build the current layer
          # and create the node connecting the current layer
          # to the input layer we just created.
          layer(x)
          set_inputs = True

      if set_inputs:
        outputs = tf.nest.flatten(layer._inbound_nodes[-1].outputs)
        if len(outputs) != 1:
          raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)
        self.outputs = outputs
        self.inputs = layer_utils.get_source_inputs(self.outputs[0])
        self.built = True
        self._has_explicit_input_shape = True

    elif self.outputs:
      # If the model is being built continuously on top of an input layer:
      # refresh its output.
      output_tensor = layer(self.outputs[0])
      if len(tf.nest.flatten(output_tensor)) != 1:
        raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)
      self.outputs = [output_tensor]
      self.built = True

    if set_inputs or self._graph_initialized:
      self._init_graph_network(self.inputs, self.outputs)
      self._graph_initialized = True
    else:
      self._self_tracked_trackables.append(layer)
      self._handle_deferred_layer_dependencies([layer])

    self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
Esempio n. 20
0
def deserialize_keras_object(identifier,
                             module_objects=None,
                             custom_objects=None,
                             printable_module_name='object'):
    """Turns the serialized form of a Keras object back into an actual object.

  This function is for mid-level library implementers rather than end users.

  Importantly, this utility requires you to provide the dict of `module_objects`
  to use for looking up the object config; this is not populated by default.
  If you need a deserialization utility that has preexisting knowledge of
  built-in Keras objects, use e.g. `keras.layers.deserialize(config)`,
  `keras.metrics.deserialize(config)`, etc.

  Calling `deserialize_keras_object` while underneath the
  `SharedObjectLoadingScope` context manager will cause any already-seen shared
  objects to be returned as-is rather than creating a new object.

  Args:
    identifier: the serialized form of the object.
    module_objects: A dictionary of built-in objects to look the name up in.
      Generally, `module_objects` is provided by midlevel library implementers.
    custom_objects: A dictionary of custom objects to look the name up in.
      Generally, `custom_objects` is provided by the end user.
    printable_module_name: A human-readable string representing the type of the
      object. Printed in case of exception.

  Returns:
    The deserialized object.

  Example:

  A mid-level library implementer might want to implement a utility for
  retrieving an object from its config, as such:

  ```python
  def deserialize(config, custom_objects=None):
     return deserialize_keras_object(
       identifier,
       module_objects=globals(),
       custom_objects=custom_objects,
       name="MyObjectType",
     )
  ```

  This is how e.g. `keras.layers.deserialize()` is implemented.
  """
    if identifier is None:
        return None

    if isinstance(identifier, dict):
        # In this case we are dealing with a Keras config dictionary.
        config = identifier
        (cls, cls_config) = class_and_config_for_serialized_keras_object(
            config, module_objects, custom_objects, printable_module_name)

        # If this object has already been loaded (i.e. it's shared between multiple
        # objects), return the already-loaded object.
        shared_object_id = config.get(SHARED_OBJECT_KEY)
        shared_object = _shared_object_loading_scope().get(shared_object_id)  # pylint: disable=assignment-from-none
        if shared_object is not None:
            return shared_object

        if hasattr(cls, 'from_config'):
            arg_spec = tf_inspect.getfullargspec(cls.from_config)
            custom_objects = custom_objects or {}

            if 'custom_objects' in arg_spec.args:
                deserialized_obj = cls.from_config(
                    cls_config,
                    custom_objects=dict(
                        list(_GLOBAL_CUSTOM_OBJECTS.items()) +
                        list(custom_objects.items())))
            else:
                with CustomObjectScope(custom_objects):
                    deserialized_obj = cls.from_config(cls_config)
        else:
            # Then `cls` may be a function returning a class.
            # in this case by convention `config` holds
            # the kwargs of the function.
            custom_objects = custom_objects or {}
            with CustomObjectScope(custom_objects):
                deserialized_obj = cls(**cls_config)

        # Add object to shared objects, in case we find it referenced again.
        _shared_object_loading_scope().set(shared_object_id, deserialized_obj)

        return deserialized_obj

    elif isinstance(identifier, str):
        object_name = identifier
        if custom_objects and object_name in custom_objects:
            obj = custom_objects.get(object_name)
        elif object_name in _GLOBAL_CUSTOM_OBJECTS:
            obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
        else:
            obj = module_objects.get(object_name)
            if obj is None:
                raise ValueError(
                    'Unknown {}: {}. Please ensure this object is '
                    'passed to the `custom_objects` argument. See '
                    'https://www.tensorflow.org/guide/keras/save_and_serialize'
                    '#registering_the_custom_object for details.'.format(
                        printable_module_name, object_name))

        # Classes passed by name are instantiated with no args, functions are
        # returned as-is.
        if tf_inspect.isclass(obj):
            return obj()
        return obj
    elif tf_inspect.isfunction(identifier):
        # If a function has already been deserialized, return as is.
        return identifier
    else:
        raise ValueError('Could not interpret serialized %s: %s' %
                         (printable_module_name, identifier))
Esempio n. 21
0
 def assert_training_default(fn, default_value):
   arg_spec = tf_inspect.getfullargspec(fn)
   index = len(arg_spec.args) - arg_spec.args.index('training')
   self.assertEqual(arg_spec.defaults[-index], default_value)
Esempio n. 22
0
 def _call_full_argspec(self):
     # Argspec inspection is expensive and the call spec is used often, so it
     # makes sense to cache the result.
     return tf_inspect.getfullargspec(self.forward_pass)
Esempio n. 23
0
def maybe_add_training_arg(original_call, wrapped_call, expects_training_arg,
                           default_training_value):
    """Decorate call and optionally adds training argument.

  If a layer expects a training argument, this function ensures that 'training'
  is present in the layer args or kwonly args, with the default training value.

  Args:
    original_call: Original call function.
    wrapped_call: Wrapped call function.
    expects_training_arg: Whether to include 'training' argument.
    default_training_value: Default value of the training kwarg to include in
      the arg spec. If `None`, the default is
      `tf.keras.backend.learning_phase()`.

  Returns:
    Tuple of (
      function that calls `wrapped_call` and sets the training arg,
      Argspec of returned function or `None` if the argspec is unchanged)
  """
    if not expects_training_arg:
        return wrapped_call, None

    def wrap_with_training_arg(*args, **kwargs):
        """Wrap the `wrapped_call` function, and set training argument."""
        training_arg_index = get_training_arg_index(original_call)
        training = get_training_arg(training_arg_index, args, kwargs)
        if training is None:
            training = default_training_value or backend.learning_phase()

        args = list(args)
        kwargs = kwargs.copy()

        def replace_training_and_call(training):
            set_training_arg(training, training_arg_index, args, kwargs)
            return wrapped_call(*args, **kwargs)

        return control_flow_util.smart_cond(
            training, lambda: replace_training_and_call(True),
            lambda: replace_training_and_call(False))

    # Create arg spec for decorated function. If 'training' is not defined in the
    # args of the original arg spec, then add it to kwonlyargs.
    arg_spec = tf_inspect.getfullargspec(original_call)
    defaults = list(arg_spec.defaults) if arg_spec.defaults is not None else []

    kwonlyargs = arg_spec.kwonlyargs
    kwonlydefaults = arg_spec.kwonlydefaults or {}
    # Add training arg if it does not exist, or set the default training value.
    if 'training' not in arg_spec.args:
        kwonlyargs.append('training')
        kwonlydefaults['training'] = default_training_value
    else:
        index = arg_spec.args.index('training')
        training_default_index = len(arg_spec.args) - index
        if (arg_spec.defaults
                and len(arg_spec.defaults) >= training_default_index
                and defaults[-training_default_index] is None):
            defaults[-training_default_index] = default_training_value

    decorator_argspec = tf_inspect.FullArgSpec(
        args=arg_spec.args,
        varargs=arg_spec.varargs,
        varkw=arg_spec.varkw,
        defaults=defaults,
        kwonlyargs=kwonlyargs,
        kwonlydefaults=kwonlydefaults,
        annotations=arg_spec.annotations)
    return wrap_with_training_arg, decorator_argspec