예제 #1
0
    def _set_c_attrs(self, attrs):
        """Sets `attrs` as attributes of self._c_func.

    Requires that self._c_func is not None.

    Args:
      attrs: a dictionary from attribute name to attribute proto value
    """
        for name, attr_value in attrs.items():
            serialized = attr_value.SerializeToString()
            # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
            # It might be worth creating a convenient way to re-use the same status.
            c_api.TF_FunctionSetAttrValueProto(self._c_func.func,
                                               compat.as_str(name), serialized)
예제 #2
0
    def __init__(self, name, graph, operations, inputs, outputs, attrs):
        """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
      attrs: dict mapping names of attributes to their AttrValue values
    """
        fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
            graph._c_graph,  # pylint: disable=protected-access
            compat.as_str(name),
            False,
            [o._c_op for o in operations],  # pylint: disable=protected-access
            [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
            [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
            [],
            None,
            compat.as_str(""))

        for name, attr_value in attrs.items():
            serialized = attr_value.SerializeToString()
            # TODO(iga): this creates and deletes a new TF_Status for every attr.
            # It might be worth creating a convenient way to re-use status.
            pywrap_tensorflow.TF_FunctionSetAttrValueProto(
                fn, compat.as_str(name), serialized)

        # TODO(apassos) avoid creating a FunctionDef (specially to grab the
        # signature, but also in general it's nice not to depend on it.
        with c_api_util.tf_buffer() as buffer_:
            pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
            proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
        function_def = function_pb2.FunctionDef()
        function_def.ParseFromString(compat.as_bytes(proto_data))
        if context.executing_eagerly():
            _register(fn)
        self.definition = function_def
        self.name = function_def.signature.name
        self.signature = function_def.signature
        self.grad_func_name = None
        self.python_grad_func = None
        self._c_func = c_api_util.ScopedTFFunction(fn)
        self._grad_func = None