Пример #1
0
  def _call_and_compute_mask(self, inputs, training=None, mask=None):
    if not self.built and self._is_graph_network:
      self._init_graph_network(self.inputs, self.outputs, name=self.name)

    x = inputs
    for layer in self.layers:
      kwargs = {}
      if 'mask' in tf_inspect.getfullargspec(layer.call).args:
        kwargs['mask'] = mask
      if 'training' in tf_inspect.getfullargspec(layer.call).args:
        kwargs['training'] = training

      if isinstance(layer, Network) and layer._compute_output_and_mask_jointly:
        x, mask = layer._call_and_compute_mask(x, **kwargs)
      else:
        if not layer.built:
          # Build layer if applicable.
          with ops.name_scope(layer._name_scope()):
            layer._maybe_build(x)
          layer.built = True
        x = layer.call(x, **kwargs)
        if layer.supports_masking:
          mask = layer.compute_mask(x, mask)
        else:
          mask = None
      if not context.executing_eagerly():
        x._keras_mask = mask
    return x, mask
Пример #2
0
  def test_class_alias(self, mock_warning):
    class MyClass(object):
      """My docstring."""

      init_args = []

      def __init__(self, arg):
        MyClass.init_args.append(arg)

    deprecated_cls = deprecation.deprecated_alias("deprecated.cls",
                                                  "real.cls",
                                                  MyClass)

    print(deprecated_cls.__name__)
    print(deprecated_cls.__module__)
    print(deprecated_cls.__doc__)

    MyClass("test")
    self.assertEqual(0, mock_warning.call_count)
    deprecated_cls("deprecated")
    self.assertEqual(1, mock_warning.call_count)
    # Make sure the error points to the right file.
    self.assertRegexpMatches(mock_warning.call_args[0][1],
                             r"deprecation_test\.py:")
    deprecated_cls("deprecated again")
    self.assertEqual(1, mock_warning.call_count)

    self.assertEqual(["test", "deprecated", "deprecated again"],
                     MyClass.init_args)

    # Check __init__ signature matches for doc generation.
    self.assertEqual(
        tf_inspect.getfullargspec(MyClass.__init__),
        tf_inspect.getfullargspec(deprecated_cls.__init__))
Пример #3
0
 def __init__(self, original_op, ragged_op, ragged_args):
   op_arg_names = tf_inspect.getfullargspec(original_op)[0]
   ragged_arg_names = tf_inspect.getfullargspec(ragged_op)[0]
   if op_arg_names != ragged_arg_names:
     raise AssertionError(
         'Signature must exactly match when overriding %s with %s: %s vs %s' %
         (original_op, ragged_op, op_arg_names, ragged_arg_names))
   self._ragged_op = ragged_op
   self._ragged_args = _get_arg_infos(ragged_op, ragged_args)
   if _UPDATE_DOCSTRINGS:
     arg_list = ' and '.join('`%s`' % arg for arg in ragged_args)
     original_op.__doc__ = (
         original_op.__doc__.rstrip() + '\n\n' +
         '    {0} may be a `tf.RaggedTensor`.\n'.format(arg_list))
Пример #4
0
  def test_decorator_preserves_argspec(self):

    class TestClass(object):

      def called_member(self, a):
        if a < 0:
          a = -a
        return a

      called_member_converted = api.convert()(called_member)

    tc = TestClass()
    self.assertListEqual(
        list(tf_inspect.getfullargspec(tc.called_member)),
        list(tf_inspect.getfullargspec(tc.called_member_converted)))
Пример #5
0
  def __init__(self, func, trainable=False, arguments=None, **kwargs):
    # Set self._{non,}_trainable_weights before calling Layer.__init__.
    if hasattr(func, 'trainable_variables'):
      self._trainable_weights = [v for v in func.trainable_variables]
      trainable_variables_set = set(func.trainable_variables)
    else:
      self._trainable_weights = []
      trainable_variables_set = set()
    if hasattr(func, 'variables'):
      self._non_trainable_weights = [v for v in func.variables
                                     if v not in trainable_variables_set]
    else:
      self._non_trainable_weights = []  # TODO(arnoegw): Infer from `func`.

    # TODO(b/124219898): We should be able to get the embedding dimension from
    # the restored model.
    if 'output_shape' in kwargs:
      self._output_shape = tuple(kwargs.pop('output_shape'))

    super(CustomLayer, self).__init__(trainable=trainable, **kwargs)
    # Prepare to call `func`.
    self._func = func
    self._func_fullargspec = tf_inspect.getfullargspec(func.__call__)
    self._func_wants_training = (
        'training' in self._func_fullargspec.args or
        'training' in self._func_fullargspec.kwonlyargs)
    self._arguments = arguments or {}
    # Forward the callable's regularization losses (if any).
    if hasattr(func, 'regularization_losses'):
      for l in func.regularization_losses:
        if not callable(l):
          raise ValueError(
              'CustomLayer(func) expects func.regularization_losses to be an '
              'iterable of callables, each returning a scalar loss term.')
        self.add_loss(l)  # Supports callables.
Пример #6
0
  def decorated(self, **kwargs):
    """A wrapped test method that treats some arguments in a special way."""
    mode = kwargs.pop("mode", "graph")

    distribution = kwargs.get("distribution", None)
    required_tpu = kwargs.pop("required_tpu", False)
    required_gpus = kwargs.pop("required_gpus", None)

    if distribution:
      assert required_gpus is None, (
          "Do not use `required_gpus` and `distribution` together.")
      assert required_tpu is False, (
          "Do not use `required_tpu` and `distribution` together.")
      required_gpus = distribution.required_gpus
      required_tpu = distribution.required_tpu

    if required_tpu and not TPU_TEST:
      self.skipTest("Test requires a TPU, but it's not available.")
    if not required_tpu and TPU_TEST:
      self.skipTest("Test that doesn't require a TPU.")

    if not required_gpus:
      if GPU_TEST:
        self.skipTest("Test that doesn't require GPUs.")
    elif context.num_gpus() < required_gpus:
      self.skipTest(
          "{} GPUs are not available for this test. {} GPUs are available".
          format(required_gpus, context.num_gpus()))

    # At this point, `kwargs` doesn't have `required_gpus` or `required_tpu`
    # that the user might have specified.  `kwargs` still has `mode`, which
    # the test is allowed to accept or ignore.
    requested_arguments = tf_inspect.getfullargspec(test_method).args
    missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
        set(requested_arguments + ["mode"]))
    if missing_arguments:
      raise ValueError("The test is missing arguments {} .".format(
          missing_arguments))

    kwargs_to_pass = {}
    for arg in requested_arguments:
      if arg == "self":
        kwargs_to_pass[arg] = self
      else:
        kwargs_to_pass[arg] = kwargs[arg]

    if mode == "eager":
      with ops.Graph().as_default(), context.eager_mode():
        if distribution:
          kwargs_to_pass["distribution"] = distribution.strategy
        test_method(**kwargs_to_pass)
    elif mode == "graph":
      with ops.Graph().as_default(), context.graph_mode():
        if distribution:
          kwargs_to_pass["distribution"] = distribution.strategy
        test_method(**kwargs_to_pass)
    else:
      raise ValueError(
          "'mode' has to be either 'eager' or 'graph' and not {}".format(
              mode))
