def _from_definition(fdef, grad_func=None): """Creates a _DefinedFunction initialized from a FunctionDef proto. Args: fdef: a FunctionDef grad_func: a _DefinedFunction or None Returns: A _DefinedFunction representing fdef """ # TODO(iga): This method does major surgery on _DefinedFunction. # Make it a named constructor using @classmethod of _DefinedFunction. # The Python callable is only needed to create a FunctionDef. Since we have # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we # have access to such a callable here). func = None argnames = [arg.name for arg in fdef.signature.input_arg] input_types = tuple( dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg) func_name = fdef.signature.name # Note: FunctionDefs do not include python gradient functions, so if the # original _DefinedFunction included one it will not be reflected here. python_grad_func = None out_names = [arg.name for arg in fdef.signature.output_arg] result = _DefinedFunction(func, argnames, input_types, func_name, grad_func, python_grad_func, out_names) # pylint: disable=protected-access serialized = fdef.SerializeToString() c_func = c_api.TF_FunctionImportFunctionDef(serialized) result._c_func = c_api_util.ScopedTFFunction(c_func) result._extra_inputs = [] # pylint: enable=protected-access return result
def _from_definition(fdef, grad_func=None): """Creates a _DefinedFunction initialized from a FunctionDef proto. Args: fdef: a FunctionDef grad_func: a _DefinedFunction or None Returns: A _DefinedFunction representing fdef """ # TODO(iga): This method does major surgery on _DefinedFunction. # Make it a named constructor using @classmethod of _DefinedFunction. # The Python callable is only needed to create a FunctionDef. Since we have # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we # have access to such a callable here). func = None argnames = [arg.name for arg in fdef.signature.input_arg] input_types = tuple( dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg) func_name = fdef.signature.name # Note: FunctionDefs do not include python gradient functions, so if the # original _DefinedFunction included one it will not be reflected here. python_grad_func = None out_names = [arg.name for arg in fdef.signature.output_arg] result = _DefinedFunction(func, argnames, input_types, func_name, grad_func, python_grad_func, out_names) # pylint: disable=protected-access if ops._USE_C_API: serialized = fdef.SerializeToString() with errors.raise_exception_on_not_ok_status() as status: result._c_func = c_api.TF_FunctionImportFunctionDef( serialized, status) result._extra_inputs = [] else: result._definition = fdef # Captured inputs are added as regular inputs to a function when it's # serialized, i.e. any extra inputs from the original function are now # included in `result`._args result._extra_inputs = [] result._hash_str = result._create_hash_str( result._definition.signature.input_arg, result._definition.signature.output_arg, result._definition.node_def) # pylint: enable=protected-access return result