예제 #1
0
 def _compute_backprop(self):
   """Computes the backprop function object for this function."""
   self._has_backprop = True
   with self._graph.as_default(), context.graph_mode():
     c = _CapturingContext()
     with c:
       filtered_outputs = [
           x for x in self._returns if x is not None
       ]
       self._out_grad_placeholders = [
           graph_placeholder(x.dtype, x.shape) for x in filtered_outputs
       ]
       in_gradients = gradients_impl.gradients(
           filtered_outputs,
           self._input_placeholders,
           grad_ys=self._out_grad_placeholders)
       shapes = [x.shape for x in in_gradients if x is not None]
   captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
   forward_function_def = graph_to_function_def.graph_to_function_def(
       self._graph, self._ops, self._input_placeholders,
       filtered_outputs + captures)
   self._forward_fdef = _DefinedFunction(forward_function_def)
   _register_with_name(_forward_name(self._func_name), forward_function_def)
   backward_outputs = [x for x in in_gradients if x is not None]
   all_inputs = self._out_grad_placeholders + captures
   backward_function_def = graph_to_function_def.graph_to_function_def(
       self._graph, [x.op for x in self._out_grad_placeholders
                    ] + list(sorted(c.known_ops, key=lambda x: x.name)),
       all_inputs, backward_outputs)
   _register_with_name(_backward_name(self._func_name), backward_function_def)
   self._backward_function = _GraphModeFunction(
       all_inputs, [], backward_function_def, self._graph, c.known_ops,
       in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
예제 #2
0
 def _compute_backprop(self):
     """Computes the backprop function object for this function."""
     self._has_backprop = True
     with self._graph.as_default(), context.graph_mode():
         c = _CapturingContext()
         with c:
             filtered_outputs = [x for x in self._returns if x is not None]
             self._out_grad_placeholders = [
                 graph_placeholder(x.dtype, x.shape)
                 for x in filtered_outputs
             ]
             in_gradients = gradients_impl.gradients(
                 filtered_outputs,
                 self._input_placeholders,
                 grad_ys=self._out_grad_placeholders)
             shapes = [x.shape for x in in_gradients if x is not None]
     captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
     forward_function_def = graph_to_function_def.graph_to_function_def(
         self._graph, self._ops, self._input_placeholders,
         filtered_outputs + captures)
     self._forward_fdef = _DefinedFunction(forward_function_def)
     _register_with_name(_forward_name(self._func_name),
                         forward_function_def)
     backward_outputs = [x for x in in_gradients if x is not None]
     all_inputs = self._out_grad_placeholders + captures
     backward_function_def = graph_to_function_def.graph_to_function_def(
         self._graph, [x.op for x in self._out_grad_placeholders] +
         list(sorted(c.known_ops, key=lambda x: x.name)), all_inputs,
         backward_outputs)
     _register_with_name(_backward_name(self._func_name),
                         backward_function_def)
     self._backward_function = _GraphModeFunction(
         all_inputs, [], backward_function_def, self._graph, c.known_ops,
         in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
예제 #3
0
    def _create_definition_if_needed(self):
        """Creates the function definition if it's not created yet."""

        if self._definition is not None:
            return

        # Create the func_def object.
        temp_graph = _FuncGraph()
        with temp_graph.as_default():
            # List of placeholders for the function_def.
            inputs = []
            for (argname, argtype) in self._args:
                argholder = array_ops.placeholder(argtype, name=argname)
                inputs.append(argholder)
            # Call func and gather the output tensors.
            with vs.variable_scope("", custom_getter=temp_graph.getvar):
                outputs = self._func(*inputs)
            # If func only returned one value, make it a tuple.
            if not isinstance(outputs, (list, tuple)):
                outputs = (outputs, )
            if any([_ is None for _ in outputs]):
                raise ValueError("Function can not return None.")
            # Ensures each output is a Tensor.
            outputs = [ops.convert_to_tensor(_) for _ in outputs]
        self._extra_inputs = temp_graph.extra_inputs
        inputs.extend(temp_graph.extra_args)
        # pylint: disable=protected-access
        self._sub_functions = temp_graph._functions
        # pylint: enable=protected-access

        # Build the FunctionDef
        self._definition = graph_to_function_def.graph_to_function_def(
            temp_graph,
            temp_graph.get_operations(),
            inputs,
            outputs,
            out_names=self._out_names)

        # Extra kwargs are treated as attrs on the function def.
        sig_pre_func_name = self._func_name or _get_func_name(self._func)
        kwargs_attr = _parse_kwargs_as_attrs(sig_pre_func_name,
                                             **self._extra_kwargs)
        for k in kwargs_attr:
            self._definition.attr[k].CopyFrom(kwargs_attr[k])

        # Hash the definition and its dependencies.
        self._hash_str = self._create_hash_str(
            self._definition.signature.input_arg,
            self._definition.signature.output_arg, self._definition.node_def)

        # Finally, we decide the function name to use.  If not specified,
        # make up something which is almost certainly unique (but deterministic).
        if not self._func_name:
            self._func_name = "_".join(
                [_get_func_name(self._func), self._hash_str])
        self._definition.signature.name = self._func_name
        if self._func.__doc__:
            self._definition.signature.description = self._func.__doc__
예제 #4
0
  def _create_definition_if_needed(self):
    """Creates the function definition if it's not created yet."""

    if self._definition is not None:
      return

    # Create the func_def object.
    temp_graph = _FuncGraph()
    with temp_graph.as_default():
      # List of placeholders for the function_def.
      inputs = []
      for (argname, argtype) in self._args:
        argholder = array_ops.placeholder(argtype, name=argname)
        inputs.append(argholder)
      # Call func and gather the output tensors.
      with vs.variable_scope("", custom_getter=temp_graph.getvar):
        outputs = self._func(*inputs)
      # If func only returned one value, make it a tuple.
      if not isinstance(outputs, (list, tuple)):
        outputs = (outputs,)
      if any([_ is None for _ in outputs]):
        raise ValueError("Function can not return None.")
      # Ensures each output is a Tensor.
      outputs = [ops.convert_to_tensor(_) for _ in outputs]
    self._extra_inputs = temp_graph.extra_inputs
    inputs.extend(temp_graph.extra_args)
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Build the FunctionDef
    self._definition = graph_to_function_def.graph_to_function_def(
        temp_graph,
        temp_graph.get_operations(),
        inputs,
        outputs,
        out_names=self._out_names)

    # Extra kwargs are treated as attrs on the function def.
    sig_pre_func_name = self._func_name or _get_func_name(self._func)
    kwargs_attr = _parse_kwargs_as_attrs(sig_pre_func_name,
                                         **self._extra_kwargs)
    for k in kwargs_attr:
      self._definition.attr[k].CopyFrom(kwargs_attr[k])

    # Hash the definition and its dependencies.
    self._hash_str = self._create_hash_str(
        self._definition.signature.input_arg,
        self._definition.signature.output_arg, self._definition.node_def)

    # Finally, we decide the function name to use.  If not specified,
    # make up something which is almost certainly unique (but deterministic).
    if not self._func_name:
      self._func_name = "_".join([_get_func_name(self._func), self._hash_str])
    self._definition.signature.name = self._func_name
    if self._func.__doc__:
      self._definition.signature.description = self._func.__doc__
예제 #5
0
    def _build_function_def(self):
        with ops.Graph().as_default() as g:
            # Inputs:    x    y    z
            #            |\   |   /
            #            | \  |  /
            #            |  foo_1     list_output
            #            |   / \       /       \
            #            | d_1 e_1  a:1        a:0
            #            |  \   |   /           |
            #            |   \  |  /            |
            #            |    foo_2             |
            #            |     / \              |
            # Outputs:   x   d_2 e_2           a:0

            x = array_ops.placeholder(dtypes.float32, name="x")
            y = array_ops.placeholder(dtypes.int32, name="y")
            z = array_ops.placeholder(dtypes.int32, name="z")

            d_1, e_1 = test_ops._op_def_lib.apply_op("Foo1",
                                                     name="foo_1",
                                                     a=x,
                                                     b=y,
                                                     c=z)

            list_output0, list_output1 = test_ops.list_output(
                T=[dtypes.int32, dtypes.int32], name="list_output")

            d_2, e_2 = test_ops.foo1(a=d_1,
                                     b=e_1,
                                     c=list_output1,
                                     name="foo_2")

        fdef = graph_to_function_def.graph_to_function_def(
            g,
            g.get_operations(),
            [x, y, z],  # Inputs
            [x, d_2, e_2, list_output0])  # Outputs.

        # Assert that the FunctionDef was correctly built.
        assert len(fdef.node_def) == 3  # 2 Foo1 nodes and 1 ListOutput node.
        assert fdef.node_def[0].op == "Foo1"
        assert fdef.node_def[0].input == ["x", "y", "z"]
        assert fdef.node_def[1].op == "ListOutput"
        assert not fdef.node_def[1].input
        assert fdef.node_def[2].op == "Foo1"
        assert fdef.node_def[2].input == [
            "foo_1:d:0", "foo_1:e:0", "list_output:a:1"
        ]
        return fdef
예제 #6
0
def _defun_internal(name, func, args, kwds):
    """Defines and returns graph-mode version of func."""
    with context.graph_mode():
        tmp_graph = ops.Graph()
        # Copy the graph collections to ensure summaries and other things work. This
        # lets the function access (but not mutate) collections of the containing
        # graph, such as the global step and the summary writer collections.
        curr_graph = ops.get_default_graph()
        for collection in curr_graph.collections:
            tmp_graph.get_collection_ref(
                collection)[:] = curr_graph.get_collection(collection)
        with tmp_graph.as_default():
            func_inputs = _get_defun_inputs(args)

            captures = {}
            with capture_tensors(captures):
                func_outputs = func(*func_inputs, **kwds)
            ids = list(sorted(captures.keys()))
            if ids:
                extra_inputs, extra_placeholders = zip(
                    *[captures[x] for x in ids])
            else:
                extra_inputs = []
                extra_placeholders = []
            outputs_list = nest.flatten(func_outputs)
            output_shapes = [x.shape for x in outputs_list if x is not None]

    flat_inputs = [
        x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
    ]
    all_inputs = flat_inputs + list(extra_placeholders)

    func_def_outputs = [
        ag_core.getval(x) for x in outputs_list if x is not None
    ]
    inference_function_def = graph_to_function_def.graph_to_function_def(
        tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
    # Register any other functions defined in the graph
    # TODO(ashankar): Oh lord, forgive me for this lint travesty.
    for f in tmp_graph._functions.values():  # pylint: disable=protected-access
        # TODO(ashankar): What about the gradient registry?
        _register_with_name(f.name, f.definition)
    _register_with_name(_inference_name(name), inference_function_def)

    return _GraphModeFunction(all_inputs, extra_inputs,
                              inference_function_def, tmp_graph,
                              tmp_graph.get_operations(), func_outputs,
                              _map_sequence_obj_to_idx(func_def_outputs),
                              output_shapes)
예제 #7
0
 def testTwoInputsSameOp(self):
   g = ops.Graph()
   with g.as_default():
     m = array_ops.placeholder(dtypes.float32)
     s, u, v = linalg_ops.svd(m)
     ss = math_ops.reduce_sum(s)
     uu = math_ops.reduce_sum(u)
     vv = math_ops.reduce_sum(v)
     result = ss + uu + vv
   f = graph_to_function_def.graph_to_function_def(
       g,
       g.get_operations()[1:],  # skip the placeholder
       [s, u, v],
       [result])
   self.assertEqual(len(f.signature.input_arg), 3)
예제 #8
0
 def testTwoInputsSameOp(self):
     g = ops.Graph()
     with g.as_default():
         m = array_ops.placeholder(dtypes.float32)
         s, u, v = linalg_ops.svd(m)
         ss = math_ops.reduce_sum(s)
         uu = math_ops.reduce_sum(u)
         vv = math_ops.reduce_sum(v)
         result = ss + uu + vv
     f = graph_to_function_def.graph_to_function_def(
         g,
         g.get_operations()[1:],  # skip the placeholder
         [s, u, v],
         [result])
     self.assertEqual(len(f.signature.input_arg), 3)
예제 #9
0
def _defun_internal(name, func, args, kwds):
  """Defines and returns graph-mode version of func."""
  with context.graph_mode():
    tmp_graph = ops.Graph()
    # Copy the graph collections to ensure summaries and other things work. This
    # lets the function access (but not mutate) collections of the containing
    # graph, such as the global step and the summary writer collections.
    curr_graph = ops.get_default_graph()
    for collection in curr_graph.collections:
      tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
          collection)
    with tmp_graph.as_default():
      func_inputs = _get_defun_inputs(args)

      captures = {}
      with capture_tensors(captures):
        func_outputs = func(*func_inputs, **kwds)
      ids = list(sorted(captures.keys()))
      if ids:
        extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
      else:
        extra_inputs = []
        extra_placeholders = []
      outputs_list = nest.flatten(func_outputs)
      output_shapes = [x.shape for x in outputs_list if x is not None]

  flat_inputs = [
      x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
  ]
  all_inputs = flat_inputs + list(extra_placeholders)

  func_def_outputs = [x for x in outputs_list if x is not None]
  inference_function_def = graph_to_function_def.graph_to_function_def(
      tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
  # Register any other functions defined in the graph
  # TODO(ashankar): Oh lord, forgive me for this lint travesty.
  for f in tmp_graph._functions.values():  # pylint: disable=protected-access
    # TODO(ashankar): What about the gradient registry?
    _register_with_name(f.name, f.definition)
  _register_with_name(_inference_name(name), inference_function_def)

  return _GraphModeFunction(
      all_inputs, extra_inputs, inference_function_def, tmp_graph,
      tmp_graph.get_operations(), func_outputs,
      _map_sequence_obj_to_idx(func_def_outputs), output_shapes)
  def _build_function_def(self):
    with ops.Graph().as_default() as g:
      # Inputs
      x = array_ops.placeholder(dtypes.float32, name="x")
      y = array_ops.placeholder(dtypes.float32, name="y")

      # Outputs
      sum_squares = math_ops.add_n(
          [math_ops.pow(x, 2), math_ops.pow(y, 2)], name="sum_squares")
      sum_cubes = math_ops.add_n(
          [math_ops.pow(x, 3), math_ops.pow(y, 3)], name="sum_cubes")
    fdef = graph_to_function_def.graph_to_function_def(
        g,
        g.get_operations(),
        [x, y],  # Inputs
        [sum_squares, sum_cubes])  # Outputs.
    fdef.signature.name = "_whats_in_a_name"
    return fdef
예제 #11
0
    def _build_function_def(self):
        with ops.Graph().as_default() as g:
            # Inputs
            x = array_ops.placeholder(dtypes.float32, name="x")
            y = array_ops.placeholder(dtypes.float32, name="y")

            # Outputs
            sum_squares = math_ops.add_n(
                [math_ops.pow(x, 2), math_ops.pow(y, 2)], name="sum_squares")
            sum_cubes = math_ops.add_n(
                [math_ops.pow(x, 3), math_ops.pow(y, 3)], name="sum_cubes")
        fdef = graph_to_function_def.graph_to_function_def(
            g,
            g.get_operations(),
            [x, y],  # Inputs
            [sum_squares, sum_cubes])  # Outputs.
        fdef.signature.name = "_whats_in_a_name"
        return fdef
  def _build_function_def(self):
    with ops.Graph().as_default() as g:
      # Inputs:    x    y    z
      #            |\   |   /
      #            | \  |  /
      #            |  foo_1     list_output
      #            |   / \       /       \
      #            | d_1 e_1  a:1        a:0
      #            |  \   |   /           |
      #            |   \  |  /            |
      #            |    foo_2             |
      #            |     / \              |
      # Outputs:   x   d_2 e_2           a:0

      x = array_ops.placeholder(dtypes.float32, name="x")
      y = array_ops.placeholder(dtypes.int32, name="y")
      z = array_ops.placeholder(dtypes.int32, name="z")

      d_1, e_1 = test_ops._op_def_lib.apply_op(
          "Foo1", name="foo_1", a=x, b=y, c=z)

      list_output0, list_output1 = test_ops.list_output(
          T=[dtypes.int32, dtypes.int32], name="list_output")

      d_2, e_2 = test_ops.foo1(a=d_1, b=e_1, c=list_output1, name="foo_2")

    fdef = graph_to_function_def.graph_to_function_def(
        g,
        g.get_operations(),
        [x, y, z],  # Inputs
        [x, d_2, e_2, list_output0])  # Outputs.

    # Assert that the FunctionDef was correctly built.
    assert len(fdef.node_def) == 3  # 2 Foo1 nodes and 1 ListOutput node.
    assert fdef.node_def[0].op == "Foo1"
    assert fdef.node_def[0].input == ["x", "y", "z"]
    assert fdef.node_def[1].op == "ListOutput"
    assert not fdef.node_def[1].input
    assert fdef.node_def[2].op == "Foo1"
    assert fdef.node_def[2].input == [
        "foo_1:d:0", "foo_1:e:0", "list_output:a:1"
    ]
    return fdef
예제 #13
0
def make_function_def(graph, operations, inputs, outputs):
  """Makes function def where accesses to resources are serialized."""
  last_op_using_resource_tensor = {}

  # TODO(apassos) probably control flow has to be handled delicately here as in
  # if a resource is accessed inside a control flow context we need the control
  # dependency to point to something outside the context which is guaranteed to
  # happen after the access.
  #
  # TODO(apassos) this should do some form of alias analysis as ops which
  # forward the resources such as Identity and Switch can cause serialization to
  # fail.
  for op in operations:
    for t in op.inputs:
      if t.dtype == dtypes.resource:
        if t.name in last_op_using_resource_tensor:
          op._add_control_input(last_op_using_resource_tensor[t.name])  # pylint: disable=protected-access
        last_op_using_resource_tensor[t.name] = op
  return graph_to_function_def.graph_to_function_def(
      graph, operations, inputs, outputs)
예제 #14
0
def make_function_def(graph, operations, inputs, outputs):
  """Makes function def where accesses to resources are serialized."""
  last_op_using_resource_tensor = {}

  # TODO(apassos) probably control flow has to be handled delicately here as in
  # if a resource is accessed inside a control flow context we need the control
  # dependency to point to something outside the context which is guaranteed to
  # happen after the access.
  #
  # TODO(apassos) this should do some form of alias analysis as ops which
  # forward the resources such as Identity and Switch can cause serialization to
  # fail.
  for op in operations:
    for t in op.inputs:
      if t.dtype == dtypes.resource:
        if t.name in last_op_using_resource_tensor:
          op._add_control_input(last_op_using_resource_tensor[t.name])  # pylint: disable=protected-access
        last_op_using_resource_tensor[t.name] = op
  return graph_to_function_def.graph_to_function_def(
      graph, operations, inputs, outputs)
예제 #15
0
  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None or self._c_func is not None:
      return

    # Copy variable collections (by reference) from the parent graph such that
    # name based variable sharing (e.g. via tf.make_template) works between the
    # func graph and parent graph.
    variable_keys = []
    variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS)  # pylint: disable=protected-access
    variable_keys.append(vs._VARSTORE_KEY)  # pylint: disable=protected-access

    parent_graph = ops.get_default_graph()
    collections_ref = {
        key: parent_graph.get_collection_ref(key) for key in variable_keys}

    temp_graph = func_graph_from_py_func(
        self._func,
        self._arg_names,
        self._arg_types,
        self._func_name,
        self._capture_by_value,
        self._caller_device,
        collections_ref=collections_ref,
        allowlisted_stateful_ops=self._allowlisted_stateful_ops,
        capture_resource_var_by_value=self._capture_resource_var_by_value)

    self._extra_inputs = temp_graph.extra_inputs
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Extra kwargs are treated as attrs on the function def.
    if self._func_name:
      base_func_name = self._func_name
    else:
      base_func_name = function_utils.get_func_name(self._func)
      if self._grad_func:
        base_func_name += ("_%s" % self._grad_func.name)
    kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)

    if not temp_graph._c_graph:  # pylint: disable=protected-access
      # Build the FunctionDef
      self._definition = graph_to_function_def.graph_to_function_def(
          temp_graph,
          temp_graph.get_operations(),
          temp_graph.inputs,
          temp_graph.outputs,
          out_names=self._out_names)

      for k in kwargs_attr:
        self._definition.attr[k].CopyFrom(kwargs_attr[k])

      # Hash the definition and its dependencies.
      self._hash_str = self._create_hash_str(
          self._definition.signature.input_arg,
          self._definition.signature.output_arg, self._definition.node_def)

      # Finally, we decide the function name to use.  If not specified,
      # make up something which is almost certainly unique (but deterministic).
      if not self._func_name:
        self._func_name = "_".join([base_func_name, self._hash_str])
      self._definition.signature.name = self._func_name
      if self._func.__doc__:
        self._definition.signature.description = self._func.__doc__

      self._op_def = self._definition.signature
    else:  # C API is enabled
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      # pylint: disable=protected-access
      c_func = c_api.TF_GraphToFunction_wrapper(
          temp_graph._c_graph,
          base_func_name,
          self._func_name is None,  # append_hash_to_fn_name
          None,  # opers
          [t._as_tf_output() for t in temp_graph.inputs],
          [t._as_tf_output() for t in temp_graph.outputs],
          output_names,
          [], # control_outputs
          [], # control_output_names
          None,  # opts
          description)
      self._c_func = c_api_util.ScopedTFFunction(c_func)
      # pylint: enable=protected-access
      self._set_c_attrs(kwargs_attr)

      # Set cached fields: _op_def and _func_name (if not already set)
      self._op_def = self.definition.signature
      if self._func_name:
        assert self._func_name == self._op_def.name
      else:
        self._func_name = compat.as_str(self._op_def.name)

    self._stateful_ops = [(op.name, op.type)
                          for op in temp_graph.get_operations()
                          if op._is_stateful]  # pylint: disable=protected-access
예제 #16
0
    def _create_definition_if_needed_impl(self):
        """This is not what you want, see _create_definition_if_needed."""
        if self._definition is not None or self._c_func is not None:
            return

        temp_graph = func_graph_from_py_func(
            self._func,
            self._arg_names,
            self._arg_types,
            self._func_name,
            self._capture_by_value,
            self._caller_device,
            whitelisted_stateful_ops=self._whitelisted_stateful_ops)

        self._extra_inputs = temp_graph.extra_inputs
        # pylint: disable=protected-access
        self._sub_functions = temp_graph._functions
        # pylint: enable=protected-access

        # Extra kwargs are treated as attrs on the function def.
        if self._func_name:
            base_func_name = self._func_name
        else:
            base_func_name = function_utils.get_func_name(self._func)
            if self._grad_func:
                base_func_name += ("_%s" % self._grad_func.name)
        kwargs_attr = _parse_kwargs_as_attrs(base_func_name,
                                             **self._extra_kwargs)

        if not temp_graph._c_graph:  # pylint: disable=protected-access
            # Build the FunctionDef
            self._definition = graph_to_function_def.graph_to_function_def(
                temp_graph,
                temp_graph.get_operations(),
                temp_graph.inputs,
                temp_graph.outputs,
                out_names=self._out_names)

            for k in kwargs_attr:
                self._definition.attr[k].CopyFrom(kwargs_attr[k])

            # Hash the definition and its dependencies.
            self._hash_str = self._create_hash_str(
                self._definition.signature.input_arg,
                self._definition.signature.output_arg,
                self._definition.node_def)

            # Finally, we decide the function name to use.  If not specified,
            # make up something which is almost certainly unique (but deterministic).
            if not self._func_name:
                self._func_name = "_".join([base_func_name, self._hash_str])
            self._definition.signature.name = self._func_name
            if self._func.__doc__:
                self._definition.signature.description = self._func.__doc__

            self._op_def = self._definition.signature
        else:  # C API is enabled
            output_names = ([compat.as_bytes(x) for x in self._out_names]
                            if self._out_names else [])
            description = self._func.__doc__ or None
            # pylint: disable=protected-access
            c_func = c_api.TF_GraphToFunction_wrapper(
                temp_graph._c_graph,
                base_func_name,
                self._func_name is None,  # append_hash_to_fn_name
                None,  # opers
                [t._as_tf_output() for t in temp_graph.inputs],
                [t._as_tf_output() for t in temp_graph.outputs],
                output_names,
                None,  # opts
                description)
            self._c_func = c_api_util.ScopedTFFunction(c_func)
            # pylint: enable=protected-access
            self._set_c_attrs(kwargs_attr)

            # Set cached fields: _op_def and _func_name (if not already set)
            self._op_def = self.definition.signature
            if self._func_name:
                assert self._func_name == self._op_def.name
            else:
                self._func_name = compat.as_str(self._op_def.name)

        self._stateful_ops = [(op.name, op.type)
                              for op in temp_graph.get_operations()
                              if op.op_def.is_stateful]
예제 #17
0
def _graph_callable_internal(func, shape_and_dtypes):
    """Defines and returns a template version of func.

  Under the hood we make two function objects, each wrapping a different version
  of the graph-mode code. One version immediately runs variable initialization
  before making the variable's Tensors available for use, while the other
  version replaces the Variables with placeholders which become function
  arguments and get the current variable's value.

  Limitations in (2) and (4) are because this does not implement a graph-mode
  Variable class which has a convert_to_tensor(as_ref=True) method and a
  initialized_value method. This is fixable.

  Args:
    func: The tfe Python function to compile.
    shape_and_dtypes: A list of type ShapeAndDtype.

  Raises:
    ValueError: If any one of func's outputs is not a Tensor.

  Returns:
    Callable graph object.
  """
    with context.graph_mode():
        # This graph will store both the initialization and the call version of the
        # wrapped function. It will later be used by the backprop code to build the
        # backprop graph, if necessary.
        tmp_graph = tf_ops.Graph()
        with tmp_graph.as_default():
            # Placeholders for the non-variable inputs.
            func_inputs = _get_graph_callable_inputs(shape_and_dtypes)
            func_num_args = len(tf_inspect.getargspec(func).args)
            if len(func_inputs) != func_num_args:
                raise TypeError(
                    "The number of arguments accepted by the decorated "
                    "function `%s` (%d) must match the number of "
                    "ShapeAndDtype objects passed to the graph_callable() "
                    "decorator (%d)." %
                    (func.__name__, func_num_args, len(func_inputs)))

            # First call the function to generate a graph which can initialize all
            # variables. As a side-effect this will populate the variable capturing
            # scope's view of which variables exist.
            variable_captures = _VariableCapturingScope()
            captures = {}
            with variable_captures.initializing_scope(
            ), function.capture_tensors(captures):
                func_outputs = func(*func_inputs)
            outputs_list = nest.flatten(func_outputs)
            output_shapes = [x.shape for x in outputs_list if x is not None]
            if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list):
                raise ValueError("Found non-tensor output in %s" %
                                 str(outputs_list))
            initializing_operations = tmp_graph.get_operations()

            # Call the function again, now replacing usages of variables with
            # placeholders. This assumes the variable capturing scope created above
            # knows about all variables.
            with variable_captures.capturing_scope(), function.capture_tensors(
                    captures):
                captured_outputs = func(*func_inputs)
            captured_outlist = nest.flatten(captured_outputs)
            capturing_operations = tmp_graph.get_operations(
            )[len(initializing_operations):]

    sorted_variables = sorted(variable_captures.variables.values(),
                              key=lambda x: x.name)
    variable_placeholders = [x.placeholder for x in sorted_variables]
    ids = list(sorted(captures.keys()))
    if ids:
        extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
    else:
        extra_inputs = []
        extra_placeholders = []

    flat_inputs = [
        x for x in nest.flatten(func_inputs) if isinstance(x, tf_ops.Tensor)
    ]
    placeholder_inputs = flat_inputs + list(extra_placeholders)
    all_inputs = variable_placeholders + placeholder_inputs

    func_def_outputs = [
        x for x in outputs_list if isinstance(x, tf_ops.Tensor)
    ]
    initializer_function_def = graph_to_function_def.graph_to_function_def(
        tmp_graph, initializing_operations, placeholder_inputs,
        func_def_outputs)
    # TODO(ashankar): Oh lord, forgive me for this lint travesty.
    # Also, what about the gradient registry of these functions? Those need to be
    # addressed as well.
    for f in tmp_graph._functions.values():  # pylint: disable=protected-access
        function._register_with_name(f.name, f.definition)  # pylint: disable=protected-access
    function._register_with_name(
        function._inference_name(func.__name__),  # pylint: disable=protected-access
        initializer_function_def)
    initializer_function = function._GraphModeFunction(  # pylint: disable=protected-access
        placeholder_inputs,
        extra_inputs,
        initializer_function_def,
        tmp_graph,
        initializing_operations,
        func_outputs,
        function._map_sequence_obj_to_idx(func_def_outputs),  # pylint: disable=protected-access
        output_shapes)

    capture_func_def_outputs = [
        x for x in captured_outlist if isinstance(x, tf_ops.Tensor)
    ]
    captured_function_def = graph_to_function_def.graph_to_function_def(
        tmp_graph, capturing_operations, all_inputs, capture_func_def_outputs)
    function._register_with_name(
        function._inference_name(func.__name__),  # pylint: disable=protected-access
        captured_function_def)
    captured_function = _FunctionObject(
        sorted_variables,
        all_inputs,
        extra_inputs,
        captured_function_def,
        tmp_graph,
        capturing_operations,
        captured_outputs,
        function._map_sequence_obj_to_idx(capture_func_def_outputs),  # pylint: disable=protected-access
        output_shapes)

    return _InitializingFunctionObject(captured_function, initializer_function)
