def _grad_fn(func_graph, args): """Computes the gradient of `func_graph` in the current graph. This function builds the gradient graph of the corresponding forward-pass `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs. Args: func_graph: function.FuncGraph. The corresponding forward-pass function. args: The input arguments. args[0] - Loop counter args[1] - Total number of iterations. args[2:] - Incoming gradients for `func_graph.outputs`. Returns: The output gradient Tensors. """ xs = func_graph.inputs ys = func_graph.outputs grad_ys = args[2:] # Build the gradient graph. Note that this builds the gradient computation of # func_graph in the current graph, which requires capturing tensors from # func_graph. The captured func_graph tensors are resolved to external tensors # in _resolve_grad_inputs. # TODO(srbs): Mark GradientsHelper as public? grad_outs = gradients_impl._GradientsHelper(ys, xs, grad_ys=grad_ys, src_graph=func_graph) assert all([g is not None for g in grad_outs]) counter = args[0] total_iters = args[1] return [counter + 1, total_iters] + grad_outs
def _grad_fn(ys, xs, args, func_graph): """Computes the gradient of `func_graph` in the current graph. This function builds the gradient graph of the corresponding forward-pass `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs. Args: ys: A `Tensor` or list of tensors to be differentiated. xs: A `Tensor` or list of tensors to be used for differentiation. args: The input arguments. args[0] - Loop counter args[1] - Total number of iterations. args[2:] - Incoming gradients for `ys`. func_graph: function.FuncGraph. The corresponding forward-pass function. Returns: The output gradient Tensors. """ grad_ys = args[2:] # Build the gradient graph. Note that this builds the gradient computation of # func_graph in the current graph, which requires capturing tensors from # func_graph. The captured func_graph tensors are resolved to external tensors # in _resolve_grad_inputs. # TODO(srbs): Mark GradientsHelper as public? grad_outs = gradients_impl._GradientsHelper( ys, xs, grad_ys=grad_ys, src_graph=func_graph) # TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there # is a tf.StopGradient in the loop body. assert all(g is not None for g in grad_outs) counter = args[0] total_iters = args[1] return [counter + 1, total_iters] + grad_outs
def _grad_fn(func_graph, args): """Computes the gradient of `func_graph` in the current graph. This function builds the gradient graph of the corresponding forward-pass `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs. Args: func_graph: function.FuncGraph. The corresponding forward-pass function. args: The input arguments. args[0] - Loop counter args[1] - Total number of iterations. args[2:] - Incoming gradients for `func_graph.outputs`. Returns: The output gradient Tensors. """ xs = func_graph.inputs ys = func_graph.outputs grad_ys = args[2:] # Build the gradient graph. Note that this builds the gradient computation of # func_graph in the current graph, which requires capturing tensors from # func_graph. The captured func_graph tensors are resolved to external tensors # in _resolve_grad_inputs. # TODO(srbs): Mark GradientsHelper as public? grad_outs = gradients_impl._GradientsHelper( ys, xs, grad_ys=grad_ys, src_graph=func_graph) assert all([g is not None for g in grad_outs]) counter = args[0] total_iters = args[1] return [counter + 1, total_iters] + grad_outs
def _grad_fn(func_graph, grads): """The gradient function for each conditional branch. This function builds the gradient graph of the corresponding forward-pass conditional branch in `func_graph`. This is done by differentiating func_graph's outputs w.r.t. its inputs. Args: func_graph: FuncGraph. The corresponding forward-pass function. grads: The list of input gradient Tensors. Returns: The output gradient Tensors. """ # Filter out untrainable function outputs. # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes # cause _GradientsHelper to raise an exception (e.g. the implementation # doesn't expect 'ys' to contain boolean tensors). assert len(func_graph.outputs) == len(grads) ys = [] grad_ys = [] for y, grad_y in zip(func_graph.outputs, grads): if not gradients_impl.IsTrainable(y): continue ys.append(y) grad_ys.append(grad_y) # Build the gradient graph. Note that this builds the gradient computation of # func_graph in the current graph, which requires capturing tensors from # func_graph. The captured func_graph tensors are resolved to external tensors # in _resolve_grad_inputs. result = gradients_impl._GradientsHelper(ys, func_graph.inputs, grad_ys=grad_ys, src_graph=func_graph) # Functions can't return None; replace Nones with zero tensors. # TODO(b/80444525): don't return anything here and make _IfGrad return None if # both branches have zero gradient. for i in range(len(result)): if result[i] is None: if func_graph.inputs[i].dtype == dtypes.resource: result[i] = array_ops.zeros( gen_resource_variable_ops.variable_shape( func_graph.inputs[i])) else: result[i] = array_ops.zeros_like(func_graph.inputs[i]) return result
def _grad_fn(func_graph, grads): """The gradient function for each conditional branch. This function builds the gradient graph of the corresponding forward-pass conditional branch in `func_graph`. This is done by differentiating func_graph's outputs w.r.t. its inputs. Args: func_graph: FuncGraph. The corresponding forward-pass function. grads: The list of input gradient Tensors. Returns: The output gradient Tensors. """ # Filter out untrainable function outputs. # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes # cause _GradientsHelper to raise an exception (e.g. the implementation # doesn't expect 'ys' to contain boolean tensors). assert len(func_graph.outputs) == len(grads) ys = [] grad_ys = [] for y, grad_y in zip(func_graph.outputs, grads): if not gradients_impl.IsTrainable(y): continue ys.append(y) grad_ys.append(grad_y) # Build the gradient graph. Note that this builds the gradient computation of # func_graph in the current graph, which requires capturing tensors from # func_graph. The captured func_graph tensors are resolved to external tensors # in _resolve_grad_inputs. result = gradients_impl._GradientsHelper( ys, func_graph.inputs, grad_ys=grad_ys, src_graph=func_graph) # Functions can't return None; replace Nones with zero tensors. # TODO(b/80444525): don't return anything here and make _IfGrad return None if # both branches have zero gradient. for i in range(len(result)): if result[i] is None: if func_graph.inputs[i].dtype == dtypes.resource: result[i] = array_ops.zeros( gen_resource_variable_ops.variable_shape(func_graph.inputs[i])) else: result[i] = array_ops.zeros_like(func_graph.inputs[i]) return result