Ejemplo n.º 1
0
    def testCondNested(self):
        with context.graph_mode(), self.test_session():
            v = resource_variable_ops.ResourceVariable(1.0)
            variables.global_variables_initializer().run()
            p = array_ops.placeholder(dtype=dtypes.bool)
            q = array_ops.placeholder(dtype=dtypes.bool)
            with function.AutomaticControlDependencies() as c:

                def true_fn():
                    v.assign(v + 1, name='true')
                    return 1.0

                def false_fn():
                    def inner_true_fn():
                        v.assign(v * 2, name='false_true')
                        return 2.0

                    def inner_false_fn():
                        v.assign(v * 3, name='false_false')
                        return 3.0

                    control_flow_ops.cond(q, inner_true_fn, inner_false_fn)
                    return 1.0

                control_flow_ops.cond(p, true_fn, false_fn)
                with ops.name_scope('final'):
                    val = v.read_value()
                val = c.mark_as_return(val)
            self.assertAllEqual(val.eval(feed_dict={p: False, q: False}), 3.0)
            self.assertAllEqual(val.eval(feed_dict={p: False, q: True}), 6.0)
            self.assertAllEqual(val.eval(feed_dict={p: True, q: True}), 7.0)
            self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0)
Ejemplo n.º 2
0
 def testBasic(self):
     with context.graph_mode(), self.test_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         variables.global_variables_initializer().run()
         with function.AutomaticControlDependencies() as c:
             v.assign(v + 1)
             v.assign(2 * v)
             val = v.read_value()
             val = c.mark_as_return(val)
         self.assertAllEqual(val.eval(), 4.0)
Ejemplo n.º 3
0
    def testCondOneBranch(self):
        with context.graph_mode(), self.test_session():
            v = resource_variable_ops.ResourceVariable(1.0)
            variables.global_variables_initializer().run()
            p = array_ops.placeholder(dtype=dtypes.bool)
            with function.AutomaticControlDependencies() as c:

                def true_fn():
                    return 0.0

                def false_fn():
                    v.assign(v + 4)
                    return 1.0

                control_flow_ops.cond(p, true_fn, false_fn)
                val = v.read_value()
                val = c.mark_as_return(val)
            self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0)
            self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0)
  def testCondMustRunSeparateRead(self):
    with context.graph_mode(), self.test_session():
      v = resource_variable_ops.ResourceVariable(1.0)
      variables.global_variables_initializer().run()
      p = array_ops.placeholder(dtype=dtypes.bool)
      with function.AutomaticControlDependencies() as c:

        def true_fn():
          v.assign(v + 1)
          return 0.0

        def false_fn():
          v.assign(v + 4)
          return 1.0

        control_flow_ops.cond(p, true_fn, false_fn)
        one = constant_op.constant(1.0)
        c.mark_as_return(one)
      one.eval(feed_dict={p: False})
      self.assertAllEqual(v.read_value().eval(), 5.0)
      one.eval(feed_dict={p: True})
      self.assertAllEqual(v.read_value().eval(), 6.0)
