コード例 #1
0
ファイル: function.py プロジェクト: zxie/tensorflow
def graph_to_function_def(graph, name, inputs, outputs):
    """Returns `graph` as a `FunctionDef` protocol buffer.

  This method creates a [`FunctionDef`](
  https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
  protocol buffer that contains all the ops present in the graph.  The
  graph effectively becomes the body of the function.

  The arguments `inputs` and `outputs` will be listed as the inputs
  and outputs tensors of the function.  They must be lists of
  tensors present in the graph.  The lists can optionally be empty.

  The returned protocol buffer can be passed to the
  [`Graph.add_function()`](#Graph.add_function) method of a
  different graph to make it available there.

  Args:
    graph: Graph.
    name: string. The name to use for the function.
    inputs: List of tensors. Inputs to the function.
    outputs: List of tensors. Outputs of the function.

  Returns:
    A FunctionDef protocol buffer.
  """
    func = function_pb2.FunctionDef()
    func.signature.name = name
    func.signature.input_arg.extend([_tensor_to_argdef(i) for i in inputs])
    func.signature.output_arg.extend([_tensor_to_argdef(o) for o in outputs])
    func_arg_placeholders = set([i.name for i in inputs])
    for op in graph.get_operations():
        tensor_name = op.values()[0].name
        if tensor_name not in func_arg_placeholders:
            _add_op_node(graph, op, func)
    return func
コード例 #2
0
  def testGraphDefsWithPermutedNodesInFunctionsCompareEqual(self):

    @function.Defun(dtypes.float32)
    def F1(x):
      return math_ops.exp(x) - math_ops.exp(-x)

    f1_def = F1.definition

    library = function_pb2.FunctionDefLibrary()
    library.function.extend([f1_def])

    graph_def1 = graph_pb2.GraphDef()
    graph_def1.library.CopyFrom(library)

    reversed_function = function_pb2.FunctionDef()
    reversed_function.CopyFrom(f1_def)
    # Clear the node_def attribute.
    del reversed_function.node_def[:]
    reversed_function.node_def.extend(reversed(f1_def.node_def))
    reversed_library = function_pb2.FunctionDefLibrary()
    reversed_library.function.extend([reversed_function])
    graph_def2 = graph_pb2.GraphDef()
    graph_def2.library.CopyFrom(reversed_library)

    self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
コード例 #3
0
def _fix_fdef(orig_fdef, functions, shared_name_suffix):
    """Fixes a FunctionDef proto to be loaded in current context.

  In particular, when loading a function library into an eager context, one
  must rename the functions to avoid conflicts with existent functions.

  Args:
    orig_fdef: FunctionDef proto to fix. It is not modified.
    functions: map from function name to a ConcreteFunction instance.
    shared_name_suffix: A unique string for this load which helps to avoid
      `shared_name` collisions across loads. Two functions from the same load
      using the same `shared_name` still need to share, but functions from
      different loads with the same `shared_name` should not.

  Returns:
    A fixed copy of the original FunctionDef.
  """
    fdef = function_pb2.FunctionDef()
    fdef.CopyFrom(orig_fdef)
    contains_custom_gradients = False

    for node_def in fdef.node_def:
        fix_node_def(node_def, functions, shared_name_suffix)
        if not contains_custom_gradients:
            contains_custom_gradients = _check_op_has_custom_gradients(
                node_def)
    if contains_custom_gradients:
        logging.warning(
            "Importing a function (%s) with ops with custom gradients. Will likely "
            "fail if a gradient is requested.", fdef.signature.name)

    fdef.signature.name = _clean_function_name(fdef.signature.name)
    return fdef
コード例 #4
0
def _fix_fdef(orig_fdef, functions):
  """Fixes a FunctionDef proto to be loaded in current context.

  In particular, when loading a function library into an eager context, one
  must rename the functions to avoid conflicts with existent functions.

  Args:
    orig_fdef: FunctionDef proto to fix. It is not modified.
    functions: map from function name to a ConcreteFunction instance.

  Returns:
    A fixed copy of the original FunctionDef.
  """
  fdef = function_pb2.FunctionDef()
  fdef.CopyFrom(orig_fdef)
  for node_def in fdef.node_def:
    if "_gradient_op_type" in node_def.attr:
      if node_def.op in ["StatefulPartitionedCall", "PartitionedCall"]:
        # TODO(andresp): This code assumes that the gradient registered for this
        # function call is the default gradient for the function and not a
        # custom one.
        fname = node_def.attr["f"].func.name
        node_def.attr["_gradient_op_type"].s = compat.as_bytes(
            functions[fname]._gradient_name)  # pylint: disable=protected-access
      else:
        logging.warning("Importing a function (%s) with ops with custom "
                        "gradients. Will likely fail if a gradient is "
                        "requested.", fdef.signature.name)
    for _, attr_value in node_def.attr.items():
      if attr_value.func.name:
        attr_value.func.name = functions[attr_value.func.name].name

  fdef.signature.name = _clean_function_name(fdef.signature.name)
  return fdef
