Exemple #1
0
def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents):
    """Computes a Jacobian-vector product for an op.

  Note that this function would be wasteful if executed eagerly. It runs the
  backward gradient function and throws away the result just to record its
  operations on a GradientTape. These unused ops are pruned away when this
  function is traced.

  Args:
    op_name: A string, the type of operation being executed.
    attr_tuple: Attributes of the operation.
    inputs: A flat list of input Tensors to the operation.
    outputs: A flat list of output Tensors from the operation.
    tangents: A flat list of Tensors, same shape as `inputs`.

  Returns:
    A flat list of tangents corresponding to `outputs`.
  """
    if not outputs:
        # tape.gradients([], inputs) doesn't make much sense
        return []
    trainable_inputs = []
    trainable_indices = []
    nontrivial_tangents = []
    for input_index, tensor in enumerate(inputs):
        if gradients_util.IsTrainable(tensor):
            trainable_inputs.append(tensor)
            trainable_indices.append(input_index)
            nontrivial_tangents.append(tangents[input_index])

    with backprop.GradientTape() as transpose_tape:
        with backprop.GradientTape() as backfunc_tape:
            backfunc_tape.watch(trainable_inputs)
            execute.record_gradient(op_name, inputs, attr_tuple, outputs,
                                    "forward_op_replay")

        forwardprop_aids = []
        trainable_outputs = []
        nontrivial_output_indices = []
        for output_index, output in enumerate(outputs):
            if gradients_util.IsTrainable(output):
                forwardprop_aids.append(
                    array_ops.ones_like(output, name="unused_forwardprop_aid"))
                trainable_outputs.append(output)
                nontrivial_output_indices.append(output_index)

        transpose_tape.watch(forwardprop_aids)
        grads = backfunc_tape.gradient(
            trainable_outputs,
            trainable_inputs,
            forwardprop_aids,
            unconnected_gradients=UnconnectedGradients.ZERO)
    nontrivial_output_tangents = transpose_tape.gradient(
        grads, forwardprop_aids, output_gradients=nontrivial_tangents)
    output_tangents = [None] * len(outputs)
    for index, tangent in zip(nontrivial_output_indices,
                              nontrivial_output_tangents):
        output_tangents[index] = tangent
    return output_tangents
Exemple #2
0
def _is_trainable(tensor):
    """Returns whether the given tensor is trainable."""
    if not gradients_util.IsTrainable(tensor):
        return False

    # Special case: untrainable accumulator output. The gradients algorithm
    # doesn't know about tensor lists of untrainable elements. In theory the
    # tensor list gradient functions should return None as appropriate, but
    # because we can't return None from the gradient function we filter out
    # untrainable accumulator output here to avoid computing the gradient at all.
    if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0:
        assert tensor.dtype == dtypes.variant
        element_type = tensor.op.get_attr("element_dtype")
        return gradients_util.IsTrainable(element_type)

    return True
def MpcIsBackpropagatable(tensor):
    if is_mpc_compare_tensor(tensor):
        return False

    if gradients_util.IsTrainable(tensor):
        return True

    dtype = dtypes.as_dtype(tensor.dtype)
    return dtype.base_dtype == dtypes.bfloat16
Exemple #4
0
def RttIsBackpropagatable(tensor):
    # NOTE(George): We make 'string' legal for backpropagating. 
    if tensor.dtype == dtypes.string:
        return True

    if gradients_util.IsTrainable(tensor):
        return True

    dtype = dtypes.as_dtype(tensor.dtype)
    return dtype.base_dtype == dtypes.bfloat16
Exemple #5
0
def _grad_fn(func_graph, grads):
    """The gradient function for each conditional branch.

  This function builds the gradient graph of the corresponding forward-pass
  conditional branch in `func_graph`. This is done by differentiating
  func_graph's outputs w.r.t. its inputs.

  Args:
    func_graph: FuncGraph. The corresponding forward-pass function.
    grads: The list of input gradient Tensors.

  Returns:
    The output gradient Tensors.
  """
    # Filter out untrainable function outputs.
    # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes
    # cause _GradientsHelper to raise an exception (e.g. the implementation
    # doesn't expect 'ys' to contain boolean tensors).
    assert len(func_graph.outputs) == len(grads)
    ys = []
    grad_ys = []
    for y, grad_y in zip(func_graph.outputs, grads):
        if not gradients_util.IsTrainable(y):
            continue
        ys.append(y)
        grad_ys.append(grad_y)

    # Build the gradient graph. Note that this builds the gradient computation of
    # func_graph in the current graph, which requires capturing tensors from
    # func_graph. The captured func_graph tensors are resolved to external tensors
    # in _resolve_grad_inputs.
    result = gradients_util._GradientsHelper(ys,
                                             func_graph.inputs,
                                             grad_ys=grad_ys,
                                             src_graph=func_graph)

    # Functions can't return None; replace Nones with zero tensors.
    # TODO(b/80444525): don't return anything here and make _IfGrad return None if
    # both branches have zero gradient.
    for i in range(len(result)):
        if result[i] is None:
            if func_graph.inputs[i].dtype == dtypes.resource:
                result[i] = array_ops.zeros(
                    gen_resource_variable_ops.variable_shape(
                        func_graph.inputs[i]),
                    dtype=default_gradient.get_zeros_dtype(
                        func_graph.inputs[i]))
            else:
                result[i] = array_ops.zeros_like(func_graph.inputs[i])

    return result