Esempio n. 1
0
def _from_definition(fdef, grad_func=None):
    """Creates a _DefinedFunction initialized from a FunctionDef proto.

  Args:
    fdef: a FunctionDef
    grad_func: a _DefinedFunction or None

  Returns:
    A _DefinedFunction representing fdef
  """
    # TODO(iga): This method does major surgery on _DefinedFunction.
    # Make it a named constructor using @classmethod of _DefinedFunction.

    # The Python callable is only needed to create a FunctionDef. Since we have
    # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we
    # have access to such a callable here).
    func = None
    argnames = [arg.name for arg in fdef.signature.input_arg]
    input_types = tuple(
        dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg)
    func_name = fdef.signature.name
    # Note: FunctionDefs do not include python gradient functions, so if the
    # original _DefinedFunction included one it will not be reflected here.
    python_grad_func = None
    out_names = [arg.name for arg in fdef.signature.output_arg]
    result = _DefinedFunction(func, argnames, input_types, func_name,
                              grad_func, python_grad_func, out_names)
    # pylint: disable=protected-access
    serialized = fdef.SerializeToString()
    c_func = c_api.TF_FunctionImportFunctionDef(serialized)
    result._c_func = c_api_util.ScopedTFFunction(c_func)
    result._extra_inputs = []
    # pylint: enable=protected-access

    return result
Esempio n. 2
0
def _from_definition(fdef, grad_func=None):
    """Creates a _DefinedFunction initialized from a FunctionDef proto.

  Args:
    fdef: a FunctionDef
    grad_func: a _DefinedFunction or None

  Returns:
    A _DefinedFunction representing fdef
  """
    # TODO(iga): This method does major surgery on _DefinedFunction.
    # Make it a named constructor using @classmethod of _DefinedFunction.

    # The Python callable is only needed to create a FunctionDef. Since we have
    # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we
    # have access to such a callable here).
    func = None
    argnames = [arg.name for arg in fdef.signature.input_arg]
    input_types = tuple(
        dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg)
    func_name = fdef.signature.name
    # Note: FunctionDefs do not include python gradient functions, so if the
    # original _DefinedFunction included one it will not be reflected here.
    python_grad_func = None
    out_names = [arg.name for arg in fdef.signature.output_arg]
    result = _DefinedFunction(func, argnames, input_types, func_name,
                              grad_func, python_grad_func, out_names)
    # pylint: disable=protected-access
    if ops._USE_C_API:
        serialized = fdef.SerializeToString()
        with errors.raise_exception_on_not_ok_status() as status:
            result._c_func = c_api.TF_FunctionImportFunctionDef(
                serialized, status)
        result._extra_inputs = []
    else:
        result._definition = fdef
        # Captured inputs are added as regular inputs to a function when it's
        # serialized, i.e. any extra inputs from the original function are now
        # included in `result`._args
        result._extra_inputs = []
        result._hash_str = result._create_hash_str(
            result._definition.signature.input_arg,
            result._definition.signature.output_arg,
            result._definition.node_def)
    # pylint: enable=protected-access

    return result