def from_function_and_signature(cls, python_function, input_signature, is_pure=False, experimental_follow_type_hints=False, jit_compile=None): """Creates a FunctionSpec instance given a python function and signature. Args: python_function: a function to inspect input_signature: a signature of the function (None, if variable) is_pure: if True all input arguments (including variables and constants) will be converted to tensors and no variable changes allowed. experimental_follow_type_hints: see `tf.function` jit_compile: see `tf.function` Returns: instance of FunctionSpec """ fullargspec = tf_inspect.getfullargspec(python_function) if (input_signature is not None and set(fullargspec.kwonlyargs) - set(fullargspec.kwonlydefaults or ())): nodefault_kwonlyargs = set(fullargspec.kwonlyargs) if fullargspec.kwonlydefaults is not None: nodefault_kwonlyargs -= set(fullargspec.kwonlydefaults) raise ValueError("Cannot build TF function from " f"{python_function.__name__}: keyword-only arguments " "must have default values when input_signature is " "provided. Got keyword-only arguments without default " f"values: {sorted(nodefault_kwonlyargs)}.") # Checks if the `fullargspec` contains self or cls as its first argument. is_method = tf_inspect.isanytargetmethod(python_function) # Treat a wrapped partial function as a special case. For all arguments that # were overridden with keywords in the partial: # - remove the corresponding arguments, # - remove the corresponding keywords. _, unwrapped = tf_decorator.unwrap(python_function) if isinstance(unwrapped, functools.partial): # Also consider the Python3 case with kwonlydefaults. if fullargspec.defaults or fullargspec.kwonlydefaults: new_defaults = fullargspec.defaults new_args = fullargspec.args if fullargspec.defaults: # To be able to canonicalize the function properly, we want to ignore # default values that are overridden via a partial kwarg. For example: # # def func(a, b, c, d=5, e=7): # return a, b, c, d, e # p_func = tf.function(functools.partial(func, 10, e=9)) # # Here we want to drop from the defaults the parameter `e`. If we # forwarded the call to the partial function with a default for `e` # we would get an error for passing two values for one parameter. # # Note that this has a limitation: we can only override parameters at # the end of the parameter list. # # In this case we want to end up with 3 arguments (b, c, d) and 1 # default value (5). We do this by constructing a mask where 0 stands # for a value that was overridden by a partial kwarg. The seemingly # complicated logic below does just that - for arguments (b, c, d, e) # we would get a mask (1, 1, 1, 0). old_args = fullargspec.args old_defaults = fullargspec.defaults no_default = object() num_args_without_defaults = len(old_args) - len(old_defaults) left_padding = tuple([no_default] * num_args_without_defaults) args_with_defaults = zip(old_args, left_padding + old_defaults) # Create a mask where 0 stands for args that had a partial kwarg # defined. non_keyword_defaults_mask = [ 0 if key in unwrapped.keywords else 1 for key in old_args ] # Keep only arguments and defaults that were not kwargs of partial. new_args_with_defaults = list( itertools.compress(args_with_defaults, non_keyword_defaults_mask)) # Keep all args. new_args = [arg for arg, _ in new_args_with_defaults] # Keep only real default values. new_defaults = [ default for _, default in new_args_with_defaults if default is not no_default ] fullargspec = tf_inspect.FullArgSpec( args=new_args, varargs=fullargspec.varargs, varkw=fullargspec.varkw, defaults=new_defaults, kwonlyargs=[], kwonlydefaults={}, annotations=fullargspec.annotations) # Get the function's name. Remove functools.partial wrappers if necessary. while isinstance(python_function, functools.partial): python_function = python_function.func name = getattr(python_function, "__name__", "f") return FunctionSpec( fullargspec, is_method, input_signature, is_pure=is_pure, jit_compile=jit_compile, experimental_follow_type_hints=experimental_follow_type_hints, name=name)
def testIsAnyTargetMethod(self): class MyModule: def f(self, a): pass def __call__(self): pass module = MyModule() self.assertTrue(tf_inspect.isanytargetmethod(module)) f = module.f self.assertTrue(tf_inspect.isanytargetmethod(f)) f = functools.partial(f, 1) self.assertTrue(tf_inspect.isanytargetmethod(f)) f = test_decorator('tf_decorator1')(f) self.assertTrue(tf_inspect.isanytargetmethod(f)) f = test_decorator('tf_decorator2')(f) self.assertTrue(tf_inspect.isanytargetmethod(f)) class MyModule2: pass module = MyModule2() self.assertFalse(tf_inspect.isanytargetmethod(module)) def f2(): pass self.assertFalse(tf_inspect.isanytargetmethod(f2)) f2 = functools.partial(f2, 1) self.assertFalse(tf_inspect.isanytargetmethod(f2)) f2 = test_decorator('tf_decorator1')(f2) self.assertFalse(tf_inspect.isanytargetmethod(f2)) f2 = test_decorator('tf_decorator2')(f2) self.assertFalse(tf_inspect.isanytargetmethod(f2)) self.assertFalse(tf_inspect.isanytargetmethod(lambda: None)) self.assertFalse(tf_inspect.isanytargetmethod(None)) self.assertFalse(tf_inspect.isanytargetmethod(1))