Пример #7
0
 def __init__(self, x, y, image_data_generator,
              batch_size=32,
              shuffle=False,
              sample_weight=None,
              seed=None,
              data_format=None,
              save_to_dir=None,
              save_prefix='',
              save_format='png',
              subset=None,
              dtype=None):
   if data_format is None:
     data_format = backend.image_data_format()
   kwargs = {}
   if 'dtype' in tf_inspect.getfullargspec(
       image.NumpyArrayIterator.__init__)[0]:
     if dtype is None:
       dtype = backend.floatx()
     kwargs['dtype'] = dtype
   super(NumpyArrayIterator, self).__init__(
       x, y, image_data_generator,
       batch_size=batch_size,
       shuffle=shuffle,
       sample_weight=sample_weight,
       seed=seed,
       data_format=data_format,
       save_to_dir=save_to_dir,
       save_prefix=save_prefix,
       save_format=save_format,
       subset=subset,
       **kwargs)
Пример #8
0
def array_to_img(x, data_format=None, scale=True, dtype=None):
  """Converts a 3D Numpy array to a PIL Image instance.

  Arguments:
      x: Input Numpy array.
      data_format: Image data format.
          either "channels_first" or "channels_last".
      scale: Whether to rescale image values
          to be within `[0, 255]`.
      dtype: Dtype to use.

  Returns:
      A PIL Image instance.

  Raises:
      ImportError: if PIL is not available.
      ValueError: if invalid `x` or `data_format` is passed.
  """

  if data_format is None:
    data_format = backend.image_data_format()
  kwargs = {}
  if 'dtype' in tf_inspect.getfullargspec(image.array_to_img)[0]:
    if dtype is None:
      dtype = backend.floatx()
    kwargs['dtype'] = dtype
  return image.array_to_img(x, data_format=data_format, scale=scale, **kwargs)
Пример #9
0
 def __init__(self, original_op, arg_is_list=False):
   self._original_op = original_op
   self._arg_is_list = arg_is_list
   arg_names = tf_inspect.getfullargspec(original_op)[0]
   self._x = arg_names[0]
   if _UPDATE_DOCSTRINGS:
     original_op.__doc__ = (
         original_op.__doc__.rstrip() + '\n\n' +
         '    `{x}` may be a `tf.RaggedTensor`.\n'.format(x=self._x))
Пример #10
0
 def __init__(self, original_op):
   self._original_op = original_op
   arg_names = tf_inspect.getfullargspec(original_op)[0]
   self._x = arg_names[0]
   self._y = arg_names[1]
   if _UPDATE_DOCSTRINGS:
     original_op.__doc__ = (
         original_op.__doc__.rstrip() + '\n\n' +
         '    `{x}` and `{y}` may be a `tf.RaggedTensor`.\n'.format(
             x=self._x, y=self._y))
Пример #11
0
 def __init__(self,
              featurewise_center=False,
              samplewise_center=False,
              featurewise_std_normalization=False,
              samplewise_std_normalization=False,
              zca_whitening=False,
              zca_epsilon=1e-6,
              rotation_range=0,
              width_shift_range=0.,
              height_shift_range=0.,
              brightness_range=None,
              shear_range=0.,
              zoom_range=0.,
              channel_shift_range=0.,
              fill_mode='nearest',
              cval=0.,
              horizontal_flip=False,
              vertical_flip=False,
              rescale=None,
              preprocessing_function=None,
              data_format=None,
              validation_split=0.0,
              dtype=None):
   if data_format is None:
     data_format = backend.image_data_format()
   kwargs = {}
   if 'dtype' in tf_inspect.getfullargspec(
       image.ImageDataGenerator.__init__)[0]:
     if dtype is None:
       dtype = backend.floatx()
     kwargs['dtype'] = dtype
   super(ImageDataGenerator, self).__init__(
       featurewise_center=featurewise_center,
       samplewise_center=samplewise_center,
       featurewise_std_normalization=featurewise_std_normalization,
       samplewise_std_normalization=samplewise_std_normalization,
       zca_whitening=zca_whitening,
       zca_epsilon=zca_epsilon,
       rotation_range=rotation_range,
       width_shift_range=width_shift_range,
       height_shift_range=height_shift_range,
       brightness_range=brightness_range,
       shear_range=shear_range,
       zoom_range=zoom_range,
       channel_shift_range=channel_shift_range,
       fill_mode=fill_mode,
       cval=cval,
       horizontal_flip=horizontal_flip,
       vertical_flip=vertical_flip,
       rescale=rescale,
       preprocessing_function=preprocessing_function,
       data_format=data_format,
       validation_split=validation_split,
       **kwargs)
Пример #12
0
  def testGetFullArgsSpecForPartial(self):

    def func(a, b):
      del a, b

    partial_function = functools.partial(func, 1)
    argspec = tf_inspect.FullArgSpec(
        args=['b'], varargs=None, varkw=None, defaults=None,
        kwonlyargs=[], kwonlydefaults=None, annotations={})

    self.assertEqual(argspec, tf_inspect.getfullargspec(partial_function))
Пример #13
0
 def testPositionsMatchArgGiven(self):
   full_dict = tf_upgrade_v2.TFAPIChangeSpec().function_arg_warnings
   method_names = full_dict.keys()
   for method in method_names:
     # doesn't test methods on objects
     if not method.startswith("*."):
       args = full_dict[method].keys()
       method = get_symbol_for_name(tf, method)
       arg_spec = tf_inspect.getfullargspec(method)
       for (arg, pos) in args:
         self.assertEqual(arg_spec[0][pos], arg)
Пример #14
0
  def testGetFullArgSpecOnDecoratorThatChangesFullArgSpec(self):
    argspec = tf_inspect.FullArgSpec(
        args=['a', 'b', 'c'],
        varargs=None,
        varkw=None,
        defaults=(1, 'hello'),
        kwonlyargs=[],
        kwonlydefaults=None,
        annotations={})

    decorator = tf_decorator.TFDecorator('', test_undecorated_function, '',
                                         argspec)
    self.assertEqual(argspec, tf_inspect.getfullargspec(decorator))
Пример #15
0
  def testGetFullArgSpecIgnoresDecoratorsThatDontProvideFullArgSpec(self):
    argspec = tf_inspect.FullArgSpec(
        args=['a', 'b', 'c'],
        varargs=None,
        varkw=None,
        defaults=(1, 'hello'),
        kwonlyargs=[],
        kwonlydefaults=None,
        annotations={})

    inner_decorator = tf_decorator.TFDecorator('', test_undecorated_function,
                                               '', argspec)
    outer_decorator = tf_decorator.TFDecorator('', inner_decorator)
    self.assertEqual(argspec, tf_inspect.getfullargspec(outer_decorator))
