def _MergeGrad(op, grad, _): """Gradients for a Merge op are calculated using a Switch op.""" real_op = GetRealOp(op) input_op = real_op.inputs[0].op # pylint: disable=protected-access ctxt = input_op._get_control_flow_context() # pylint: enable=protected-access if isinstance(ctxt, WhileContext): grad_ctxt = op.grad_state.grad_context # pylint: disable=protected-access return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot) # pylint: enable=protected-access elif isinstance(ctxt, CondContext): pred = ctxt.pred if isinstance(op, ControlFlowOpWrapper): # This Merge node is part of a cond within a loop. # The backprop needs to have the value of this predicate for every # iteration. So we must have its values accumulated in the forward, and # use the accumulated values as the predicate for this backprop switch. grad_state = op.grad_state real_pred = grad_state.history_map.get(pred.name) if not real_pred: # Remember the value of pred for every iteration. grad_ctxt = grad_state.grad_context grad_ctxt.Exit() history_pred = grad_state.AddForwardAccumulator(pred) grad_ctxt.Enter() # Add the stack pop op. If pred.op is in a (outer) CondContext, # the stack pop will be guarded with a switch. real_pred = grad_state.AddBackPropAccumulatedValue( history_pred, pred) grad_state.history_map[pred.name] = real_pred pred = real_pred # pylint: disable=protected-access return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad") # pylint: enable=protected-access else: num_inputs = len(real_op.inputs) cond = [ math_ops.equal(real_op.outputs[1], i) for i in xrange(num_inputs) ] # pylint: disable=protected-access return [ control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1] for i in xrange(num_inputs) ]
def testRefSwitch(self): with self.test_session(): v = tf.Variable(7) p = tf.constant(True) v1 = control_flow_ops._SwitchRefOrTensor(v, p) v2 = tf.assign(v1[1], 9) tf.initialize_all_variables().run() self.assertEqual(9, v2.eval())
def testRefSwitch(self): with self.test_session(): v = tf.Variable(7) p = tf.constant(True) v1 = control_flow_ops._SwitchRefOrTensor(v, p) v2 = tf.assign(v1[1], 9) tf.initialize_all_variables().run() self.assertEqual(9, v2.eval())
def _MergeGrad(op, grad, _): """Gradients for a Merge op are calculated using a Switch op.""" input_op = op.inputs[0].op graph = ops.get_default_graph() # pylint: disable=protected-access op_ctxt = control_flow_util.GetOutputContext(input_op) grad_ctxt = graph._get_control_flow_context() # pylint: enable=protected-access if isinstance(op_ctxt, WhileContext): # pylint: disable=protected-access return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot) # pylint: enable=protected-access elif isinstance(op_ctxt, CondContext): pred = op_ctxt.pred if grad_ctxt and grad_ctxt.grad_state: # This Merge node is part of a cond within a loop. # The backprop needs to have the value of this predicate for every # iteration. So we must have its values accumulated in the forward, and # use the accumulated values as the predicate for this backprop switch. grad_state = grad_ctxt.grad_state real_pred = grad_state.history_map.get(pred.name) if real_pred is None: # Remember the value of pred for every iteration. grad_ctxt = grad_state.grad_context grad_ctxt.Exit() history_pred = grad_state.AddForwardAccumulator(pred) grad_ctxt.Enter() # Add the stack pop op. If pred.op is in a (outer) CondContext, # the stack pop will be guarded with a switch. real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred) grad_state.history_map[pred.name] = real_pred pred = real_pred # pylint: disable=protected-access return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad") # pylint: enable=protected-access else: num_inputs = len(op.inputs) cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)] # pylint: disable=protected-access return [control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1] for i in xrange(num_inputs)]
def AddBackpropAccumulatedValue(self, history_value, value, dead_branch=False): """Add the getter for an accumulated value in the grad context. This is added to the backprop loop. Called in the grad context to get the value of an accumulated value. The stack pop op must be guarded by the pred of the controlling cond. Args: history_value: The history (a stack) of a value. value: The value that is pushed onto the stack. dead_branch: True iff the tensor is on a dead branch of a cond. Returns: The current value (the top of the stack). """ history_ctxt = history_value.op._get_control_flow_context() # Find the cond context that controls history_value if any. cond_ctxt = None value_ctxt = value.op._get_control_flow_context() while value_ctxt and value_ctxt != history_ctxt: if isinstance(value_ctxt, control_flow_ops.CondContext): cond_ctxt = value_ctxt break value_ctxt = value_ctxt.outer_context with ops.control_dependencies(None): self.grad_context.Enter() if cond_ctxt: # Guard stack pop with a switch if it is controlled by a cond. grad_state = self pred = None while pred is None and grad_state: pred = grad_state.history_map.get(cond_ctxt.pred.name) grad_state = grad_state.outer_grad_state if pred is None: pred = cond_ctxt.pred branch = ( 1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch history_value = control_flow_ops._SwitchRefOrTensor( history_value, pred)[branch] pop = gen_data_flow_ops.stack_pop_v2(history_value, value.dtype.base_dtype) pop.set_shape(value.get_shape()) self.grad_context.Exit() parallel_iterations = self.grad_context.parallel_iterations if parallel_iterations > 1: # All pops are ordered after pivot_for_body and before grad_sync. self.grad_sync._add_control_input(pop.op) return pop
def ZerosLike(self, op, index): """Create zeros_like for the specified output of an op. If op is in a while loop that is part of gradients(), this method must be called in its grad loop context. Args: op: A tensorflow operation. index: the index for a specific output of the op. Returns: A zero tensor of the same shape of op.outputs[index]. """ if util.IsLoopSwitch(op): return None if op.graph._building_function: # pylint: disable=protected-access # The optimization here is tricky to apply to functions return array_ops.zeros_like(op.outputs[index]) dead_branch = util.IsSwitch(op) forward_ctxt = util.GetWhileContext(op) grad_state = self._map.get(forward_ctxt) if grad_state is None: # op is not in a while loop that is part of gradients(). return ZerosLikeOutsideLoop(op, index) op_ctxt = op._get_control_flow_context() val = ops.convert_to_tensor(op.outputs[index], name="tensor") shape = val.get_shape() if shape.is_fully_defined(): # If the shape is known statically, just create a zero tensor with # the right shape in the grad loop context. result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) if dead_branch: # op is a cond switch. Guard the zero tensor with a switch. pred = grad_state.history_map.get(op_ctxt.pred.name) branch = op_ctxt.branch result = control_flow_ops._SwitchRefOrTensor(result, pred)[1 - branch] else: # Unknown shape so keep a history of the shape at runtime. if dead_branch: # Need to add a special switch to guard the value. pred = op_ctxt.pred branch = op_ctxt.branch op_ctxt.outer_context.Enter() val = control_flow_ops._SwitchRefOrTensor(op.inputs[0], pred)[1 - branch] zeros_shape = array_ops.shape_internal(val, optimize=False) op_ctxt.outer_context.Exit() val.op._set_control_flow_context(op_ctxt) zeros_shape.op._set_control_flow_context(op_ctxt) else: op_ctxt.Enter() zeros_shape = array_ops.shape_internal(val, optimize=False) op_ctxt.Exit() # Add forward accumulator for shape. grad_state.grad_context.Exit() history_zeros_shape = grad_state.AddForwardAccumulator( zeros_shape, dead_branch=dead_branch) grad_state.grad_context.Enter() # Create a zero tensor with the right shape. shape = grad_state.AddBackpropAccumulatedValue( history_zeros_shape, zeros_shape, dead_branch) result = array_ops.zeros(shape, val.dtype) return result