コード例 #5
0
def _graph_to_function_def(graph, inputs, outputs, out_names=None):
    """Returns `graph` as a `FunctionDef` protocol buffer.

  This method creates a [`FunctionDef`](
  https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
  protocol buffer that contains all the ops present in the graph.  The
  graph effectively becomes the body of the function.

  The arguments `inputs` and `outputs` will be listed as the inputs
  and outputs tensors of the function.  They must be lists of
  tensors present in the graph.  The lists can optionally be empty.

  Args:
    graph: Graph.
    inputs: List of tensors. Inputs to the function.
    outputs: List of tensors. Outputs of the function.
    out_names: Optional list of string names for the outputs.

  Returns:
    A FunctionDef protocol buffer.

  Raises:
    ValueError: if out_names is specified and the wrong length.
  """
    func = function_pb2.FunctionDef()
    func.signature.name = "_"
    used_names = set()
    func.signature.input_arg.extend(
        [_tensor_to_argdef(i, used_names=used_names) for i in inputs])
    if out_names is None:
        used_names = set()
        func.signature.output_arg.extend(
            [_tensor_to_argdef(o, used_names=used_names) for o in outputs])
    elif len(outputs) != len(out_names):
        raise ValueError(
            "Length of out_names (%d) does not match number of outputs (%d): %s"
            % (len(out_names), len(outputs), ", ".join(out_names)))
    elif len(out_names) != len(set(out_names)):
        raise ValueError("Must not have duplicates in out_names: %s" %
                         ", ".join(out_names))
    else:
        func.signature.output_arg.extend(
            [_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
    func_arg_placeholders = set([i.name for i in inputs])
    input_dict = _create_input_dict(graph, func_arg_placeholders)

    for op in graph.get_operations():
        if _is_in_placeholders(op, func_arg_placeholders):
            continue
        _add_op_node(op, func, input_dict)

    if out_names is None:
        for index, o in enumerate(outputs):
            k = func.signature.output_arg[index].name
            func.ret[k] = input_dict[o.name]
    else:
        for o, n in zip(outputs, out_names):
            func.ret[n] = input_dict[o.name]

    return func
コード例 #6
0
def _graph_to_function_def(graph, inputs, outputs):
    """Returns `graph` as a `FunctionDef` protocol buffer.

  This method creates a [`FunctionDef`](
  https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
  protocol buffer that contains all the ops present in the graph.  The
  graph effectively becomes the body of the function.

  The arguments `inputs` and `outputs` will be listed as the inputs
  and outputs tensors of the function.  They must be lists of
  tensors present in the graph.  The lists can optionally be empty.

  Args:
    graph: Graph.
    inputs: List of tensors. Inputs to the function.
    outputs: List of tensors. Outputs of the function.

  Returns:
    A FunctionDef protocol buffer.
  """
    func = function_pb2.FunctionDef()
    func.signature.name = "_"
    func.signature.input_arg.extend([_tensor_to_argdef(i) for i in inputs])
    func.signature.output_arg.extend([_tensor_to_argdef(o) for o in outputs])
    func_arg_placeholders = set([i.name for i in inputs])
    for op in graph.get_operations():
        if op.values() and (op.values()[0].name in func_arg_placeholders):
            continue
        _add_op_node(op, func)
    return func
コード例 #7
0
def _fix_fdef(orig_fdef, functions, shared_name_suffix):
    """Fixes a FunctionDef proto to be loaded in current context.

  In particular, when loading a function library into an eager context, one
  must rename the functions to avoid conflicts with existent functions.

  Args:
    orig_fdef: FunctionDef proto to fix. It is not modified.
    functions: map from function name to a ConcreteFunction instance.
    shared_name_suffix: A unique string for this load which helps to avoid
      `shared_name` collisions across loads. Two functions from the same load
      using the same `shared_name` still need to share, but functions from
      different loads with the same `shared_name` should not.

  Returns:
    A fixed copy of the original FunctionDef.
  """
    fdef = function_pb2.FunctionDef()
    fdef.CopyFrom(orig_fdef)
    for node_def in fdef.node_def:
        fix_node_def(node_def, functions, shared_name_suffix,
                     fdef.signature.name)

    fdef.signature.name = _clean_function_name(fdef.signature.name)
    return fdef
コード例 #8
0
    def testNoOutputs(self):
        with session_lib.Session() as sess:
            # Build a function with a single Const node, whose output is ignored.
            fdef = function_pb2.FunctionDef()
            fdef.signature.name = "KernelWithNoOutputs"
            node = node_def_pb2.NodeDef()
            node.op = "Const"
            node.name = "ignored"
            node.attr["dtype"].type = dtypes.int32.as_datatype_enum
            tensor = tensor_util.make_tensor_proto([0],
                                                   dtype=dtypes.int32,
                                                   shape=[])
            node.attr["value"].tensor.CopyFrom(tensor)
            fdef.node_def.extend([node])

            # Check that calling the result as a compiled kernel doesn't crash.
            @function.Defun(compiled=True)
            def KernelWithNoOutputs():
                return constant_op.constant(100)

            # Hack to override the definition.  By accessing .definition, we
            # force the _DefinedFunction initialized internally. Then, we
            # replace it's internal FunctionDef proto. We do this hack here
            # because one typically can't construct KernelWithNoOutputs
            # function via Defun decorator directly.
            _ = KernelWithNoOutputs.definition
            foo = KernelWithNoOutputs
            foo._definition = fdef
            call = KernelWithNoOutputs()
            sess.run(call, {})
コード例 #9
0
def make_function_def(name, graph, operations, inputs, outputs):
    """Makes FunctionDef proto and defined function.

  Args:
    name: the function name
    graph: the graph from which to build the function
    operations: the operations in the function body
    inputs: tensors to be used as function arguments
    outputs: tensors to be returned from the function

  Returns:
   fdef: a FunctionDef protocol buffer for the function
   fn: a wrapped TF_Function for the function
  """
    with errors.raise_exception_on_not_ok_status() as status:
        fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
            graph._c_graph,  # pylint: disable=protected-access
            compat.as_str(name),
            False,
            [o._c_op for o in operations],  # pylint: disable=protected-access
            [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
            [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
            [],
            None,
            compat.as_str(""),
            status)
    # TODO(apassos) avoid creating a FunctionDef (specially to grab the signature,
    # but also in general it's nice not to depend on it.
    with c_api_util.tf_buffer() as buffer_:
        with errors.raise_exception_on_not_ok_status() as status:
            pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
        proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
    fdef = function_pb2.FunctionDef()
    fdef.ParseFromString(compat.as_bytes(proto_data))
    return fdef, fn
コード例 #10
0
    def test_create_nullary(self):
        fndef = text_format.Parse(
            """
            signature {
               name: 'NullaryFunction'
               output_arg { name: 'o' type: DT_INT32 }
             }
             node_def {
               name: 'retval'
               op: 'Const'
               attr {
                 key: 'dtype'
                 value { type: DT_INT32 }
               }
               attr {
                 key: 'value'
                 value {
                   tensor {
                     dtype: DT_INT32
                     tensor_shape {}
                     int_val: 1
                   }
                 }
               }
             }
             ret { key: 'o' value: 'retval:output' }
         """,
            function_pb2.FunctionDef(),
        )

        ctx = runtime_client.GlobalEagerContext()
        rt = runtime_client.Runtime(ctx)
        rt.CreateFunction(fndef)
コード例 #11
0
def function_def_from_tf_function(c_func):
  """Converts a SWIG-wrapped TF_Function* to a FunctionDef proto."""
  with c_api_util.tf_buffer() as buf:
    c_api.TF_FunctionToFunctionDef(c_func, buf)
    data = c_api.TF_GetBuffer(buf)
  fdef = function_pb2.FunctionDef()
  fdef.ParseFromString(compat.as_bytes(data))
  return fdef
コード例 #12
0
def _fix_fdef(orig_fdef, name_map):
  fdef = function_pb2.FunctionDef()
  fdef.CopyFrom(orig_fdef)
  fdef.signature.name = _clean_function_name(fdef.signature.name)
  for node_def in fdef.node_def:
    for _, attr_value in node_def.attr.items():
      if attr_value.func.name:
        attr_value.func.name = name_map[attr_value.func.name]
  return fdef
コード例 #13
0
ファイル: function.py プロジェクト: zuozi2810/tensorflow
 def definition(self):
     """Function definition proto."""
     self._create_definition_if_needed()
     if self._c_func:
         with c_api_util.tf_buffer() as buf:
             c_api.TF_FunctionToFunctionDef(self._c_func.func, buf)
             fdef = function_pb2.FunctionDef()
             proto_data = c_api.TF_GetBuffer(buf)
             fdef.ParseFromString(compat.as_bytes(proto_data))
         return fdef
     return self._definition
コード例 #14
0
    def to_function_graph_def(self, add_shapes=True):
        # type: (bool) -> function_pb2.FunctionDef
        """
    Args:
      add_shapes: If True, add the special "_output_shapes" attribute with
        output shape information from this Node's output metadata.

    Returns the `function_pb2.FunctionDef` serialization of this function's
    graph in its current form.
    """
        ret = function_pb2.FunctionDef()
        ret.CopyFrom(self._func_graph_def)
        # Leave signature as is, but replace all node_defs
        del ret.node_def[:]
        ret.signature.CopyFrom(self._func_graph_def.signature)

        input_args = [input_arg.name for input_arg in ret.signature.input_arg]

        for op in self.nodes:
            if op.op_type == _INPUT_DUMMY_OP_NAME:
                continue

            node_def = ret.node_def.add()
            op.to_node_def(node_def, add_shapes)
            unique_input_counter = Counter()

            for i in range(len(op.inputs)):
                (input_tensor_name,
                 global_input_index_str) = (op.inputs[i].name.split(":"))

                global_input_index = int(global_input_index_str)
                if input_tensor_name in input_args:
                    # don't add index for function args
                    node_def.input[i] = input_tensor_name
                else:
                    input_op_output_args, input_op_output_has_number_attr = (
                        self._get_op_def_denormalized_outputs(op.inputs[i].op))
                    if (len(input_op_output_args) == 1
                            and input_op_output_args[0].type_list_attr):
                        node_def.input[i] = (input_tensor_name + ":" +
                                             input_op_output_args[0].name +
                                             ":" + str(global_input_index))
                    else:
                        input_name = (
                            input_tensor_name + ":" +
                            input_op_output_args[global_input_index].name)
                        node_def.input[i] = (
                            input_name + ":" +
                            str(unique_input_counter[input_name]))
                        if input_op_output_has_number_attr:
                            # only uniquify input args with var length,
                            # otherwise it should be 0
                            unique_input_counter[input_name] += 1
        return ret
コード例 #15
0
ファイル: function.py プロジェクト: Odegaard11/publicBicycle
 def definition(self):
   """Function definition proto."""
   self._create_definition_if_needed()
   if self._c_func:
     with c_api_util.tf_buffer() as buf:
       with errors.raise_exception_on_not_ok_status() as status:
         c_api.TF_FunctionToFunctionDef(self._c_func, buf, status)
       fdef = function_pb2.FunctionDef()
       proto_data = c_api.TF_GetBuffer(buf)
       fdef.ParseFromString(compat.as_bytes(proto_data))
     return fdef
   return self._definition
コード例 #16
0
 def definition(self):
   """Function definition proto."""
   self._create_definition_if_needed()
   if self._c_func:
     with c_api_util.tf_buffer() as buf:
       c_api.TF_FunctionToFunctionDef(self._c_func.func, buf)
       fdef = function_pb2.FunctionDef()
       proto_data = c_api.TF_GetBuffer(buf)
       fdef.ParseFromString(compat.as_bytes(proto_data))
       with ops.init_scope():
         if context.executing_eagerly():
           context.add_function(self._c_func.func)
           self._function_deleter = _DefinedFunctionDeleter(
               fdef.signature.name)
     return fdef
   return self._definition
コード例 #17
0
ファイル: function.py プロジェクト: sgcm520/tensorflow2
    def __init__(self, name, graph, operations, inputs, outputs, attrs):
        """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
      attrs: dict mapping names of attributes to their AttrValue values
    """
        fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
            graph._c_graph,  # pylint: disable=protected-access
            compat.as_str(name),
            False,
            [o._c_op for o in operations],  # pylint: disable=protected-access
            [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
            [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
            [],
            None,
            compat.as_str(""))

        for name, attr_value in attrs.items():
            serialized = attr_value.SerializeToString()
            # TODO(iga): this creates and deletes a new TF_Status for every attr.
            # It might be worth creating a convenient way to re-use status.
            pywrap_tensorflow.TF_FunctionSetAttrValueProto(
                fn, compat.as_str(name), serialized)

        # TODO(apassos) avoid creating a FunctionDef (specially to grab the
        # signature, but also in general it's nice not to depend on it.
        with c_api_util.tf_buffer() as buffer_:
            pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
            proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
        function_def = function_pb2.FunctionDef()
        function_def.ParseFromString(compat.as_bytes(proto_data))
        if context.executing_eagerly():
            _register(fn)
        self.definition = function_def
        self.name = function_def.signature.name
        self.signature = function_def.signature
        self.grad_func_name = None
        self.python_grad_func = None
        self._c_func = c_api_util.ScopedTFFunction(fn)
        self._grad_func = None
コード例 #18
0
def _fix_fdef(orig_fdef, functions, shared_name_suffix):
    """Fixes a FunctionDef proto to be loaded in current context.

  In particular, when loading a function library into an eager context, one
  must rename the functions to avoid conflicts with existent functions.

  Args:
    orig_fdef: FunctionDef proto to fix. It is not modified.
    functions: map from function name to a ConcreteFunction instance.
    shared_name_suffix: A unique string for this load which helps to avoid
      `shared_name` collisions across loads. Two functions from the same load
      using the same `shared_name` still need to share, but functions from
      different loads with the same `shared_name` should not.

  Returns:
    A fixed copy of the original FunctionDef.
  """
    fdef = function_pb2.FunctionDef()
    fdef.CopyFrom(orig_fdef)
    for node_def in fdef.node_def:
        if "_gradient_op_type" in node_def.attr:
            if node_def.op in ["StatefulPartitionedCall", "PartitionedCall"]:
                # TODO(andresp): This code assumes that the gradient registered for this
                # function call is the default gradient for the function and not a
                # custom one.
                fname = node_def.attr["f"].func.name
                node_def.attr["_gradient_op_type"].s = compat.as_bytes(
                    functions[fname]._gradient_name)  # pylint: disable=protected-access
            else:
                logging.warning(
                    "Importing a function (%s) with ops with custom "
                    "gradients. Will likely fail if a gradient is "
                    "requested.", fdef.signature.name)
        for _, attr_value in node_def.attr.items():
            if attr_value.func.name:
                attr_value.func.name = functions[attr_value.func.name].name

        # TODO(b/124205571): Avoid accidental sharing and destruction of restored
        # resources. For now uniquify "shared_name" when loading functions to avoid
        # sharing.
        if "shared_name" in node_def.attr:
            node_def.attr["shared_name"].s += compat.as_bytes(
                shared_name_suffix)

    fdef.signature.name = _clean_function_name(fdef.signature.name)
    return fdef
コード例 #19
0
ファイル: function.py プロジェクト: tejamukka/tensorflow-1
    def __init__(self, name, graph, operations, inputs, outputs):
        """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
    """
        with errors.raise_exception_on_not_ok_status() as status:
            fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
                graph._c_graph,  # pylint: disable=protected-access
                compat.as_str(name),
                False,
                [o._c_op for o in operations],  # pylint: disable=protected-access
                [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
                [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
                [],
                None,
                compat.as_str(""),
                status)
        # TODO(apassos) avoid creating a FunctionDef (specially to grab the
        # signature, but also in general it's nice not to depend on it.
        with c_api_util.tf_buffer() as buffer_:
            with errors.raise_exception_on_not_ok_status() as status:
                pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
            proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
        function_def = function_pb2.FunctionDef()
        function_def.ParseFromString(compat.as_bytes(proto_data))
        if context.in_eager_mode():
            _register(fn)
        self.definition = function_def
        self.name = function_def.signature.name
        self.signature = function_def.signature
        self.grad_func_name = None
        self.python_grad_func = None
        self._c_func = fn
        self._grad_func = None
コード例 #20
0
def _fix_fdef(orig_fdef, functions, shared_name_suffix, new_gradient_op_types):
    """Fixes a FunctionDef proto to be loaded in current context.

  In particular, when loading a function library into an eager context, one
  must rename the functions to avoid conflicts with existent functions.

  Args:
    orig_fdef: FunctionDef proto to fix. It is not modified.
    functions: map from function name to a ConcreteFunction instance.
    shared_name_suffix: A unique string for this load which helps to avoid
      `shared_name` collisions across loads. Two functions from the same load
      using the same `shared_name` still need to share, but functions from
      different loads with the same `shared_name` should not.
    new_gradient_op_types: map from old gradient op type to newly generated
      op type.

  Returns:
    A fixed copy of the original FunctionDef
  """
    fdef = function_pb2.FunctionDef()
    fdef.CopyFrom(orig_fdef)
    contains_unsaved_custom_gradients = False

    for node_def in fdef.node_def:
        fix_node_def(node_def, functions, shared_name_suffix)
        op_type = _get_gradient_op_type(node_def)
        if op_type is not None:
            if op_type in new_gradient_op_types:
                node_def.attr["_gradient_op_type"].s = compat.as_bytes(
                    new_gradient_op_types[op_type])
            else:
                contains_unsaved_custom_gradients = True
    if contains_unsaved_custom_gradients:
        logging.warning(
            "Importing a function (%s) with ops with unsaved custom gradients. Will"
            " likely fail if a gradient is requested.", fdef.signature.name)

    fdef.signature.name = _clean_function_name(fdef.signature.name)
    return fdef
コード例 #21
0
    def test_create_function_called_by_py_runtime(self):
        if not tf2.enabled():
            self.skipTest("TF2 test")

        fndef = text_format.Parse(
            """
            signature {
               name: 'NullaryFunction'
               output_arg { name: 'o' type: DT_INT32 }
             }
             node_def {
               name: 'retval'
               op: 'Const'
               attr {
                 key: 'dtype'
                 value { type: DT_INT32 }
               }
               attr {
                 key: 'value'
                 value {
                   tensor {
                     dtype: DT_INT32
                     tensor_shape {}
                     int_val: 1
                   }
                 }
               }
             }
             ret { key: 'o' value: 'retval:output' }
         """,
            function_pb2.FunctionDef(),
        )

        ctx = runtime_client.GlobalPythonEagerContext()
        rt = runtime_client.Runtime(ctx)
        rt.CreateFunction(fndef)

        ret, = execute.execute("NullaryFunction", 1, [], (), context.context())
        self.assertAllEqual(ret, 1)
コード例 #22
0
def graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
    """Returns `graph` as a `FunctionDef` protocol buffer.

  This method creates a [`FunctionDef`](
  https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
  protocol buffer that contains all the ops in `operations`.  The
  operations become the body of the function.

  The arguments `inputs` and `outputs` will be listed as the inputs
  and outputs tensors of the function.  They must be lists of
  tensors present in the graph.  The lists can optionally be empty.

  Args:
    graph: Graph.
    operations: the operations to put in the function. Must be a subset of
     the operations in the graph.
    inputs: List of tensors. Inputs to the function.
    outputs: List of tensors. Outputs of the function.
    out_names: Optional list of string names for the outputs.

  Returns:
    A FunctionDef protocol buffer.

  Raises:
    ValueError: if out_names is specified and the wrong length.
  """
    func = function_pb2.FunctionDef()
    func.signature.name = "_"
    used_names = set()
    func.signature.input_arg.extend(
        [_tensor_to_argdef(i, used_names=used_names) for i in inputs])
    # Initializes the input map with all placeholder input tensors.
    initial_dict = {}
    for o, m in zip(inputs, func.signature.input_arg):
        initial_dict[o.name] = m.name
    if out_names is None:
        used_names = set()
        func.signature.output_arg.extend(
            [_tensor_to_argdef(o, used_names=used_names) for o in outputs])
    elif len(outputs) != len(out_names):
        raise errors_impl.InvalidArgumentError(
            None, None,
            "output names must be either empty or equal in size to outputs. "
            "output names size = %d outputs size = %d" %
            (len(out_names), len(outputs)))
    elif len(out_names) != len(set(out_names)):
        raise ValueError("Must not have duplicates in out_names: %s" %
                         ", ".join(out_names))
    else:
        func.signature.output_arg.extend(
            [_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
    func_arg_placeholders = set(i.name for i in inputs)
    input_dict = _create_input_dict(graph,
                                    func_arg_placeholders,
                                    initial_value=initial_dict)

    for op in operations:
        if _is_in_placeholders(op, func_arg_placeholders):
            continue
        _add_op_node(op, func, input_dict)

    if out_names is None:
        for index, o in enumerate(outputs):
            k = func.signature.output_arg[index].name
            func.ret[k] = input_dict[o.name]
    else:
        for o, n in zip(outputs, out_names):
            func.ret[n] = input_dict[o.name]

    return func