Пример #16
0
  def _call_and_compute_mask(self, inputs, training=None, mask=None):
    if not self.built:
      self.build(inputs.shape)

    x = inputs
    for layer in self.layers:
      kwargs = {}
      if 'mask' in tf_inspect.getfullargspec(layer.call).args:
        kwargs['mask'] = mask
      if 'training' in tf_inspect.getfullargspec(layer.call).args:
        kwargs['training'] = training

      if isinstance(layer, Network) and layer._compute_output_and_mask_jointly:
        x, mask = layer._call_and_compute_mask(x, **kwargs)
      else:
        x = layer.call(x, **kwargs)
        if layer.supports_masking:
          mask = layer.compute_mask(x, mask)
        else:
          mask = None
      if not context.executing_eagerly():
        x._keras_mask = mask
    return x, mask
Пример #17
0
def deserialize_keras_object(identifier,
                             module_objects=None,
                             custom_objects=None,
                             printable_module_name='object'):
  if identifier is None:
    return None
  if isinstance(identifier, dict):
    # In this case we are dealing with a Keras config dictionary.
    config = identifier
    (cls, cls_config) = class_and_config_for_serialized_keras_object(
        config, module_objects, custom_objects, printable_module_name)

    if hasattr(cls, 'from_config'):
      arg_spec = tf_inspect.getfullargspec(cls.from_config)
      custom_objects = custom_objects or {}

      if 'custom_objects' in arg_spec.args:
        return cls.from_config(
            cls_config,
            custom_objects=dict(
                list(_GLOBAL_CUSTOM_OBJECTS.items()) +
                list(custom_objects.items())))
      with CustomObjectScope(custom_objects):
        return cls.from_config(cls_config)
    else:
      # Then `cls` may be a function returning a class.
      # in this case by convention `config` holds
      # the kwargs of the function.
      custom_objects = custom_objects or {}
      with CustomObjectScope(custom_objects):
        return cls(**cls_config)
  elif isinstance(identifier, six.string_types):
    object_name = identifier
    if custom_objects and object_name in custom_objects:
      obj = custom_objects.get(object_name)
    elif object_name in _GLOBAL_CUSTOM_OBJECTS:
      obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
    else:
      obj = module_objects.get(object_name)
      if obj is None:
        raise ValueError('Unknown ' + printable_module_name + ':' + object_name)
    # Classes passed by name are instantiated with no args, functions are
    # returned as-is.
    if tf_inspect.isclass(obj):
      return obj()
    return obj
  else:
    raise ValueError('Could not interpret serialized ' + printable_module_name +
                     ': ' + identifier)
Пример #18
0
def has_arg(fn, name, accept_all=False):
  """Checks if a callable accepts a given keyword argument.

  Arguments:
      fn: Callable to inspect.
      name: Check if `fn` can be called with `name` as a keyword argument.
      accept_all: What to return if there is no parameter called `name`
                  but the function accepts a `**kwargs` argument.

  Returns:
      bool, whether `fn` accepts a `name` keyword argument.
  """
  arg_spec = tf_inspect.getfullargspec(fn)
  if accept_all and arg_spec.varkw is not None:
    return True
  return name in arg_spec.args
Пример #19
0
  def __init__(self, function, output_shape=None, mask=None, arguments=None,
               **kwargs):
    super(Lambda, self).__init__(**kwargs)
    self.function = function
    self.arguments = arguments if arguments else {}
    if mask is not None:
      self.supports_masking = True
    self.mask = mask
    self._output_shape = output_shape
    self._variable_dict = {}
    # These attributes are inherited from `Layer`.
    self._trainable_weights = []
    self._non_trainable_weights = []

    function_args = tf_inspect.getfullargspec(self.function).args
    self._fn_expects_training_arg = 'training' in function_args
    self._fn_expects_mask_arg = 'mask' in function_args
Пример #20
0
  def testGetFullArgSpecOnPartialNoArgumentsLeft(self):
    """Tests getfullargspec on partial function that prunes all arguments."""

    def func(m, n):
      return 2 * m + n

    partial_func = functools.partial(func, 7, 10)
    argspec = tf_inspect.FullArgSpec(
        args=[],
        varargs=None,
        varkw=None,
        defaults=None,
        kwonlyargs=[],
        kwonlydefaults=None,
        annotations={})

    self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func))
Пример #21
0
 def testPositionsMatchArgGiven(self):
   full_dict = tf_upgrade_v2.TFAPIChangeSpec().function_arg_warnings
   method_names = full_dict.keys()
   for method_name in method_names:
     args = full_dict[method_name].keys()
     # special case for optimizer methods
     if method_name.startswith("*."):
       method = method_name.replace("*", "tf.train.Optimizer")
     else:
       method = method_name
     method = get_symbol_for_name(tf, method)
     arg_spec = tf_inspect.getfullargspec(method)
     for (arg, pos) in args:
       # to deal with the self argument on methods on objects
       if method_name.startswith("*."):
         pos += 1
       self.assertEqual(arg_spec[0][pos], arg)
Пример #22
0
  def testGetFullArgSpecOnNewClass(self):

    class NewClass(object):

      def __new__(cls, a, b=1, c='hello'):
        pass

    argspec = tf_inspect.FullArgSpec(
        args=['cls', 'a', 'b', 'c'],
        varargs=None,
        varkw=None,
        defaults=(1, 'hello'),
        kwonlyargs=[],
        kwonlydefaults=None,
        annotations={})

    self.assertEqual(argspec, tf_inspect.getfullargspec(NewClass))
Пример #23
0
  def testGetFullArgSpecOnPartialWithVarargs(self):
    """Tests getfullargspec on partial function with variable arguments."""

    def func(m, *arg):
      return m + len(arg)

    partial_func = functools.partial(func, 7, 8)
    argspec = tf_inspect.FullArgSpec(
        args=[],
        varargs='arg',
        varkw=None,
        defaults=None,
        kwonlyargs=[],
        kwonlydefaults=None,
        annotations={})

    self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func))
Пример #24
0
  def testGetFullArgSpecOnPartialNoArgumentsLeft(self):
    """Tests getfullargspec on partial function that prunes all arguments."""

    def func(m, n):
      return 2 * m + n

    partial_func = functools.partial(func, 7, 10)
    argspec = tf_inspect.FullArgSpec(
        args=[],
        varargs=None,
        varkw=None,
        defaults=None,
        kwonlyargs=[],
        kwonlydefaults=None,
        annotations={})

    self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func))
