def __call__(self, *args): """Handles the different modes of usage of the decorator/wrapper. This method only acts as a frontend that allows this class to be used as a decorator or wrapper in a variety of ways. The actual wrapping is performed by the private method `_wrap`. Args: *args: Positional arguments (the decorator at this point does not accept keyword arguments, although that might change in the future). Returns: Either a result of wrapping, or a callable that expects a function or a defun and performs wrapping on it, depending on specific usage pattern. Raises: TypeError: if the arguments are of the wrong types. """ if not args: # If invoked as a decorator, and with an empty argument list as "@xyz()" # applied to a function definition, expect the Python function being # decorated to be passed in the subsequent call, and potentially create # a polymorphic callable. The parameter type is unspecified. # Deliberate wrapping with a lambda to prevent the caller from being able # to accidentally specify parameter type as a second argument. return lambda fn: _wrap(fn, None, self._wrapper_fn) elif (isinstance(args[0], types.FunctionType) or function_utils.is_defun(args[0])): # If the first argument on the list is a Python function or a defun, this # is the one that's being wrapped. This is the case of either a decorator # invocation without arguments as "@xyz" applied to a function definition, # of an inline invocation as "... = xyz(lambda....). Any of the following # arguments, if present, are the arguments to the wrapper that are to be # interpreted as the type specification. if len(args) > 2: args = (args[0], args[1:]) return _wrap( args[0], computation_types.to_type(args[1]) if len(args) > 1 else None, self._wrapper_fn) else: if len(args) > 1: args = (args, ) arg_type = computation_types.to_type(args[0]) return lambda fn: _wrap(fn, arg_type, self._wrapper_fn)
def test_is_defun(self): self.assertTrue(function_utils.is_defun(tf.function(lambda x: None))) fn = tf.function(lambda x: None, (tf.TensorSpec(None, tf.int32), )) self.assertTrue(function_utils.is_defun(fn)) self.assertFalse(function_utils.is_defun(lambda x: None)) self.assertFalse(function_utils.is_defun(None))