def _create_zeros_for_none_grads(forward_graphs, grad_graphs): """Creates zeros for None out grads if atleast one branch has non-None grad. Args: forward_graphs: List of forward FuncGraphs. grad_graphs: List of grad FuncGraphs. """ assert len(forward_graphs) == len(grad_graphs) branch_outputs = [g.structured_outputs for g in grad_graphs] num_outputs_per_branch = [len(outs) for outs in branch_outputs] assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch for output_idx, branch_outs in enumerate(zip(*branch_outputs)): if (any(t is None for t in branch_outs) and any(t is not None for t in branch_outs)): for branch_index, t in enumerate(branch_outs): if t is None: with grad_graphs[branch_index].as_default(): zeros = default_gradient.zeros_like( forward_graphs[branch_index].inputs[output_idx]) grad_graphs[branch_index].structured_outputs[output_idx] = zeros for grad_graph in grad_graphs: grad_graph.outputs = [ t for t in func_graph_module.flatten(grad_graph.structured_outputs) if t is not None ]
def gradient_func(unused_op, *result_grads): # Replace all `None` arguments, because the traced custom gradient function # expects tensors. Replacing with zeros is correct since the `None` values # occur when the gradient is unconnected, and thus the gradient is # "statically proven to be zero." See `tf.UnconnectedGradients` for details. result_grads = [x if x is not None else default_gradient.zeros_like(t) for (x, t) in zip(result_grads, func.graph.inputs)] return func(*result_grads)