Ejemplo n.º 1
0
def _create_zeros_for_none_grads(forward_graphs, grad_graphs):
  """Creates zeros for None out grads if atleast one branch has non-None grad.

  Args:
    forward_graphs: List of forward FuncGraphs.
    grad_graphs: List of grad FuncGraphs.
  """
  assert len(forward_graphs) == len(grad_graphs)
  branch_outputs = [g.structured_outputs for g in grad_graphs]
  num_outputs_per_branch = [len(outs) for outs in branch_outputs]
  assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch
  for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
    if (any(t is None for t in branch_outs) and
        any(t is not None for t in branch_outs)):
      for branch_index, t in enumerate(branch_outs):
        if t is None:
          with grad_graphs[branch_index].as_default():
            zeros = default_gradient.zeros_like(
                forward_graphs[branch_index].inputs[output_idx])
            grad_graphs[branch_index].structured_outputs[output_idx] = zeros

  for grad_graph in grad_graphs:
    grad_graph.outputs = [
        t for t in func_graph_module.flatten(grad_graph.structured_outputs)
        if t is not None
    ]
Ejemplo n.º 2
0
  def gradient_func(unused_op, *result_grads):
    # Replace all `None` arguments, because the traced custom gradient function
    # expects tensors. Replacing with zeros is correct since the `None` values
    # occur when the gradient is unconnected, and thus the gradient is
    # "statically proven to be zero." See `tf.UnconnectedGradients` for details.
    result_grads = [x if x is not None else default_gradient.zeros_like(t)
                    for (x, t) in zip(result_grads, func.graph.inputs)]

    return func(*result_grads)