def _is_trainable(tensor): """Returns whether the given tensor is trainable.""" if not gradients_impl.IsTrainable(tensor): return False # Special case: untrainable accumulator output. The gradients algorithm # doesn't know about tensor lists of untrainable elements. In theory the # tensor list gradient functions should return None as appropriate, but # because we can't return None from the gradient function we filter out # untrainable accumulator output here to avoid computing the gradient at all. if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0: assert tensor.dtype == dtypes.variant element_type = tensor.op.get_attr("element_dtype") return gradients_impl.IsTrainable(element_type) return True
def _grad_fn(func_graph, grads): """The gradient function for each conditional branch. This function builds the gradient graph of the corresponding forward-pass conditional branch in `func_graph`. This is done by differentiating func_graph's outputs w.r.t. its inputs. Args: func_graph: FuncGraph. The corresponding forward-pass function. grads: The list of input gradient Tensors. Returns: The output gradient Tensors. """ # Filter out untrainable function outputs. # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes # cause _GradientsHelper to raise an exception (e.g. the implementation # doesn't expect 'ys' to contain boolean tensors). assert len(func_graph.outputs) == len(grads) ys = [] grad_ys = [] for y, grad_y in zip(func_graph.outputs, grads): if not gradients_impl.IsTrainable(y): continue ys.append(y) grad_ys.append(grad_y) # Build the gradient graph. Note that this builds the gradient computation of # func_graph in the current graph, which requires capturing tensors from # func_graph. The captured func_graph tensors are resolved to external tensors # in _resolve_grad_inputs. result = gradients_impl._GradientsHelper(ys, func_graph.inputs, grad_ys=grad_ys, src_graph=func_graph) # Functions can't return None; replace Nones with zero tensors. # TODO(b/80444525): don't return anything here and make _IfGrad return None if # both branches have zero gradient. for i in range(len(result)): if result[i] is None: if func_graph.inputs[i].dtype == dtypes.resource: result[i] = array_ops.zeros( gen_resource_variable_ops.variable_shape( func_graph.inputs[i])) else: result[i] = array_ops.zeros_like(func_graph.inputs[i]) return result
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) # Set the incoming gradient of TensorArray handles to None. The gradient # implementation currently assumes all resource tensors correspond to float32 # ResourceVariables, which can lead to runtime shape errors when used with a # TensorArray. This is a workaround until TensorArrays are reimplemented with # TensorLists instead of resources. # Also set the incoming gradient of non-trainable inputs to None. It is # possible that we receive non-None gradients for non-trainable types in # nested while loops because we accumulate outputs of the inner while as # variant tensors which are trainable and hence receive zeros_like tensors in # the gradient pass. The non-trainable tensors then receive the popped zeros # tensor from this zeros variant. The gradient for the loop vars corresponding # to these tensors is None or zeros (this happens only if the loop var is # accumulated as well) in _grad_fn so we reset these. # TODO(b/118712257): Remove the IsTrainable filter once we can handle None # output grads in _grad_fn. grads = [ None if _is_tensor_array_handle(output) or not gradients_impl.IsTrainable(output) else grad for grad, output in zip(grads, op.outputs) ] # Ensure that all non-resource trainable outputs have incoming gradients. assert all(g is not None or o.dtype == dtypes.resource or not gradients_impl.IsTrainable(o) for o, g in zip(op.outputs, grads) ), "All trainable loop vars must receive incoming gradients." # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip( *[(y, x, grad) for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, body_graph, util.unique_grad_fn_name(body_graph.name), op) intermediate_tensors = _get_intermediates(body_grad_graph) maximum_iterations = op.get_attr( "_maximum_iterations") if _is_in_xla_context() else None assert not _is_in_xla_context() or maximum_iterations is not None for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=intermediate_tensor.shape, max_num_elements=maximum_iterations) with body_grad_graph.as_default(): tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) # Push the intermediate tensor to the tensor list. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list_ph, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_grad_graph.outputs.append(appended_tensor_list) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters loop_vars = args + body_grad_graph.external_captures grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name="%s_grad" % op.name) _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] # Set None as the output gradient for tensors with None input gradient # e.g. TensorArray handles. # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. index = 2 none_padded_outputs = [] for g in grads: if g is None: none_padded_outputs.append(None) else: none_padded_outputs.append(outputs[index]) index += 1 return none_padded_outputs
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) # Set the incoming gradient of TensorArray handle to None. # TODO(b/118164915): We need a way of distinguising b/w TensorArray resource # handles and ResourceVariables and set the default gradient of only the # TensorArray handle to None. grads = [ None if output.dtype == dtypes.resource else g for g, output in zip(grads, op.outputs) ] # Ensure that all non-resource trainable outputs have incoming gradients. assert all(g is not None or o.dtype == dtypes.resource or not gradients_impl.IsTrainable(o) for o, g in zip(op.outputs, grads) ), "All trainable loop vars must receive incoming gradients." # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip( *[(y, x, grad) for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, body_graph, util.unique_grad_fn_name(body_graph.name), op) intermediate_tensors = _get_intermediates(body_grad_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=_get_tensor_convertible_shape( intermediate_tensor.shape)) with body_grad_graph.as_default(): tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) # Push the intermediate tensor to the tensor list. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list_ph, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_grad_graph.outputs.append(appended_tensor_list) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters loop_vars = args + body_grad_graph.external_captures grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name="%s_grad" % op.name) _copy_handle_data(body_grad_graph.outputs, outputs) _maybe_set_lowering_attr(outputs[0].op) # Set None as the output gradient for tensors with None input gradient # e.g. TensorArray handles. # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. index = 2 none_padded_outputs = [] for g in grads: if g is None: none_padded_outputs.append(None) else: none_padded_outputs.append(outputs[index]) index += 1 return none_padded_outputs