def _internal_py_func(func, inp, Tout, stateful=None, eager=False, is_grad_func=False, name=None): """See documentation for py_func and eager_py_func.""" if not callable(func): raise ValueError("Expected func to be callable, got func of type {}".format( type(func))) is_list_or_tuple = False if isinstance(Tout, (list, tuple)): is_list_or_tuple = True else: Tout = [Tout] if eager: func = EagerFunc(func, Tout, is_grad_func) token = _py_funcs.insert(func) # We tie the registered function's lifetime with the current default graph, # i.e., when the current graph is destroyed, we remove its py funcs. graph = ops.get_default_graph() while True: current_graph = graph if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access graph = graph._outer_graph # pylint: disable=protected-access elif isinstance(graph, func_graph.FuncGraph): graph = graph.outer_graph if graph is current_graph: break # TODO(zhifengc): Consider adding a Graph method to collect # `cleanup` objects in one of its member. if not hasattr(graph, "_py_funcs_used_in_graph"): graph._py_funcs_used_in_graph = [] # pylint: disable=protected-access # Store a reference to the function in the graph to ensure it stays alive # as long as the graph lives. When the graph is destroyed, the function # is left to the garbage collector for destruction as well. graph._py_funcs_used_in_graph.append(func) # pylint: disable=protected-access if eager: result = gen_script_ops.eager_py_func( input=inp, token=token, is_async=context.is_async(), Tout=Tout, name=name) else: if stateful: result = gen_script_ops.py_func( input=inp, token=token, Tout=Tout, name=name) else: result = gen_script_ops.py_func_stateless( input=inp, token=token, Tout=Tout, name=name) return result if is_list_or_tuple else result[0]
def _internal_py_func(func, inp, Tout, stateful=None, eager=False, is_grad_func=False, name=None): """See documentation for py_func and eager_py_func.""" is_list_or_tuple = False if isinstance(Tout, (list, tuple)): is_list_or_tuple = True else: Tout = [Tout] if eager: func = EagerFunc(func, Tout, is_grad_func) token = _py_funcs.insert(func) # We tie the registered function's lifetime with the current default graph, # i.e., when the current graph is destroyed, we remove its py funcs. graph = ops.get_default_graph() # pylint: disable=protected-access while isinstance(graph, function._FuncGraph): # If the py_func was declared inside a _FuncGraph, its lifetime should be # bound to that of the outer graph instead. graph = graph._outer_graph # TODO(zhifengc): Consider adding a Graph method to collect # `cleanup` objects in one of its member. if not hasattr(graph, "_py_funcs_used_in_graph"): graph._py_funcs_used_in_graph = [] # Store a reference to the function in the graph to ensure it stays alive # as long as the graph lives. When the graph is destroyed, the function # is left to the garbage collector for destruction as well. graph._py_funcs_used_in_graph.append(func) # pylint: enable=protected-access if eager: result = gen_script_ops.eager_py_func(input=inp, token=token, Tout=Tout, name=name) else: if stateful: result = gen_script_ops.py_func(input=inp, token=token, Tout=Tout, name=name) else: result = gen_script_ops.py_func_stateless(input=inp, token=token, Tout=Tout, name=name) return result if is_list_or_tuple else result[0]
def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None): """See documentation for py_func and eager_py_func.""" is_list_or_tuple = False if isinstance(Tout, (list, tuple)): is_list_or_tuple = True else: Tout = [Tout] if eager: func = EagerFunc(func, Tout) token = _py_funcs.insert(func) # We tie the registered function's lifetime with the current default graph, # i.e., when the current graph is destroyed, we remove its py funcs. graph = ops.get_default_graph() # pylint: disable=protected-access while isinstance(graph, function._FuncGraph): # If the py_func was declared inside a _FuncGraph, its lifetime should be # bound to that of the outer graph instead. graph = graph._outer_graph cleanup = CleanupFunc(token) # TODO(zhifengc): Consider adding a Graph method to collect # `cleanup` objects in one of its member. if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"): graph._cleanup_py_funcs_used_in_graph = [] # When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph # will be destroyed and their __del__ will remove the 'token' from # the funcs registry. graph._cleanup_py_funcs_used_in_graph.append(cleanup) # pylint: enable=protected-access if eager: result = gen_script_ops.eager_py_func(input=inp, token=token, Tout=Tout, name=name) else: if stateful: result = gen_script_ops.py_func(input=inp, token=token, Tout=Tout, name=name) else: result = gen_script_ops.py_func_stateless(input=inp, token=token, Tout=Tout, name=name) return result if is_list_or_tuple else result[0]
def _internal_py_func(func, inp, Tout, stateful=None, eager=False, is_grad_func=False, name=None): """See documentation for py_func and eager_py_func.""" is_list_or_tuple = False if isinstance(Tout, (list, tuple)): is_list_or_tuple = True else: Tout = [Tout] if eager: func = EagerFunc(func, Tout, is_grad_func) token = _py_funcs.insert(func) # We tie the registered function's lifetime with the current default graph, # i.e., when the current graph is destroyed, we remove its py funcs. graph = ops.get_default_graph() # pylint: disable=protected-access while isinstance(graph, function._FuncGraph): # If the py_func was declared inside a _FuncGraph, its lifetime should be # bound to that of the outer graph instead. graph = graph._outer_graph # TODO(zhifengc): Consider adding a Graph method to collect # `cleanup` objects in one of its member. if not hasattr(graph, "_py_funcs_used_in_graph"): graph._py_funcs_used_in_graph = [] # Store a reference to the function in the graph to ensure it stays alive # as long as the graph lives. When the graph is destroyed, the function # is left to the garbage collector for destruction as well. graph._py_funcs_used_in_graph.append(func) # pylint: enable=protected-access if eager: result = gen_script_ops.eager_py_func( input=inp, token=token, Tout=Tout, name=name) else: if stateful: result = gen_script_ops.py_func( input=inp, token=token, Tout=Tout, name=name) else: result = gen_script_ops.py_func_stateless( input=inp, token=token, Tout=Tout, name=name) return result if is_list_or_tuple else result[0]
def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None): """See documentation for py_func and eager_py_func.""" is_list_or_tuple = False if isinstance(Tout, (list, tuple)): is_list_or_tuple = True else: Tout = [Tout] if eager: func = EagerFunc(func, Tout) token = _py_funcs.insert(func) # We tie the registered function's lifetime with the current default graph, # i.e., when the current graph is destroyed, we remove its py funcs. graph = ops.get_default_graph() # pylint: disable=protected-access while isinstance(graph, function._FuncGraph): # If the py_func was declared inside a _FuncGraph, its lifetime should be # bound to that of the outer graph instead. graph = graph._outer_graph cleanup = CleanupFunc(token) # TODO(zhifengc): Consider adding a Graph method to collect # `cleanup` objects in one of its member. if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"): graph._cleanup_py_funcs_used_in_graph = [] # When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph # will be destroyed and their __del__ will remove the 'token' from # the funcs registry. graph._cleanup_py_funcs_used_in_graph.append(cleanup) # pylint: enable=protected-access if eager: result = gen_script_ops.eager_py_func( input=inp, token=token, Tout=Tout, name=name) else: if stateful: result = gen_script_ops.py_func( input=inp, token=token, Tout=Tout, name=name) else: result = gen_script_ops.py_func_stateless( input=inp, token=token, Tout=Tout, name=name) return result if is_list_or_tuple else result[0]
def _internal_py_func(func, inp, Tout, stateful=None, use_eager_py_func=False, is_grad_func=False, name=None): """See documentation for py_func and eager_py_func.""" if not callable(func): raise ValueError( f"Expected func to be callable. Received func={func} of type " f"{type(func)}.") original_func = func func = autograph.do_not_convert(func) inp = list(inp) # Normalize Tout. is_list_or_tuple = isinstance(Tout, (list, tuple)) Tout = Tout if is_list_or_tuple else [Tout] Tout = [_as_dtype_or_type_spec(t) for t in Tout] # Check if we need to handle CompositeTensor inputs or outputs. handle_composite_tensors = (use_eager_py_func and (any( isinstance(v, composite_tensor.CompositeTensor) for v in inp) or any(isinstance(t, type_spec.TypeSpec) for t in Tout))) if handle_composite_tensors: func, inp, Tout, out_structure = _wrap_for_composites(func, inp, Tout) if use_eager_py_func: func = EagerFunc(func, Tout, is_grad_func) # Tying the registered function's lifetime with the current default graph is # not reliable. For example, Estimator-based binaries may switch graphs in # between model training end evaluation, via saved_model. Those binaries work # because the original function is global, and break once the registered # function is an anonymous lambda, like the one produced by do_not_convert. # To avoid breaking those cases, we attach the wrapper to the original # function so that their lifetime is connected. # TODO(b/144286616): Remove this. if tf_inspect.isfunction(original_func): # Note: this check is needed because original_func may be a descriptor # (https://docs.python.org/3/howto/descriptor.html) # and we can't attach attributes to those. original_func.ag_dnc_wrapper__ = func token = _py_funcs.insert(func) # We tie the registered function's lifetime with the current default graph, # i.e., when the current graph is destroyed, we remove its py funcs. graph = ops.get_default_graph() while True: current_graph = graph if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access graph = graph._outer_graph # pylint: disable=protected-access elif isinstance(graph, func_graph.FuncGraph): graph = graph.outer_graph if graph is current_graph: break # TODO(zhifengc): Consider adding a Graph method to collect # `cleanup` objects in one of its member. if not hasattr(graph, "_py_funcs_used_in_graph"): graph._py_funcs_used_in_graph = [] # pylint: disable=protected-access # Store a reference to the function in the graph to ensure it stays alive # as long as the graph lives. When the graph is destroyed, the function # is left to the garbage collector for destruction as well. graph._py_funcs_used_in_graph.append(func) # pylint: disable=protected-access if use_eager_py_func: result = gen_script_ops.eager_py_func(input=inp, token=token, is_async=context.is_async(), Tout=Tout, name=name) else: if stateful: result = gen_script_ops.py_func(input=inp, token=token, Tout=Tout, name=name) else: result = gen_script_ops.py_func_stateless(input=inp, token=token, Tout=Tout, name=name) if handle_composite_tensors and Tout: result = nest.pack_sequence_as(out_structure, result, expand_composites=True) return result if is_list_or_tuple else result[0]
def _internal_py_func(func, inp, Tout, stateful=None, eager=False, is_grad_func=False, name=None, use_tape_cache=True): """See documentation for py_func and eager_py_func.""" if not callable(func): raise ValueError( "Expected func to be callable, got func of type {}".format( type(func))) original_func = func func = autograph.do_not_convert(func) is_list_or_tuple = False if isinstance(Tout, (list, tuple)): is_list_or_tuple = True else: Tout = [Tout] if eager: func = EagerFunc(func, Tout, is_grad_func, use_tape_cache=use_tape_cache) # Tying the registered function's lifetime with the current default graph is # not reliable. For example, Estimator-based binaries may switch graphs in # between model training end evaluation, via saved_model. Those binaries work # because the original function is global, and break once the registered # function is an anonymous lambda, like the one produced by do_not_convert. # To avoid breaking those cases, we attach the wrapper to the original # function so that their lifetime is connected. # TODO(b/144286616): Remove this. if tf_inspect.isfunction(original_func): # Note: this check is needed because original_func may be a descriptor # (https://docs.python.org/3/howto/descriptor.html) # and we can't attach attributes to those. original_func.ag_dnc_wrapper__ = func token = _py_funcs.insert(func) # We tie the registered function's lifetime with the current default graph, # i.e., when the current graph is destroyed, we remove its py funcs. graph = ops.get_default_graph() while True: current_graph = graph if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access graph = graph._outer_graph # pylint: disable=protected-access elif isinstance(graph, func_graph.FuncGraph): graph = graph.outer_graph if graph is current_graph: break # TODO(zhifengc): Consider adding a Graph method to collect # `cleanup` objects in one of its member. if not hasattr(graph, "_py_funcs_used_in_graph"): graph._py_funcs_used_in_graph = [] # pylint: disable=protected-access # Store a reference to the function in the graph to ensure it stays alive # as long as the graph lives. When the graph is destroyed, the function # is left to the garbage collector for destruction as well. graph._py_funcs_used_in_graph.append(func) # pylint: disable=protected-access if eager: result = gen_script_ops.eager_py_func(input=inp, token=token, is_async=context.is_async(), Tout=Tout, name=name) else: if stateful: result = gen_script_ops.py_func(input=inp, token=token, Tout=Tout, name=name) else: result = gen_script_ops.py_func_stateless(input=inp, token=token, Tout=Tout, name=name) return result if is_list_or_tuple else result[0]