Пример #25
0
 def testPositionsMatchArgGiven(self):
     full_dict = tf_upgrade_v2.TFAPIChangeSpec().function_arg_warnings
     method_names = full_dict.keys()
     for method_name in method_names:
         args = full_dict[method_name].keys()
         # special case for optimizer methods
         if method_name.startswith("*."):
             method = method_name.replace("*", "tf.train.Optimizer")
         else:
             method = method_name
         method = get_symbol_for_name(tf, method)
         arg_spec = tf_inspect.getfullargspec(method)
         for (arg, pos) in args:
             # to deal with the self argument on methods on objects
             if method_name.startswith("*."):
                 pos += 1
             self.assertEqual(arg_spec[0][pos], arg)
Пример #26
0
  def __init__(self, function, output_shape=None, mask=None, arguments=None,
               **kwargs):
    super(Lambda, self).__init__(**kwargs)
    self.function = function
    self.arguments = arguments if arguments else {}
    if mask is not None:
      self.supports_masking = True
    self.mask = mask
    self._output_shape = output_shape
    self._variable_dict = {}
    # These attributes are inherited from `Layer`.
    self._trainable_weights = []
    self._non_trainable_weights = []

    function_args = tf_inspect.getfullargspec(self.function).args
    self._fn_expects_training_arg = 'training' in function_args
    self._fn_expects_mask_arg = 'mask' in function_args
Пример #27
0
  def testGetFullArgSpecOnNewClass(self):

    class NewClass(object):

      def __new__(cls, a, b=1, c='hello'):
        pass

    argspec = tf_inspect.FullArgSpec(
        args=['cls', 'a', 'b', 'c'],
        varargs=None,
        varkw=None,
        defaults=(1, 'hello'),
        kwonlyargs=[],
        kwonlydefaults=None,
        annotations={})

    self.assertEqual(argspec, tf_inspect.getfullargspec(NewClass))
Пример #28
0
  def testGetFullArgSpecOnPartialWithVarargs(self):
    """Tests getfullargspec on partial function with variable arguments."""

    def func(m, *arg):
      return m + len(arg)

    partial_func = functools.partial(func, 7, 8)
    argspec = tf_inspect.FullArgSpec(
        args=[],
        varargs='arg',
        varkw=None,
        defaults=None,
        kwonlyargs=[],
        kwonlydefaults=None,
        annotations={})

    self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func))
Пример #29
0
def deserialize_keras_object(identifier,
                             module_objects=None,
                             custom_objects=None,
                             printable_module_name='object'):
    if identifier is None:
        return None
    if isinstance(identifier, dict):
        # In this case we are dealing with a Keras config dictionary.
        config = identifier
        (cls, cls_config) = class_and_config_for_serialized_keras_object(
            config, module_objects, custom_objects, printable_module_name)

        if hasattr(cls, 'from_config'):
            arg_spec = tf_inspect.getfullargspec(cls.from_config)
            custom_objects = custom_objects or {}

            if 'custom_objects' in arg_spec.args:
                return cls.from_config(
                    cls_config,
                    custom_objects=dict(
                        list(_GLOBAL_CUSTOM_OBJECTS.items()) +
                        list(custom_objects.items())))
            with CustomObjectScope(custom_objects):
                return cls.from_config(cls_config)
        else:
            # Then `cls` may be a function returning a class.
            # in this case by convention `config` holds
            # the kwargs of the function.
            custom_objects = custom_objects or {}
            with CustomObjectScope(custom_objects):
                return cls(**cls_config)
    elif isinstance(identifier, six.string_types):
        function_name = identifier
        if custom_objects and function_name in custom_objects:
            fn = custom_objects.get(function_name)
        elif function_name in _GLOBAL_CUSTOM_OBJECTS:
            fn = _GLOBAL_CUSTOM_OBJECTS[function_name]
        else:
            fn = module_objects.get(function_name)
            if fn is None:
                raise ValueError('Unknown ' + printable_module_name + ':' +
                                 function_name)
        return fn
    else:
        raise ValueError('Could not interpret serialized ' +
                         printable_module_name + ': ' + identifier)
Пример #30
0
  def __init__(self, handle, output_shape, trainable=False, arguments=None,
               **kwargs):
    # Resolve the handle to a callable `func`.
    if callable(handle):
      self._func = handle
    else:
      self._func = module.load(handle)
      if not callable(self._func):
        raise ValueError("Non-callable result from hub.load('%s')" %
                         str(handle))

    # Set self._{non,}_trainable_weights and then call Layer.__init__.
    # This together with @no_automatic_dependency_tracking above preserves
    # func.trainable_variables independent of tf.Variable(..., trainable=...).
    if hasattr(self._func, "trainable_variables"):
      self._trainable_weights = [v for v in self._func.trainable_variables]
      trainable_variables_set = set(self._func.trainable_variables)
    else:
      self._trainable_weights = []
      trainable_variables_set = set()
    if hasattr(self._func, "variables"):
      self._non_trainable_weights = [v for v in self._func.variables
                                     if v not in trainable_variables_set]
    else:
      self._non_trainable_weights = []
    super(KerasLayer, self).__init__(trainable=trainable, **kwargs)

    # Prepare to call `func`.
    self._func_fullargspec = tf_inspect.getfullargspec(self._func.__call__)
    self._func_wants_training = (
        "training" in self._func_fullargspec.args or
        "training" in self._func_fullargspec.kwonlyargs)
    self._arguments = arguments or {}
    # TODO(b/124219898): We should be able to get the embedding dimension from
    # the restored model.
    self._output_shape = tuple(output_shape)

    # Forward the callable's regularization losses (if any).
    if hasattr(self._func, "regularization_losses"):
      for l in self._func.regularization_losses:
        if not callable(l):
          raise ValueError(
              "hub.KerasLayer(obj) expects obj.regularization_losses to be an "
              "iterable of callables, each returning a scalar loss term.")
        self.add_loss(l)  # Supports callables.
Пример #31
0
    def decorated(self, **kwargs):
      """A wrapped test method that sets up `test_function`."""
      assert "mode" in kwargs
      mode = kwargs["mode"]

      if "distribution" in kwargs:
        distribution = kwargs["distribution"]
        kwargs["distribution"] = distribution.strategy
        if distribution.required_tpu and not TPU_TEST:
          self.skipTest("Test requires a TPU, but it's not available.")
        if not distribution.required_tpu and TPU_TEST:
          self.skipTest("Test that doesn't require a TPU.")

        if not distribution.required_gpus:
          if GPU_TEST:
            self.skipTest("Test that doesn't require GPUs.")
        elif context.num_gpus() < distribution.required_gpus:
          self.skipTest(
              "{} GPUs are not available for this test. {} GPUs are available".
              format(distribution.required_gpus, context.num_gpus()))

      requested_arguments = tf_inspect.getfullargspec(test_function).args
      missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
          set(requested_arguments + ["mode"]))
      if missing_arguments:
        raise ValueError("The test is missing arguments {} .".format(
            missing_arguments))

      kwargs_to_pass = {}
      for arg in requested_arguments:
        if arg == "self":
          kwargs_to_pass[arg] = self
        else:
          kwargs_to_pass[arg] = kwargs[arg]

      if mode == "eager":
        with context.eager_mode(), ops.Graph().as_default():
          test_function(**kwargs_to_pass)
      elif mode == "graph":
        with context.graph_mode(), ops.Graph().as_default():
          test_function(**kwargs_to_pass)
      else:
        raise ValueError(
            "'mode' has to be either 'eager' or 'graph' and not {}".format(
                mode))
