Esempio n. 1
0
def _is_trainable(tensor):
    """Returns whether the given tensor is trainable."""
    if not gradients_impl.IsTrainable(tensor):
        return False

    # Special case: untrainable accumulator output. The gradients algorithm
    # doesn't know about tensor lists of untrainable elements. In theory the
    # tensor list gradient functions should return None as appropriate, but
    # because we can't return None from the gradient function we filter out
    # untrainable accumulator output here to avoid computing the gradient at all.
    if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0:
        assert tensor.dtype == dtypes.variant
        element_type = tensor.op.get_attr("element_dtype")
        return gradients_impl.IsTrainable(element_type)

    return True
Esempio n. 2
0
def _grad_fn(func_graph, grads):
    """The gradient function for each conditional branch.

  This function builds the gradient graph of the corresponding forward-pass
  conditional branch in `func_graph`. This is done by differentiating
  func_graph's outputs w.r.t. its inputs.

  Args:
    func_graph: FuncGraph. The corresponding forward-pass function.
    grads: The list of input gradient Tensors.

  Returns:
    The output gradient Tensors.
  """
    # Filter out untrainable function outputs.
    # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes
    # cause _GradientsHelper to raise an exception (e.g. the implementation
    # doesn't expect 'ys' to contain boolean tensors).
    assert len(func_graph.outputs) == len(grads)
    ys = []
    grad_ys = []
    for y, grad_y in zip(func_graph.outputs, grads):
        if not gradients_impl.IsTrainable(y):
            continue
        ys.append(y)
        grad_ys.append(grad_y)

    # Build the gradient graph. Note that this builds the gradient computation of
    # func_graph in the current graph, which requires capturing tensors from
    # func_graph. The captured func_graph tensors are resolved to external tensors
    # in _resolve_grad_inputs.
    result = gradients_impl._GradientsHelper(ys,
                                             func_graph.inputs,
                                             grad_ys=grad_ys,
                                             src_graph=func_graph)

    # Functions can't return None; replace Nones with zero tensors.
    # TODO(b/80444525): don't return anything here and make _IfGrad return None if
    # both branches have zero gradient.
    for i in range(len(result)):
        if result[i] is None:
            if func_graph.inputs[i].dtype == dtypes.resource:
                result[i] = array_ops.zeros(
                    gen_resource_variable_ops.variable_shape(
                        func_graph.inputs[i]))
            else:
                result[i] = array_ops.zeros_like(func_graph.inputs[i])

    return result
Esempio n. 3
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
    """The gradient of a While op produced by while_loop."""
    body_graph = _get_body_graph(op)

    # Set the incoming gradient of TensorArray handles to None. The gradient
    # implementation currently assumes all resource tensors correspond to float32
    # ResourceVariables, which can lead to runtime shape errors when used with a
    # TensorArray. This is a workaround until TensorArrays are reimplemented with
    # TensorLists instead of resources.
    # Also set the incoming gradient of non-trainable inputs to None. It is
    # possible that we receive non-None gradients for non-trainable types in
    # nested while loops because we accumulate outputs of the inner while as
    # variant tensors which are trainable and hence receive zeros_like tensors in
    # the gradient pass. The non-trainable tensors then receive the popped zeros
    # tensor from this zeros variant. The gradient for the loop vars corresponding
    # to these tensors is None or zeros (this happens only if the loop var is
    # accumulated as well) in _grad_fn so we reset these.
    # TODO(b/118712257): Remove the IsTrainable filter once we can handle None
    # output grads in _grad_fn.
    grads = [
        None if _is_tensor_array_handle(output)
        or not gradients_impl.IsTrainable(output) else grad
        for grad, output in zip(grads, op.outputs)
    ]

    # Ensure that all non-resource trainable outputs have incoming gradients.
    assert all(g is not None or o.dtype == dtypes.resource
               or not gradients_impl.IsTrainable(o)
               for o, g in zip(op.outputs, grads)
               ), "All trainable loop vars must receive incoming gradients."
    # We compute the gradient for the sub-graph between trainable ys and xs
    # with non-None incoming gradients. We later pad the None's to the list of
    # outputs.
    ys, xs, non_none_grads = zip(
        *[(y, x, grad)
          for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads)
          if grad is not None])

    body_grad_graph, args = _create_grad_func(
        ys, xs, non_none_grads, body_graph,
        util.unique_grad_fn_name(body_graph.name), op)

    intermediate_tensors = _get_intermediates(body_grad_graph)

    maximum_iterations = op.get_attr(
        "_maximum_iterations") if _is_in_xla_context() else None
    assert not _is_in_xla_context() or maximum_iterations is not None
    for intermediate_tensor in intermediate_tensors:
        tensor_list = list_ops.empty_tensor_list(
            element_dtype=intermediate_tensor.dtype,
            element_shape=intermediate_tensor.shape,
            max_num_elements=maximum_iterations)

        with body_grad_graph.as_default():
            tensor_list_ph = body_grad_graph.capture(tensor_list,
                                                     whitelisted=True)
            # Push the intermediate tensor to the tensor list.
            appended_tensor_list = list_ops.tensor_list_push_back(
                tensor_list_ph, intermediate_tensor)
            # Add this modified tensor list to the list of outputs.
            body_grad_graph.outputs.append(appended_tensor_list)

    def grad_cond(counter, max_iters, *unused_args):
        return counter < max_iters

    loop_vars = args + body_grad_graph.external_captures
    grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
    cond_grad_graph = func_graph_module.func_graph_from_py_func(
        grad_cond_name,
        grad_cond,
        loop_vars, {},
        func_graph=util.WhileCondFuncGraph(grad_cond_name))

    _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))

    outputs = gen_functional_ops._while(
        loop_vars,
        util.create_new_tf_function(cond_grad_graph),
        util.create_new_tf_function(body_grad_graph),
        output_shapes=[t.shape for t in body_grad_graph.outputs],
        name="%s_grad" % op.name)

    _copy_handle_data(body_grad_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)

    # See comment in while_loop.
    outputs = [array_ops.identity(t) for t in outputs]

    # Set None as the output gradient for tensors with None input gradient
    # e.g. TensorArray handles.
    # outputs[0] is the loop counter.
    # outputs[1] is the total number of loop iterations.
    index = 2
    none_padded_outputs = []
    for g in grads:
        if g is None:
            none_padded_outputs.append(None)
        else:
            none_padded_outputs.append(outputs[index])
            index += 1
    return none_padded_outputs
