def test_is_defun(self):
   self.assertTrue(function_utils.is_defun(function.Defun()(lambda x: None)))
   self.assertTrue(
       function_utils.is_defun(function.Defun(tf.int32)(lambda x: None)))
   self.assertFalse(function_utils.is_defun(function.Defun))
   self.assertFalse(function_utils.is_defun(lambda x: None))
   self.assertFalse(function_utils.is_defun(None))
  def __call__(self, *args):
    """Handles the differents modes of usage of the decorator/wrapper.

    This method only acts as a frontend that allows this class to be used as a
    decorator or wrapper in a variety of ways. The actual wrapping is performed
    by the private method `_wrap`.

    Args:
      *args: Positional arguments (the decorator at this point does not accept
        keyword arguments, although that might change in the future).

    Returns:
      Either a result of wrapping, or a callable that expects a function or a
      defun and performs wrapping on it, depending on specific usage pattern.

    Raises:
      TypeError: if the arguments are of the wrong types.
    """
    if not args:
      # If invoked as a decorator, and with an empty argument list as "@xyz()"
      # applied to a function definition, expect the Python function being
      # decorated to be passed in the subsequent call, and potentially create
      # a polymorphic callable. The parameter type is unspecified.
      # Deliberate wrapping with a lambda to prevent the caller from being able
      # to accidentally specify parameter type as a second argument.
      return lambda fn: _wrap(fn, None, self._wrapper_fn)
    elif (isinstance(args[0], types.FunctionType) or
          function_utils.is_defun(args[0])):
      # If the first argument on the list is a Python function or a defun, this
      # is the one that's being wrapped. This is the case of either a decorator
      # invocation without arguments as "@xyz" applied to a function definition,
      # of an inline invocation as "... = xyz(lambda....). Any of the following
      # arguments, if present, are the arguments to the wrapper that are to be
      # interpreted as the type specification.
      if len(args) > 2:
        args = (args[0], args[1:])
      return _wrap(
          args[0],
          computation_types.to_type(args[1]) if len(args) > 1 else None,
          self._wrapper_fn)
    else:
      if len(args) > 1:
        args = (args,)
      arg_type = computation_types.to_type(args[0])
      return lambda fn: _wrap(fn, arg_type, self._wrapper_fn)
Exemple #3
0
 def test_is_defun(self):
     self.assertTrue(function_utils.is_defun(tf.function(lambda x: None)))
     fn = tf.function(lambda x: None, (tf.TensorSpec(None, tf.int32), ))
     self.assertTrue(function_utils.is_defun(fn))
     self.assertFalse(function_utils.is_defun(lambda x: None))
     self.assertFalse(function_utils.is_defun(None))