Пример #32
0
    def decorated(self, **kwargs):
      """A wrapped test method that sets up `test_function`."""
      assert "mode" in kwargs
      mode = kwargs["mode"]

      if "distribution" in kwargs:
        distribution = kwargs["distribution"]
        kwargs["distribution"] = distribution.strategy
        if distribution.required_tpu and not TPU_TEST:
          self.skipTest("Test requires a TPU, but it's not available.")
        if not distribution.required_tpu and TPU_TEST:
          self.skipTest("Test that doesn't require a TPU.")

        if not distribution.required_gpus:
          if GPU_TEST:
            self.skipTest("Test that doesn't require GPUs.")
        elif context.num_gpus() < distribution.required_gpus:
          self.skipTest(
              "{} GPUs are not available for this test. {} GPUs are available".
              format(distribution.required_gpus, context.num_gpus()))

      requested_arguments = tf_inspect.getfullargspec(test_function).args
      missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
          set(requested_arguments + ["mode"]))
      if missing_arguments:
        raise ValueError("The test is missing arguments {} .".format(
            missing_arguments))

      kwargs_to_pass = {}
      for arg in requested_arguments:
        if arg == "self":
          kwargs_to_pass[arg] = self
        else:
          kwargs_to_pass[arg] = kwargs[arg]

      if mode == "eager":
        with context.eager_mode(), ops.Graph().as_default():
          test_function(**kwargs_to_pass)
      elif mode == "graph":
        with context.graph_mode(), ops.Graph().as_default():
          test_function(**kwargs_to_pass)
      else:
        raise ValueError(
            "'mode' has to be either 'eager' or 'graph' and not {}".format(
                mode))
Пример #33
0
  def _argspec_matches(self, node):
    arg_spec = tf_inspect.getfullargspec(self.fn)

    node_args = tuple(self._arg_name(arg) for arg in node.args.args)
    if node_args != tuple(arg_spec.args):
      return False

    if arg_spec.varargs != self._arg_name(node.args.vararg):
      return False

    if arg_spec.varkw != self._arg_name(node.args.kwarg):
      return False

    node_kwonlyargs = tuple(self._arg_name(arg) for arg in node.args.kwonlyargs)
    if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
      return False

    return True
Пример #34
0
    def testGetFullArgSpecOnPartialWithVarkwargs(self):
        """Tests getfullargspec.

    Tests on partial function with variable keyword arguments.
    """
        def func(m, n, **kwarg):
            return m * n + len(kwarg)

        partial_func = functools.partial(func, 7)
        argspec = tf_inspect.FullArgSpec(args=['n'],
                                         varargs=None,
                                         varkw='kwarg',
                                         defaults=None,
                                         kwonlyargs=[],
                                         kwonlydefaults=None,
                                         annotations={})

        self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func))
Пример #35
0
  def testGetFullArgSpecOnCallableObject(self):

    class Callable(object):

      def __call__(self, a, b=1, c='hello'):
        pass

    argspec = tf_inspect.FullArgSpec(
        args=['self', 'a', 'b', 'c'],
        varargs=None,
        varkw=None,
        defaults=(1, 'hello'),
        kwonlyargs=[],
        kwonlydefaults=None,
        annotations={})

    test_obj = Callable()
    self.assertEqual(argspec, tf_inspect.getfullargspec(test_obj))
Пример #36
0
  def _argspec_matches(self, node):
    arg_spec = tf_inspect.getfullargspec(self.fn)

    node_args = tuple(self._arg_name(arg) for arg in node.args.args)
    if node_args != tuple(arg_spec.args):
      return False

    if arg_spec.varargs != self._arg_name(node.args.vararg):
      return False

    if arg_spec.varkw != self._arg_name(node.args.kwarg):
      return False

    node_kwonlyargs = tuple(self._arg_name(arg) for arg in node.args.kwonlyargs)
    if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
      return False

    return True
Пример #37
0
  def testGetFullArgSpecOnCallableObject(self):

    class Callable(object):

      def __call__(self, a, b=1, c='hello'):
        pass

    argspec = tf_inspect.FullArgSpec(
        args=['self', 'a', 'b', 'c'],
        varargs=None,
        varkw=None,
        defaults=(1, 'hello'),
        kwonlyargs=[],
        kwonlydefaults=None,
        annotations={})

    test_obj = Callable()
    self.assertEqual(argspec, tf_inspect.getfullargspec(test_obj))
Пример #38
0
  def create_exception(self, source_error):
    preferred_type = type(source_error)
    if issubclass(preferred_type, errors_impl.OpError):
      # Best-effort unpacking of OpError exceptions.
      # TODO(mdan): Use a mechanism that is more future-proof.
      init_argspec = tf_inspect.getfullargspec(preferred_type.__init__)
      message = self.get_message()
      init_args = tuple(init_argspec.args)
      # At the time of this writing, TF errors either take 3 or 4 arguments,
      # with the fourth being error_code.
      if init_args == ('self', 'node_def', 'op', 'message', 'error_code'):
        return preferred_type(
            node_def=source_error.node_def,
            op=source_error.op,
            message=message,
            error_code=self.error_code)
      elif init_args == ('self', 'node_def', 'op', 'message'):
        if 'error_code' in init_argspec.kwonlyargs:
          return preferred_type(
              node_def=source_error.node_def,
              op=source_error.op,
              message=message,
              errro_code=self.error_code)
        else:
          return preferred_type(
              node_def=source_error.node_def,
              op=source_error.op,
              message=message)

    elif preferred_type in (errors.PyCTError, AutoGraphError, ConversionError,
                            StagingError, errors_impl.InaccessibleTensorError,
                            errors_impl.OperatorNotAllowedInGraphError):
      return preferred_type(self.get_message())

    exc = super(_ErrorMetadata, self).create_exception(source_error)
    if exc is not None:
      return exc

    # Note: While changing an error's message property to change the message it
    # displays will probably work a lot of times, there is no standard way in
    # Python to do that. The safest way is therefore to create a new exception.
    # For user defined exceptions, we could define an interface that allowed
    # them to work under this mechanism.
    return StagingError(self.get_message())
