def AddWhileContext(self, op, between_op_list, between_ops): """Add the grad state for the while loop that op belongs to. Note that op is an Exit, and this method must be called in the control flow context where gradients() is called. Note that this method modifies `between_op_list` and `between_ops`. """ forward_ctxt = util.GetWhileContext(op) grad_state = self._map.get(forward_ctxt) if grad_state is None: # This is a new while loop so create a grad state for it. outer_forward_ctxt = forward_ctxt.outer_context if outer_forward_ctxt: outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() outer_grad_state = None if outer_forward_ctxt: outer_grad_state = self._map.get(outer_forward_ctxt) grad_state = _GradLoopState(forward_ctxt, outer_grad_state) self._map[forward_ctxt] = grad_state # We need to include all exits of a loop for backprop. for loop_exit in grad_state.forward_loop_exits: if loop_exit.op not in between_ops: between_ops.add(loop_exit.op) between_op_list.append(loop_exit.op)
def GetGradState(self, op, before): """Return the grad state for this op if it's in a forward loop context.""" if before and util.IsLoopExit(op): forward_ctxt = op._get_control_flow_context() # pylint: disable=protected-access forward_ctxt = forward_ctxt.outer_context if forward_ctxt: forward_ctxt = forward_ctxt.GetWhileContext() else: forward_ctxt = util.GetWhileContext(op) if forward_ctxt: return self._map.get(forward_ctxt) return None
def ZerosLike(self, op, index): """Create zeros_like for the specified output of an op. If op is in a while loop that is part of gradients(), this method must be called in its grad loop context. Args: op: A tensorflow operation. index: the index for a specific output of the op. Returns: A zero tensor of the same shape of op.outputs[index]. """ if util.IsLoopSwitch(op): return None if op.graph._building_function: # pylint: disable=protected-access # The optimization here is tricky to apply to functions return array_ops.zeros_like(op.outputs[index]) dead_branch = util.IsSwitch(op) forward_ctxt = util.GetWhileContext(op) grad_state = self._map.get(forward_ctxt) if grad_state is None: # op is not in a while loop that is part of gradients(). return ZerosLikeOutsideLoop(op, index) op_ctxt = op._get_control_flow_context() val = ops.convert_to_tensor(op.outputs[index], name="tensor") shape = val.get_shape() if shape.is_fully_defined(): # If the shape is known statically, just create a zero tensor with # the right shape in the grad loop context. result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) if dead_branch: # op is a cond switch. Guard the zero tensor with a switch. pred = grad_state.history_map.get(op_ctxt.pred.name) branch = op_ctxt.branch result = control_flow_ops._SwitchRefOrTensor(result, pred)[1 - branch] else: # Unknown shape so keep a history of the shape at runtime. if dead_branch: # Need to add a special switch to guard the value. pred = op_ctxt.pred branch = op_ctxt.branch op_ctxt.outer_context.Enter() val = control_flow_ops._SwitchRefOrTensor(op.inputs[0], pred)[1 - branch] zeros_shape = array_ops.shape_internal(val, optimize=False) op_ctxt.outer_context.Exit() val.op._set_control_flow_context(op_ctxt) zeros_shape.op._set_control_flow_context(op_ctxt) else: op_ctxt.Enter() zeros_shape = array_ops.shape_internal(val, optimize=False) op_ctxt.Exit() # Add forward accumulator for shape. grad_state.grad_context.Exit() history_zeros_shape = grad_state.AddForwardAccumulator( zeros_shape, dead_branch=dead_branch) grad_state.grad_context.Enter() # Create a zero tensor with the right shape. shape = grad_state.AddBackpropAccumulatedValue( history_zeros_shape, zeros_shape, dead_branch) result = array_ops.zeros(shape, val.dtype) return result