예제 #18
0
  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None or self._c_func is not None:
      return

    temp_graph = func_graph_from_py_func(
        self._func, self._arg_names, self._arg_types, self._func_name,
        self._capture_by_value, self._caller_device)

    self._extra_inputs = temp_graph.extra_inputs
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Extra kwargs are treated as attrs on the function def.
    if self._func_name:
      base_func_name = self._func_name
    else:
      base_func_name = _get_func_name(self._func)
      if self._grad_func:
        base_func_name += ("_%s" % self._grad_func.name)
    kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)

    if not temp_graph._c_graph:  # pylint: disable=protected-access
      # Build the FunctionDef
      self._definition = graph_to_function_def.graph_to_function_def(
          temp_graph,
          temp_graph.get_operations(),
          temp_graph.inputs,
          temp_graph.outputs,
          out_names=self._out_names)

      for k in kwargs_attr:
        self._definition.attr[k].CopyFrom(kwargs_attr[k])

      # Hash the definition and its dependencies.
      self._hash_str = self._create_hash_str(
          self._definition.signature.input_arg,
          self._definition.signature.output_arg, self._definition.node_def)

      # Finally, we decide the function name to use.  If not specified,
      # make up something which is almost certainly unique (but deterministic).
      if not self._func_name:
        self._func_name = "_".join([base_func_name, self._hash_str])
      self._definition.signature.name = self._func_name
      if self._func.__doc__:
        self._definition.signature.description = self._func.__doc__

      self._op_def = self._definition.signature
    else:  # C API is enabled
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      # pylint: disable=protected-access
      c_func = c_api.TF_GraphToFunction_wrapper(
          temp_graph._c_graph,
          base_func_name,
          self._func_name is None,  # append_hash_to_fn_name
          None,  # opers
          [t._as_tf_output() for t in temp_graph.inputs],
          [t._as_tf_output() for t in temp_graph.outputs],
          output_names,
          None,  # opts
          description)
      self._c_func = c_api_util.ScopedTFFunction(c_func)
      # pylint: enable=protected-access
      self._set_c_attrs(kwargs_attr)

      # Set cached fields: _op_def and _func_name (if not already set)
      self._op_def = self.definition.signature
      if self._func_name:
        assert self._func_name == self._op_def.name
      else:
        self._func_name = compat.as_str(self._op_def.name)

    self._stateful_ops = [(op.name, op.type)
                          for op in temp_graph.get_operations()
                          if op.op_def.is_stateful]