Пример #39
0
 def __init__(self,
              directory,
              image_data_generator,
              target_size=(256, 256),
              color_mode='rgb',
              classes=None,
              class_mode='categorical',
              batch_size=32,
              shuffle=True,
              seed=None,
              data_format=None,
              save_to_dir=None,
              save_prefix='',
              save_format='png',
              follow_links=False,
              subset=None,
              interpolation='nearest',
              dtype=None):
     if data_format is None:
         data_format = backend.image_data_format()
     kwargs = {}
     if 'dtype' in tf_inspect.getfullargspec(
             image.ImageDataGenerator.__init__)[0]:
         if dtype is None:
             dtype = backend.floatx()
         kwargs['dtype'] = dtype
     super(DirectoryIterator, self).__init__(directory,
                                             image_data_generator,
                                             target_size=target_size,
                                             color_mode=color_mode,
                                             classes=classes,
                                             class_mode=class_mode,
                                             batch_size=batch_size,
                                             shuffle=shuffle,
                                             seed=seed,
                                             data_format=data_format,
                                             save_to_dir=save_to_dir,
                                             save_prefix=save_prefix,
                                             save_format=save_format,
                                             follow_links=follow_links,
                                             subset=subset,
                                             interpolation=interpolation,
                                             **kwargs)
Пример #40
0
def array_to_img(x, data_format=None, scale=True, dtype=None):
    """Converts a 3D Numpy array to a PIL Image instance.

  Usage:

  ```python
  from PIL import Image
  img = np.random.random(size=(100, 100, 3))
  pil_img = tf.keras.preprocessing.image.array_to_img(img)
  ```


  Arguments:
      x: Input Numpy array.
      data_format: Image data format, can be either "channels_first" or
        "channels_last". Defaults to `None`, in which case the global setting
        `tf.keras.backend.image_data_format()` is used (unless you changed it,
        it defaults to "channels_last").
      scale: Whether to rescale image values to be within `[0, 255]`. Defaults
        to `True`.
      dtype: Dtype to use. Default to `None`, in which case the global setting
      `tf.keras.backend.floatx()` is used (unless you changed it, it defaults
      to "float32")

  Returns:
      A PIL Image instance.

  Raises:
      ImportError: if PIL is not available.
      ValueError: if invalid `x` or `data_format` is passed.
  """

    if data_format is None:
        data_format = backend.image_data_format()
    kwargs = {}
    if 'dtype' in tf_inspect.getfullargspec(image.array_to_img)[0]:
        if dtype is None:
            dtype = backend.floatx()
        kwargs['dtype'] = dtype
    return image.array_to_img(x,
                              data_format=data_format,
                              scale=scale,
                              **kwargs)
Пример #41
0
def get_training_arg_index(call_fn):
    """Returns the index of 'training' in the layer call function arguments.

  Args:
    call_fn: Call function.

  Returns:
    - n: index of 'training' in the call function arguments.
    - -1: if 'training' is not found in the arguments, but layer.call accepts
          variable keyword arguments
    - None: if layer doesn't expect a training argument.
  """
    arg_list = tf_inspect.getfullargspec(call_fn).args
    if tf_inspect.ismethod(call_fn):
        arg_list = arg_list[1:]
    if 'training' in arg_list:
        return arg_list.index('training')
    else:
        return -1
Пример #42
0
    def __init__(self, layer):
        self.layer = layer

        self.layer_call_method = _get_layer_call_method(layer)
        self._expects_training_arg = layer_uses_training_bool(layer)
        self._training_arg_index = utils.get_training_arg_index(
            self.layer_call_method)

        # If the layer call function has kwargs, then the traced function cannot
        # have an input signature.
        arg_spec = tf_inspect.getfullargspec(self.layer_call_method)
        self._has_kwargs = bool(self._expects_training_arg or arg_spec.defaults
                                or arg_spec.kwonlyargs or arg_spec.varkw)

        self._input_signature = self._generate_input_signature(layer)
        self._functions = weakref.WeakValueDictionary()
        # Bool indicating whether this object is currently tracing the layer call
        # functions.
        self.tracing = False
Пример #43
0
  def __init__(self, function, output_shape=None, mask=None, arguments=None,
               **kwargs):
    super(Lambda, self).__init__(**kwargs)

    self.arguments = arguments or {}
    self.function = function

    if mask is not None:
      self.supports_masking = True
    self.mask = mask
    self._supports_ragged_inputs = True
    self._output_shape = output_shape

    # Warning on every invocation will be quite irksome in Eager mode.
    self._already_warned = False

    function_args = tf_inspect.getfullargspec(function).args
    self._fn_expects_training_arg = 'training' in function_args
    self._fn_expects_mask_arg = 'mask' in function_args
Пример #44
0
def _eager_mode_decorator(f, *args, **kwargs):
  """Implement custom gradient decorator for eager mode."""
  with backprop.GradientTape() as tape:
    result, grad_fn = f(*args, **kwargs)
  all_inputs = list(args) + list(kwargs.values())
  # The variables that grad_fn needs to return gradients for are the set of
  # variables used that are *not* part of the inputs.
  variables = [v for v in set(tape.watched_variables()) if v not in all_inputs]
  grad_argspec = tf_inspect.getfullargspec(grad_fn)
  if (variables and ("variables" not in grad_argspec.args) and
      not grad_argspec.varkw):
    raise TypeError("If using @custom_gradient with a function that "
                    "uses variables, then grad_fn must accept a keyword "
                    "argument 'variables'.")
  flat_result = nest.flatten(result)
  # TODO(apassos) consider removing the identity below.
  flat_result = [gen_array_ops.identity(x) for x in flat_result]

  input_tensors = [ops.convert_to_tensor(x) for x
                   in list(args) + list(variables)]
  arg_count = len(args)
  def actual_grad_fn(*result_grads):
    """Custom grad fn wrapper."""
    if variables:
      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
      if len(variable_grads) != len(variables):
        raise ValueError("Must return gradient for each variable from "
                         "@custom_gradient grad_fn.")
    else:
      input_grads = grad_fn(*result_grads)
      variable_grads = []
    flat_grads = nest.flatten(input_grads)
    if len(flat_grads) != arg_count:
      raise ValueError(
          "custom_gradient function expected to return", arg_count,
          "gradients but returned", len(flat_grads), "instead.")
    return nest.flatten(input_grads) + variable_grads

  tape_lib.record_operation(f.__name__, flat_result, input_tensors,
                            actual_grad_fn)
  flat_result = list(flat_result)
  return nest.pack_sequence_as(result, flat_result)