Esempio n. 4
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
    """The gradient of a While op produced by while_loop."""
    body_graph = _get_body_graph(op)

    # Set the incoming gradient of TensorArray handle to None.
    # TODO(b/118164915): We need a way of distinguising b/w TensorArray resource
    # handles and ResourceVariables and set the default gradient of only the
    # TensorArray handle to None.
    grads = [
        None if output.dtype == dtypes.resource else g
        for g, output in zip(grads, op.outputs)
    ]

    # Ensure that all non-resource trainable outputs have incoming gradients.
    assert all(g is not None or o.dtype == dtypes.resource
               or not gradients_impl.IsTrainable(o)
               for o, g in zip(op.outputs, grads)
               ), "All trainable loop vars must receive incoming gradients."
    # We compute the gradient for the sub-graph between trainable ys and xs
    # with non-None incoming gradients. We later pad the None's to the list of
    # outputs.
    ys, xs, non_none_grads = zip(
        *[(y, x, grad)
          for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads)
          if grad is not None])

    body_grad_graph, args = _create_grad_func(
        ys, xs, non_none_grads, body_graph,
        util.unique_grad_fn_name(body_graph.name), op)

    intermediate_tensors = _get_intermediates(body_grad_graph)

    for intermediate_tensor in intermediate_tensors:
        tensor_list = list_ops.empty_tensor_list(
            element_dtype=intermediate_tensor.dtype,
            element_shape=_get_tensor_convertible_shape(
                intermediate_tensor.shape))
        with body_grad_graph.as_default():
            tensor_list_ph = body_grad_graph.capture(tensor_list,
                                                     whitelisted=True)
            # Push the intermediate tensor to the tensor list.
            appended_tensor_list = list_ops.tensor_list_push_back(
                tensor_list_ph, intermediate_tensor)
            # Add this modified tensor list to the list of outputs.
            body_grad_graph.outputs.append(appended_tensor_list)

    def grad_cond(counter, max_iters, *unused_args):
        return counter < max_iters

    loop_vars = args + body_grad_graph.external_captures
    grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
    cond_grad_graph = func_graph_module.func_graph_from_py_func(
        grad_cond_name,
        grad_cond,
        loop_vars, {},
        func_graph=util.WhileCondFuncGraph(grad_cond_name))

    _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))

    outputs = gen_functional_ops._while(
        loop_vars,
        util.create_new_tf_function(cond_grad_graph),
        util.create_new_tf_function(body_grad_graph),
        output_shapes=[t.shape for t in body_grad_graph.outputs],
        name="%s_grad" % op.name)

    _copy_handle_data(body_grad_graph.outputs, outputs)
    _maybe_set_lowering_attr(outputs[0].op)

    # Set None as the output gradient for tensors with None input gradient
    # e.g. TensorArray handles.
    # outputs[0] is the loop counter.
    # outputs[1] is the total number of loop iterations.
    index = 2
    none_padded_outputs = []
    for g in grads:
        if g is None:
            none_padded_outputs.append(None)
        else:
            none_padded_outputs.append(outputs[index])
            index += 1
    return none_padded_outputs