예제 #19
0
  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None or self._c_func is not None:
      return

    # Copy variable collections (by reference) from the parent graph such that
    # name based variable sharing (e.g. via tf.make_template) works between the
    # func graph and parent graph.
    variable_keys = []
    variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS)  # pylint: disable=protected-access
    variable_keys.append(vs._VARSTORE_KEY)  # pylint: disable=protected-access

    collections_ref = {}
    parent_collections_ref = ops.get_default_graph()._collections  # pylint: disable=protected-access
    for key in variable_keys:
      if key not in parent_collections_ref:
        parent_collections_ref[key] = collections_ref[key] = []
      else:
        collections_ref[key] = parent_collections_ref[key]

    temp_graph = func_graph_from_py_func(
        self._func,
        self._arg_names,
        self._arg_types,
        self._func_name,
        self._capture_by_value,
        self._caller_device,
        collections_ref=collections_ref,
        whitelisted_stateful_ops=self._whitelisted_stateful_ops,
        capture_resource_var_by_value=self._capture_resource_var_by_value)

    self._extra_inputs = temp_graph.extra_inputs
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Extra kwargs are treated as attrs on the function def.
    if self._func_name:
      base_func_name = self._func_name
    else:
      base_func_name = function_utils.get_func_name(self._func)
      if self._grad_func:
        base_func_name += ("_%s" % self._grad_func.name)
    kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)

    if not temp_graph._c_graph:  # pylint: disable=protected-access
      # Build the FunctionDef
      self._definition = graph_to_function_def.graph_to_function_def(
          temp_graph,
          temp_graph.get_operations(),
          temp_graph.inputs,
          temp_graph.outputs,
          out_names=self._out_names)

      for k in kwargs_attr:
        self._definition.attr[k].CopyFrom(kwargs_attr[k])

      # Hash the definition and its dependencies.
      self._hash_str = self._create_hash_str(
          self._definition.signature.input_arg,
          self._definition.signature.output_arg, self._definition.node_def)

      # Finally, we decide the function name to use.  If not specified,
      # make up something which is almost certainly unique (but deterministic).
      if not self._func_name:
        self._func_name = "_".join([base_func_name, self._hash_str])
      self._definition.signature.name = self._func_name
      if self._func.__doc__:
        self._definition.signature.description = self._func.__doc__

      self._op_def = self._definition.signature
    else:  # C API is enabled
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      # pylint: disable=protected-access
      c_func = c_api.TF_GraphToFunction_wrapper(
          temp_graph._c_graph,
          base_func_name,
          self._func_name is None,  # append_hash_to_fn_name
          None,  # opers
          [t._as_tf_output() for t in temp_graph.inputs],
          [t._as_tf_output() for t in temp_graph.outputs],
          output_names,
          [], # control_outputs
          [], # control_output_names
          None,  # opts
          description)
      self._c_func = c_api_util.ScopedTFFunction(c_func)
      # pylint: enable=protected-access
      self._set_c_attrs(kwargs_attr)

      # Set cached fields: _op_def and _func_name (if not already set)
      self._op_def = self.definition.signature
      if self._func_name:
        assert self._func_name == self._op_def.name
      else:
        self._func_name = compat.as_str(self._op_def.name)

    self._stateful_ops = [(op.name, op.type)
                          for op in temp_graph.get_operations()
                          if op.op_def.is_stateful]