Пример #45
0
def _eager_mode_decorator(f, *args, **kwargs):
  """Implement custom gradient decorator for eager mode."""
  with backprop.GradientTape() as tape:
    result, grad_fn = f(*args, **kwargs)
  all_inputs = list(args) + list(kwargs.values())
  # The variables that grad_fn needs to return gradients for are the set of
  # variables used that are *not* part of the inputs.
  variables = [v for v in set(tape.watched_variables()) if v not in all_inputs]
  grad_argspec = tf_inspect.getfullargspec(grad_fn)
  if (variables and ("variables" not in grad_argspec.args) and
      not grad_argspec.varkw):
    raise TypeError("If using @custom_gradient with a function that "
                    "uses variables, then grad_fn must accept a keyword "
                    "argument 'variables'.")
  flat_result = nest.flatten(result)
  # TODO(apassos) consider removing the identity below.
  flat_result = [gen_array_ops.identity(x) for x in flat_result]

  input_tensors = [ops.convert_to_tensor(x) for x
                   in list(args) + list(variables)]
  arg_count = len(args)
  def actual_grad_fn(*result_grads):
    """Custom grad fn wrapper."""
    if variables:
      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
      if len(variable_grads) != len(variables):
        raise ValueError("Must return gradient for each variable from "
                         "@custom_gradient grad_fn.")
    else:
      input_grads = grad_fn(*result_grads)
      variable_grads = []
    flat_grads = nest.flatten(input_grads)
    if len(flat_grads) != arg_count:
      raise ValueError(
          "custom_gradient function expected to return", arg_count,
          "gradients but returned", len(flat_grads), "instead.")
    return nest.flatten(input_grads) + variable_grads

  tape_lib.record_operation(f.__name__, flat_result, input_tensors,
                            actual_grad_fn)
  flat_result = list(flat_result)
  return nest.pack_sequence_as(result, flat_result)
Пример #46
0
def _node_matches_argspec(node, func):
  """Returns True is node fits the argspec of func."""
  # TODO(mdan): Use just inspect once support for Python 2 is dropped.
  arg_spec = tf_inspect.getfullargspec(func)

  node_args = tuple(_arg_name(arg) for arg in node.args.args)
  if node_args != tuple(arg_spec.args):
    return False

  if arg_spec.varargs != _arg_name(node.args.vararg):
    return False

  if arg_spec.varkw != _arg_name(node.args.kwarg):
    return False

  node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs)
  if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
    return False

  return True
Пример #47
0
    def _argspec_compatible(self, node):
        arg_spec = tf_inspect.getfullargspec(self.fn)

        node_args = tuple(self._arg_name(arg) for arg in node.args.args)
        if len(node_args) != len(arg_spec.args) and node.args.vararg is None:
            return False

        if arg_spec.varargs is not None and node.args.vararg is None:
            return False

        if arg_spec.varkw is not None and node.args.kwarg is None:
            return False

        node_kwonlyargs = tuple(
            self._arg_name(arg) for arg in node.args.kwonlyargs)
        if (len(node_kwonlyargs) != len(arg_spec.kwonlyargs)
                and node.args.kwarg is None):
            return False

        return True
Пример #48
0
def has_kwargs(fn):
    """Returns whether the passed callable has **kwargs in its signature.

  Args:
    fn: Function, or function-like object (e.g., result of `functools.partial`).

  Returns:
    `bool`: if `fn` has **kwargs in its signature.

  Raises:
     `TypeError`: If fn is not a Function, or function-like object.
  """
    if isinstance(fn, functools.partial):
        fn = fn.func
    elif _is_callable_object(fn):
        fn = fn.__call__
    elif not callable(fn):
        raise TypeError('Argument `fn` should be a callable. '
                        f'Received: fn={fn} (of type {type(fn)})')
    return tf_inspect.getfullargspec(fn).varkw is not None
Пример #49
0
  def testGetFullArgSpecOnPartialWithVarkwargs(self):
    """Tests getfullargspec.

    Tests on partial function with variable keyword arguments.
    """

    def func(m, n, **kwarg):
      return m * n + len(kwarg)

    partial_func = functools.partial(func, 7)
    argspec = tf_inspect.FullArgSpec(
        args=['n'],
        varargs=None,
        varkw='kwarg',
        defaults=None,
        kwonlyargs=[],
        kwonlydefaults=None,
        annotations={})

    self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func))
Пример #50
0
def get_explicit_name_for_component(d):
    """Returns the explicitly-passed `name` of a Distribution, or None."""
    name = d.parameters.get('name', None)
    if name and d.__class__.__name__ in name:
        name = None

    if name and hasattr(d, '__init__'):
        spec = tf_inspect.getfullargspec(d.__init__)
        default_name = dict(
            zip(spec.args[len(spec.args) - len(spec.defaults or ()):],
                spec.defaults or ())).get('name', None)
        if name == default_name:
            name = None

    if name in FORBIDDEN_COMPONENT_NAMES:
        raise ValueError(
            'Distribution name "{}" is not allowed as a '
            'JointDistribution component; please choose a different '
            'name.'.format(name))
    return name
Пример #51
0
def array_to_img(x, data_format=None, scale=True, dtype=None):
    """Converts a 3D Numpy array to a PIL Image instance.

  Usage:

  >>> img = np.random.random(size=(100, 100, 3))
  >>> try:
  ...   from PIL import Image
  ...   pil_img = tf.keras.preprocessing.image.array_to_img(img)
  ... except ImportError:
  ...   pass

  Arguments:
      x: Input Numpy array.
      data_format: Image data format, can be either "channels_first" or
        "channels_last". Defaults to `None`, which gets data format from Keras
        backend.
      scale: Whether to rescale image values to be within `[0, 255]`. Defaults
        to `True`.
      dtype: Dtype to use. Default to `None`, which gets float type from Keras
        backend.

  Returns:
      A PIL Image instance.

  Raises:
      ImportError: if PIL is not available.
      ValueError: if invalid `x` or `data_format` is passed.
  """

    if data_format is None:
        data_format = backend.image_data_format()
    kwargs = {}
    if 'dtype' in tf_inspect.getfullargspec(image.array_to_img)[0]:
        if dtype is None:
            dtype = backend.floatx()
        kwargs['dtype'] = dtype
    return image.array_to_img(x,
                              data_format=data_format,
                              scale=scale,
                              **kwargs)
Пример #52
0
 def __init__(self, directory, image_data_generator,
              target_size=(256, 256),
              color_mode='rgb',
              classes=None,
              class_mode='categorical',
              batch_size=32,
              shuffle=True,
              seed=None,
              data_format=None,
              save_to_dir=None,
              save_prefix='',
              save_format='png',
              follow_links=False,
              subset=None,
              interpolation='nearest',
              dtype=None):
   if data_format is None:
     data_format = backend.image_data_format()
   kwargs = {}
   if 'dtype' in tf_inspect.getfullargspec(
       image.ImageDataGenerator.__init__)[0]:
     if dtype is None:
       dtype = backend.floatx()
     kwargs['dtype'] = dtype
   super(DirectoryIterator, self).__init__(
       directory, image_data_generator,
       target_size=target_size,
       color_mode=color_mode,
       classes=classes,
       class_mode=class_mode,
       batch_size=batch_size,
       shuffle=shuffle,
       seed=seed,
       data_format=data_format,
       save_to_dir=save_to_dir,
       save_prefix=save_prefix,
       save_format=save_format,
       follow_links=follow_links,
       subset=subset,
       interpolation=interpolation,
       **kwargs)
