def testRefEnter(self): with self.test_session(): v = tf.Variable(7) enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True) nine = tf.constant(9) enter_nine = control_flow_ops.enter(nine, "foo_1") op = tf.assign(enter_v, enter_nine) v2 = control_flow_ops.with_dependencies([op], enter_v) v3 = control_flow_ops.exit(v2) tf.initialize_all_variables().run() self.assertEqual(9, v3.eval())
def _ExitGrad(op, grad): """Gradients for an exit op are calculated using an Enter op.""" graph = ops.get_default_graph() # pylint: disable=protected-access op_ctxt = op._get_control_flow_context() grad_ctxt = graph._get_control_flow_context() # pylint: enable=protected-access if not grad_ctxt.back_prop: # The flag `back_prop` is set by users to suppress gradient # computation for this loop. If the attribute `back_prop` is false, # no gradient computation. return None if op_ctxt.grad_state: raise TypeError("Second-order gradient for while loops not supported.") if isinstance(grad, ops.Tensor): grad_ctxt.AddName(grad.name) else: if not isinstance( grad, (indexed_slices.IndexedSlices, sparse_tensor.SparseTensor)): raise TypeError( f"Type {type(grad)} not supported, must be either" "`indexed_slices.IndexedSlices` or `SparseTensor`.") grad_ctxt.AddName(grad.values.name) grad_ctxt.AddName(grad.indices.name) dense_shape = grad.dense_shape if dense_shape is not None: grad_ctxt.AddName(dense_shape.name) grad_ctxt.Enter() # pylint: disable=protected-access result = control_flow_ops._Enter( grad, grad_ctxt.name, is_constant=False, parallel_iterations=grad_ctxt.parallel_iterations, name="b_exit") # pylint: enable=protected-access grad_ctxt.loop_enters.append(result) grad_ctxt.Exit() return result
def _ExitGrad(op, grad): """Gradients for an exit op are calculated using an Enter op.""" graph = ops.get_default_graph() # pylint: disable=protected-access grad_ctxt = graph._get_control_flow_context() # pylint: enable=protected-access if not grad_ctxt.back_prop: # The flag `back_prop` is set by users to suppress gradient # computation for this loop. If the attribute `back_prop` is false, # no gradient computation. return None # pylint: disable=protected-access if op._get_control_flow_context().grad_state: raise TypeError("Second-order gradient for while loops not supported.") # pylint: enable=protected-access if isinstance(grad, ops.Tensor): grad_ctxt.AddName(grad.name) else: if not isinstance(grad, (ops.IndexedSlices, sparse_tensor.SparseTensor)): raise TypeError("Type %s not supported" % type(grad)) grad_ctxt.AddName(grad.values.name) grad_ctxt.AddName(grad.indices.name) dense_shape = grad.dense_shape if dense_shape is not None: grad_ctxt.AddName(dense_shape.name) grad_ctxt.Enter() # pylint: disable=protected-access result = control_flow_ops._Enter( grad, grad_ctxt.name, is_constant=False, parallel_iterations=grad_ctxt.parallel_iterations, name="b_exit") # pylint: enable=protected-access grad_ctxt.loop_enters.append(result) grad_ctxt.Exit() return result