Ejemplo n.º 5
0
def _graph_callable_internal(func, shape_and_dtypes):
  """Defines and returns a template version of func.

  Under the hood we make two function objects, each wrapping a different version
  of the graph-mode code. One version immediately runs variable initialization
  before making the variable's Tensors available for use, while the other
  version replaces the Variables with placeholders which become function
  arguments and get the current variable's value.

  Limitations in (2) and (4) are because this does not implement a graph-mode
  Variable class which has a convert_to_tensor(as_ref=True) method and a
  initialized_value method. This is fixable.

  Args:
    func: The tfe Python function to compile.
    shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects.

  Raises:
    ValueError: If any one of func's outputs is not a Tensor.

  Returns:
    Callable graph object.
  """
  container = tf_ops.get_default_graph()._container  # pylint: disable=protected-access
  graph_key = tf_ops.get_default_graph()._graph_key  # pylint: disable=protected-access
  with context.graph_mode():
    # This graph will store both the initialization and the call version of the
    # wrapped function. It will later be used by the backprop code to build the
    # backprop graph, if necessary.
    captures = {}
    tmp_graph = function.CapturingGraph(captures)
    # Inherit the graph key from the original graph to ensure optimizers don't
    # misbehave.
    tmp_graph._container = container  # pylint: disable=protected-access
    tmp_graph._graph_key = graph_key  # pylint: disable=protected-access
    with tmp_graph.as_default():
      # Placeholders for the non-variable inputs.
      func_inputs = _get_graph_callable_inputs(shape_and_dtypes)
      func_num_args = len(tf_inspect.getargspec(func).args)
      if len(func_inputs) != func_num_args:
        raise TypeError("The number of arguments accepted by the decorated "
                        "function `%s` (%d) must match the number of "
                        "ShapeAndDtype objects passed to the graph_callable() "
                        "decorator (%d)." %
                        (func.__name__, func_num_args, len(func_inputs)))

      # First call the function to generate a graph which can initialize all
      # variables. As a side-effect this will populate the variable capturing
      # scope's view of which variables exist.
      variable_captures = _VariableCapturingScope()
      with variable_captures.initializing_scope(), function.capture_tensors(
          captures), function.AutomaticControlDependencies() as a:
        func_outputs = func(*func_inputs)
        outputs_list = nest.flatten(func_outputs)
        for i, x in enumerate(outputs_list):
          if x is not None:
            outputs_list[i] = a.mark_as_return(x)
      if len(outputs_list) == 1 and outputs_list[0] is None:
        outputs_list = []
      output_shapes = [x.shape for x in outputs_list]
      if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list):
        raise ValueError("Found non-tensor output in %s" % str(outputs_list))
      initializing_operations = tmp_graph.get_operations()

      # Call the function again, now replacing usages of variables with
      # placeholders. This assumes the variable capturing scope created above
      # knows about all variables.
      tmp_graph.clear_resource_control_flow_state()
      with variable_captures.capturing_scope(), function.capture_tensors(
          captures), function.AutomaticControlDependencies() as a:
        captured_outputs = func(*func_inputs)
      captured_outlist = nest.flatten(captured_outputs)
      for i, x in enumerate(captured_outlist):
        if x is not None:
          captured_outlist[i] = a.mark_as_return(x)
      capturing_operations = tmp_graph.get_operations()[
          len(initializing_operations):]

  sorted_variables = sorted(variable_captures.variables.values(),
                            key=lambda x: x.name)
  ids = list(sorted(captures.keys()))
  if ids:
    extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
  else:
    extra_inputs = []
    extra_placeholders = []

  flat_inputs = [x for x in nest.flatten(func_inputs)
                 if isinstance(x, tf_ops.Tensor)]
  placeholder_inputs = flat_inputs+ list(extra_placeholders)

  func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)]
  initialization_name = function._inference_name(func.__name__)  # pylint: disable=protected-access
  # TODO(ashankar): Oh lord, forgive me for this lint travesty.
  # Also, what about the gradient registry of these functions? Those need to be
  # addressed as well.
  for f in tmp_graph._functions.values():  # pylint: disable=protected-access
    function._register(f._c_func.func)  # pylint: disable=protected-access
  initializer_function = function.GraphModeFunction(
      initialization_name,
      placeholder_inputs,
      extra_inputs,
      tmp_graph,
      initializing_operations,
      func_def_outputs,
      func_outputs,
      output_shapes)

  capture_func_def_outputs = [
      x for x in captured_outlist if isinstance(x, tf_ops.Tensor)]
  captured_function_name = function._inference_name(func.__name__)  # pylint: disable=protected-access
  captured_function = function.GraphModeFunction(
      captured_function_name,
      placeholder_inputs,
      extra_inputs,
      tmp_graph,
      capturing_operations,
      capture_func_def_outputs,
      captured_outputs,
      output_shapes,
      variables=[x.variable for x in sorted_variables])

  return _InitializingFunctionObject(captured_function, initializer_function,
                                     shape_and_dtypes)