def testRefEnter(self):
    with self.test_session():
      v = tf.Variable(7)

      enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
      nine = tf.constant(9)
      enter_nine = control_flow_ops.enter(nine, "foo_1")
      op = tf.assign(enter_v, enter_nine)
      v2 = control_flow_ops.with_dependencies([op], enter_v)
      v3 = control_flow_ops.exit(v2)
      tf.initialize_all_variables().run()
      self.assertEqual(9, v3.eval())
  def testRefEnter(self):
    with self.test_session():
      v = tf.Variable(7)

      enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
      nine = tf.constant(9)
      enter_nine = control_flow_ops.enter(nine, "foo_1")
      op = tf.assign(enter_v, enter_nine)
      v2 = control_flow_ops.with_dependencies([op], enter_v)
      v3 = control_flow_ops.exit(v2)
      tf.initialize_all_variables().run()
      self.assertEqual(9, v3.eval())
def _ExitGrad(op, grad):
    """Gradients for an exit op are calculated using an Enter op."""
    graph = ops.get_default_graph()
    # pylint: disable=protected-access
    op_ctxt = op._get_control_flow_context()
    grad_ctxt = graph._get_control_flow_context()
    # pylint: enable=protected-access
    if not grad_ctxt.back_prop:
        # The flag `back_prop` is set by users to suppress gradient
        # computation for this loop. If the attribute `back_prop` is false,
        # no gradient computation.
        return None

    if op_ctxt.grad_state:
        raise TypeError("Second-order gradient for while loops not supported.")

    if isinstance(grad, ops.Tensor):
        grad_ctxt.AddName(grad.name)
    else:
        if not isinstance(
                grad,
            (indexed_slices.IndexedSlices, sparse_tensor.SparseTensor)):
            raise TypeError(
                f"Type {type(grad)} not supported, must be either"
                "`indexed_slices.IndexedSlices` or `SparseTensor`.")
        grad_ctxt.AddName(grad.values.name)
        grad_ctxt.AddName(grad.indices.name)
        dense_shape = grad.dense_shape
        if dense_shape is not None:
            grad_ctxt.AddName(dense_shape.name)
    grad_ctxt.Enter()
    # pylint: disable=protected-access
    result = control_flow_ops._Enter(
        grad,
        grad_ctxt.name,
        is_constant=False,
        parallel_iterations=grad_ctxt.parallel_iterations,
        name="b_exit")
    # pylint: enable=protected-access
    grad_ctxt.loop_enters.append(result)
    grad_ctxt.Exit()
    return result
Exemplo n.º 4
0
def _ExitGrad(op, grad):
  """Gradients for an exit op are calculated using an Enter op."""
  graph = ops.get_default_graph()
  # pylint: disable=protected-access
  grad_ctxt = graph._get_control_flow_context()
  # pylint: enable=protected-access
  if not grad_ctxt.back_prop:
    # The flag `back_prop` is set by users to suppress gradient
    # computation for this loop. If the attribute `back_prop` is false,
    # no gradient computation.
    return None

  # pylint: disable=protected-access
  if op._get_control_flow_context().grad_state:
    raise TypeError("Second-order gradient for while loops not supported.")
  # pylint: enable=protected-access

  if isinstance(grad, ops.Tensor):
    grad_ctxt.AddName(grad.name)
  else:
    if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
      raise TypeError("Type %s not supported" % type(grad))
    grad_ctxt.AddName(grad.values.name)
    grad_ctxt.AddName(grad.indices.name)
    dense_shape = grad.dense_shape
    if dense_shape is not None:
      grad_ctxt.AddName(dense_shape.name)
  grad_ctxt.Enter()
  # pylint: disable=protected-access
  result = control_flow_ops._Enter(
      grad, grad_ctxt.name, is_constant=False,
      parallel_iterations=grad_ctxt.parallel_iterations,
      name="b_exit")
  # pylint: enable=protected-access
  grad_ctxt.loop_enters.append(result)
  grad_ctxt.Exit()
  return result