def ZerosLikeOutsideLoop(op, index): """Create zeros_like for the specified output of an op.""" val = op.outputs[index] if not util.IsSwitch(op): if val.dtype == dtypes.resource: return array_ops.zeros( gen_resource_variable_ops.variable_shape(val), dtype=default_gradient.get_zeros_dtype(val)) return array_ops.zeros_like(val, optimize=False) else: op_ctxt = op._get_control_flow_context() if op_ctxt: # We are in a cond context. Use a switch to create zeros only when needed. pred = op_ctxt.pred branch = op_ctxt.branch switch_val = control_flow_ops.switch(op.inputs[0], pred)[1 - branch] # A op is created along the branch taken as control dependencies are on # the whole op and not on the tensor output. pivot = array_ops.identity(switch_val) if val.dtype == dtypes.resource: with ops.control_dependencies([pivot]): return array_ops.zeros( gen_resource_variable_ops.variable_shape(switch_val), dtype=default_gradient.get_zeros_dtype(val)) zeros_shape = array_ops.shape_internal(switch_val, optimize=False) # Ensure ops created within array_ops.zeros are dominated by switch in # cond context. with ops.control_dependencies([pivot]): return array_ops.zeros(zeros_shape, dtype=val.dtype) else: return array_ops.zeros_like(val, optimize=False)
def _ZerosLike(t): t_dtype = default_gradient.get_zeros_dtype(t) if t.dtype == dtypes.resource: return array_ops.zeros(resource_variable_ops.variable_shape(t), dtype=t_dtype) else: return array_ops.zeros_like(t, dtype=t_dtype)
def _zeros_like(op_output): """Like array_ops.zeros_like() but also accepts resource var handles.""" if op_output.dtype == dtypes.resource: return array_ops.zeros( gen_resource_variable_ops.variable_shape(op_output), dtype=default_gradient.get_zeros_dtype(op_output)) return array_ops.zeros_like(op_output)
def _SymGrad(op, out_grads): """Backprop through a function call node op given its outputs' gradients.""" f_in = [x for x in op.inputs] + out_grads f_types = [default_gradient.get_zeros_dtype(x) for x in op.inputs] f = attr_value_pb2.NameAttrList() if _IsPartitionedCall(op): f.name = op.get_attr("f").name else: f.name = op.type for k in op.node_def.attr: f.attr[k].CopyFrom(op.node_def.attr[k]) in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f) return in_grads
def _grad_fn(func_graph, grads): """The gradient function for each conditional branch. This function builds the gradient graph of the corresponding forward-pass conditional branch in `func_graph`. This is done by differentiating func_graph's outputs w.r.t. its inputs. Args: func_graph: FuncGraph. The corresponding forward-pass function. grads: The list of input gradient Tensors. Returns: The output gradient Tensors. """ # Filter out untrainable function outputs. # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes # cause _GradientsHelper to raise an exception (e.g. the implementation # doesn't expect 'ys' to contain boolean tensors). assert len(func_graph.outputs) == len(grads) ys = [] grad_ys = [] for y, grad_y in zip(func_graph.outputs, grads): if not backprop_util.IsTrainable(y): continue ys.append(y) grad_ys.append(grad_y) # Build the gradient graph. Note that this builds the gradient computation of # func_graph in the current graph, which requires capturing tensors from # func_graph. The captured func_graph tensors are resolved to external tensors # in _resolve_grad_inputs. result = gradients_util._GradientsHelper(ys, func_graph.inputs, grad_ys=grad_ys, src_graph=func_graph) # Functions can't return None; replace Nones with zero tensors. # TODO(b/80444525): don't return anything here and make _IfGrad return None if # both branches have zero gradient. for i in range(len(result)): if result[i] is None: if func_graph.inputs[i].dtype == dtypes.resource: result[i] = array_ops.zeros( gen_resource_variable_ops.variable_shape( func_graph.inputs[i]), dtype=default_gradient.get_zeros_dtype( func_graph.inputs[i])) else: result[i] = array_ops.zeros_like(func_graph.inputs[i]) return result
def _GetGrad(grads, t, unconnected_gradients): """Gets gradient for tensor "t".""" op = t.op op_grads = grads.get(op) if not op_grads: if unconnected_gradients == UnconnectedGradients.ZERO: t_dtype = default_gradient.get_zeros_dtype(t) return array_ops.zeros_like(t, dtype=t_dtype) elif unconnected_gradients == UnconnectedGradients.NONE: return None else: raise ValueError("Unknown value for unconnected_gradients: %r" % unconnected_gradients) t_grad = op_grads[t.value_index] assert not isinstance( t_grad, list), ("gradients list should have been aggregated by now.") return t_grad
def _ZerosLikeV2(op, index): """Branch of ZerosLike for TF2.""" val = op.outputs[index] if val.dtype == dtypes.resource: return array_ops.zeros( gen_resource_variable_ops.variable_shape(val), dtype=default_gradient.get_zeros_dtype(val)) if (isinstance(val.op.graph, control_flow_v2_func_graphs.WhileBodyFuncGraph) and val.dtype != dtypes.variant): # In while_v2 we do not want to add a `ZerosLike` op because that will # trigger accumulation of `val`. Normally `ZerosLike` is preferred because # it helps avoid creating extra nodes(possibly Consts) for the shape. # For variants, we must use ZerosLike. if val.shape.is_fully_defined(): return constant_op.constant(0, shape=val.shape.dims, dtype=val.dtype) else: # Note: Even though we add `Shape` in the default graph, while_v2 is smart # enough to place it in the forward graph i.e. `val.graph`. zeros_shape = array_ops.shape_internal(val, optimize=False) return array_ops.zeros(zeros_shape, val.dtype) else: return array_ops.zeros_like(val, optimize=False)
def ZerosLikeV1WhileLoop(self, op, index): """Create zeros_like for the specified output of an op. If op is in a while loop that is part of gradients(), this method must be called in its grad loop context. Args: op: A tensorflow operation. index: the index for a specific output of the op. Returns: A zero tensor of the same shape of op.outputs[index]. """ if util.IsLoopSwitch(op): return None if op.graph.building_function: # The optimization here is tricky to apply to functions return array_ops.zeros_like(op.outputs[index]) dead_branch = util.IsSwitch(op) forward_ctxt = util.GetWhileContext(op) grad_state = self._map.get(forward_ctxt) if grad_state is None: # op is not in a while loop that is part of gradients(). return ZerosLike(op, index) op_ctxt = op._get_control_flow_context() val = ops.convert_to_tensor(op.outputs[index], name="tensor") shape = val.get_shape() if shape.is_fully_defined(): # If the shape is known statically, just create a zero tensor with # the right shape in the grad loop context. if val.dtype == dtypes.resource: result = array_ops.zeros( resource_variable_ops.variable_shape(val), dtype=default_gradient.get_zeros_dtype(val)) else: result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) if dead_branch: # op is a cond switch. Guard the zero tensor with a switch. pred = grad_state.history_map.get(op_ctxt.pred.name) branch = op_ctxt.branch result = control_flow_ops._SwitchRefOrTensor(result, pred)[1 - branch] else: # Unknown shape so keep a history of the shape at runtime. if dead_branch: # Need to add a special switch to guard the value. pred = op_ctxt.pred branch = op_ctxt.branch op_ctxt.outer_context.Enter() val = control_flow_ops._SwitchRefOrTensor(op.inputs[0], pred)[1 - branch] zeros_shape = array_ops.shape_internal(val, optimize=False) op_ctxt.outer_context.Exit() val.op._set_control_flow_context(op_ctxt) zeros_shape.op._set_control_flow_context(op_ctxt) else: op_ctxt.Enter() zeros_shape = array_ops.shape_internal(val, optimize=False) op_ctxt.Exit() # Add forward accumulator for shape. grad_state.grad_context.Exit() history_zeros_shape = grad_state.AddForwardAccumulator( zeros_shape, dead_branch=dead_branch) grad_state.grad_context.Enter() # Create a zero tensor with the right shape. shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape, zeros_shape, dead_branch) result = array_ops.zeros(shape, val.dtype) return result