Пример #53
0
def has_kwargs(fn):
  """Returns whether the passed callable has **kwargs in its signature.

  Args:
    fn: Function, or function-like object (e.g., result of `functools.partial`).

  Returns:
    `bool`: if `fn` has **kwargs in its signature.

  Raises:
     `TypeError`: If fn is not a Function, or function-like object.
  """
  if isinstance(fn, functools.partial):
    fn = fn.func
  elif _is_callable_object(fn):
    fn = fn.__call__
  elif not callable(fn):
    raise TypeError(
        'fn should be a function-like object, but is of type {}.'.format(
            type(fn)))
  return tf_inspect.getfullargspec(fn).varkw is not None
Пример #54
0
def _ragged_op_signature(op, ragged_args):
  """Returns a signature for the given op, marking ragged args in bold."""
  op_name = tf_export.get_canonical_name_for_symbol(op)
  argspec = tf_inspect.getfullargspec(op)
  arg_names = argspec.args

  # Mark ragged arguments in bold.
  for pos in ragged_args:
    arg_names[pos] = '**' + arg_names[pos] + '**'

  # Add argument defaults.
  for pos in range(-1, -len(argspec.defaults) - 1, -1):
    arg_names[pos] += '=`{!r}`'.format(argspec.defaults[pos])

  # Add varargs and keyword args
  if argspec.varargs:
    arg_names.append('*' + argspec.varargs)
  if argspec.varkw:
    arg_names.append('**' + argspec.varkw)

  return '* `tf.{}`({})'.format(op_name, ', '.join(arg_names))
Пример #55
0
def _ragged_op_signature(op, ragged_args):
    """Returns a signature for the given op, marking ragged args in bold."""
    op_name = tf_export.get_canonical_name_for_symbol(op)
    argspec = tf_inspect.getfullargspec(op)
    arg_names = argspec.args

    # Mark ragged arguments in bold.
    for pos in ragged_args:
        arg_names[pos] = '**' + arg_names[pos] + '**'

    # Add argument defaults.
    for pos in range(-1, -len(argspec.defaults) - 1, -1):
        arg_names[pos] += '=`{!r}`'.format(argspec.defaults[pos])

    # Add varargs and keyword args
    if argspec.varargs:
        arg_names.append('*' + argspec.varargs)
    if argspec.varkw:
        arg_names.append('**' + argspec.varkw)

    return '* `tf.{}`({})'.format(op_name, ', '.join(arg_names))
Пример #56
0
        def init_shard_fn(shard_index):
            if not init_from_fn:
                logging.log_if(
                    logging.WARN, _INEFFICIENT_INIT_WARNING % name,
                    shard_index == 0
                    and shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
                return initial_value[offsets[shard_index]:offsets[shard_index +
                                                                  1]]
            partition_shape = (offsets[shard_index + 1] -
                               offsets[shard_index], ) + shape[1:]
            partition_offset = (
                offsets[shard_index], ) + (0, ) * len(shape[1:])
            arg_spec = tf_inspect.getfullargspec(initial_value)
            if ("shard_info" not in arg_spec.args
                    and "shard_info" not in arg_spec.kwonlyargs):
                try:
                    value = initial_value(partition_shape=partition_shape,
                                          partition_offset=partition_offset)
                except (TypeError, ValueError):
                    # TypeError: Initializer doesn't accept kwargs
                    # ValueError: Initializer doesn't accept partition kwargs
                    # In both cases we go ahead creating the full value and then slice.
                    value = initial_value()

                if value.shape == partition_shape:
                    # Initializer supports partition: value is the partition value.
                    return value
                else:
                    # Initializer doesn't support partition: value is the full value
                    # and needs to be sliced to get the partition value.
                    logging.log_if(
                        logging.WARN, _INEFFICIENT_INIT_WARNING % name,
                        shard_index == 0 and
                        shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
                    return value[offsets[shard_index]:offsets[shard_index + 1]]
            else:
                # For compatibility with `CheckpointInitialValueCallable`.
                return initial_value(shard_info=trackable.ShardInfo(
                    shape=tensor_shape.as_shape(partition_shape),
                    offset=partition_offset))
Пример #57
0
def get_training_arg_index(layer):
    """Returns the index of 'training' in the layer call function arguments.

  Args:
    layer: Keras layer

  Returns:
    - n: index of 'training' in the call function arguments.
    - -1: if 'training' is not found in the arguments, but layer.call accepts
          variable keyword arguments
    - None: if layer doesn't expect a training argument.
  """
    if not layer._expects_training_arg:  # pylint: disable=protected-access
        return None

    arg_list = tf_inspect.getfullargspec(layer.call).args
    if tf_inspect.ismethod(layer.call):
        arg_list = arg_list[1:]
    if 'training' in arg_list:
        return arg_list.index('training')
    else:
        return -1
Пример #58
0
  def visit_Lambda(self, node):
    self.generic_visit(node)

    arg_spec = tf_inspect.getfullargspec(self.lambda_fn)

    node_args = tuple(arg.id for arg in node.args.args)
    if node_args != tuple(arg_spec.args):
      return

    node_varargs = None if node.args.vararg is None else node.args.vararg.arg
    if arg_spec.varargs != node_varargs:
      return

    node_varkw = None if node.args.kwarg is None else node.args.kwarg.arg
    if arg_spec.varkw != node_varkw:
      return

    node_kwonlyargs = tuple(arg.id for arg in node.args.kwonlyargs)
    if node_kwonlyargs != tuple(arg_spec.kwonlyargs):
      return

    self.matching_nodes.append(node)
Пример #59
0
def fn_args(fn):
    """Get argument names for function-like object.

  Args:
    fn: Function, or function-like object (e.g., result of `functools.partial`).

  Returns:
    `tuple` of string argument names.

  Raises:
    ValueError: if partial function has positionally bound arguments
  """
    if isinstance(fn, functools.partial):
        args = fn_args(fn.func)
        args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
    else:
        if _is_callable_object(fn):
            fn = fn.__call__
        args = tf_inspect.getfullargspec(fn).args
        if _is_bounded_method(fn):
            args.remove('self')
    return tuple(args)
Пример #60
0
 def _recreate_layer_from_config(self, layer, go_backwards=False):
   # When recreating the layer from its config, it is possible that the layer
   # is a RNN layer that contains custom cells. In this case we inspect the
   # layer and pass the custom cell class as part of the `custom_objects`
   # argument when calling `from_config`.
   # See https://github.com/tensorflow/tensorflow/issues/26581 for more detail.
   config = layer.get_config()
   if go_backwards:
     config['go_backwards'] = not config['go_backwards']
   if 'custom_objects' in tf_inspect.getfullargspec(
       layer.__class__.from_config).args:
     custom_objects = {}
     cell = getattr(layer, 'cell', None)
     if cell is not None:
       custom_objects[cell.__class__.__name__] = cell.__class__
       # For StackedRNNCells
       stacked_cells = getattr(cell, 'cells', [])
       for c in stacked_cells:
         custom_objects[c.__class__.__name__] = c.__class__
     return layer.__class__.from_config(config, custom_objects=custom_objects)
   else:
     return layer.__class__.from_config(config)