예제 #20
0
  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None or self._c_func is not None:
      return

    # Create the func_def object.
    temp_graph = _FuncGraph(capture_by_value=self._capture_by_value)
    with temp_graph.as_default():
      # List of placeholders for the function_def.
      inputs = []
      for (argname, argtype) in self._args:
        argholder = array_ops.placeholder(argtype, name=argname)
        inputs.append(argholder)
      # Call func and gather the output tensors.
      with vs.variable_scope("", custom_getter=temp_graph.getvar):
        outputs = self._func(*inputs)

      # There is no way of distinguishing between a function not returning
      # anything and a function returning None in Python.
      # We need to allow the former and ideally want to forbid the latter as
      # it is most likely user error.
      # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
      # allow users to explicitly mark the function as not returning anything.
      # For now, we allow a single None return and interpret it as a function
      # with no output.
      if outputs is None:
        outputs = []
      else:
        # If func only returned one value, make it a tuple.
        if not isinstance(outputs, (list, tuple)):
          outputs = (outputs,)
        if any([_ is None for _ in outputs]):
          raise ValueError("Function can not return None.")
      # Ensures each output is a Tensor.
      outputs = [ops.convert_to_tensor(_) for _ in outputs]
    self._extra_inputs = temp_graph.extra_inputs
    inputs.extend(temp_graph.extra_args)
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Extra kwargs are treated as attrs on the function def.
    base_func_name = self._func_name or _get_func_name(self._func)
    kwargs_attr = _parse_kwargs_as_attrs(base_func_name,
                                         **self._extra_kwargs)

    if not temp_graph._c_graph:  # pylint: disable=protected-access
      # Build the FunctionDef
      self._definition = graph_to_function_def.graph_to_function_def(
          temp_graph,
          temp_graph.get_operations(),
          inputs,
          outputs,
          out_names=self._out_names)

      for k in kwargs_attr:
        self._definition.attr[k].CopyFrom(kwargs_attr[k])

      # Hash the definition and its dependencies.
      self._hash_str = self._create_hash_str(
          self._definition.signature.input_arg,
          self._definition.signature.output_arg, self._definition.node_def)

      # Finally, we decide the function name to use.  If not specified,
      # make up something which is almost certainly unique (but deterministic).
      if not self._func_name:
        self._func_name = "_".join([base_func_name, self._hash_str])
      self._definition.signature.name = self._func_name
      if self._func.__doc__:
        self._definition.signature.description = self._func.__doc__

      self._op_def = self._definition.signature
    else:  # C API is enabled
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      # pylint: disable=protected-access
      with errors.raise_exception_on_not_ok_status() as status:
        self._c_func = c_api.TF_GraphToFunction_wrapper(
            temp_graph._c_graph,
            base_func_name,
            self._func_name is None,  # append_hash_to_fn_name
            None,  # opers
            [t._as_tf_output() for t in inputs],
            [t._as_tf_output() for t in outputs],
            output_names,
            None,  # opts
            description,
            status)
      # pylint: enable=protected-access
      self._set_c_attrs(kwargs_attr)

      # Set cached fields: _op_def and _func_name (if not already set)
      self._op_def = self.definition.signature
      if self._func_name:
        assert self._func_name == self._op_def.name
      else:
        self._func_name = compat.as_str(self._op_def.name)