コード例 #1
0
    def __call__(self, *args):
        """Handles the different modes of usage of the decorator/wrapper.

    This method only acts as a frontend that allows this class to be used as a
    decorator or wrapper in a variety of ways. The actual wrapping is performed
    by the private method `_wrap`.

    Args:
      *args: Positional arguments (the decorator at this point does not accept
        keyword arguments, although that might change in the future).

    Returns:
      Either a result of wrapping, or a callable that expects a function or a
      defun and performs wrapping on it, depending on specific usage pattern.

    Raises:
      TypeError: if the arguments are of the wrong types.
    """
        if not args:
            # If invoked as a decorator, and with an empty argument list as "@xyz()"
            # applied to a function definition, expect the Python function being
            # decorated to be passed in the subsequent call, and potentially create
            # a polymorphic callable. The parameter type is unspecified.
            # Deliberate wrapping with a lambda to prevent the caller from being able
            # to accidentally specify parameter type as a second argument.
            return lambda fn: _wrap(fn, None, self._wrapper_fn)
        elif (isinstance(args[0], types.FunctionType)
              or function_utils.is_defun(args[0])):
            # If the first argument on the list is a Python function or a defun, this
            # is the one that's being wrapped. This is the case of either a decorator
            # invocation without arguments as "@xyz" applied to a function definition,
            # of an inline invocation as "... = xyz(lambda....). Any of the following
            # arguments, if present, are the arguments to the wrapper that are to be
            # interpreted as the type specification.
            if len(args) > 2:
                args = (args[0], args[1:])
            return _wrap(
                args[0],
                computation_types.to_type(args[1]) if len(args) > 1 else None,
                self._wrapper_fn)
        else:
            if len(args) > 1:
                args = (args, )
            arg_type = computation_types.to_type(args[0])
            return lambda fn: _wrap(fn, arg_type, self._wrapper_fn)
コード例 #2
0
 def test_is_defun(self):
     self.assertTrue(function_utils.is_defun(tf.function(lambda x: None)))
     fn = tf.function(lambda x: None, (tf.TensorSpec(None, tf.int32), ))
     self.assertTrue(function_utils.is_defun(fn))
     self.assertFalse(function_utils.is_defun(lambda x: None))
     self.assertFalse(function_utils.is_defun(None))