Пример #1
0
  def testMultipleScopes(self):
    var1 = variables.Variable(0.0)
    var2 = variables.Variable(1.0)
    with tape.VariableWatcher() as variable_watcher1:
      var1.assign_add(1.0)
      with tape.VariableWatcher() as variable_watcher2:
        var2.assign_add(2.0)

    # variable_watcher1 should see both vars and variable_watcher2 only sees
    # var2
    self.assertAllEqual(variable_watcher1.watched_variables(), (var1, var2))
    self.assertAllEqual(variable_watcher2.watched_variables(), (var2,))
Пример #2
0
    def testCreateVariables(self):
        with tape.VariableWatcher() as variable_watcher:
            var1 = variables.Variable(0.0)
            var2 = variables.Variable(1.0)
            var1.assign_add(1.0)
            var2.assign_add(2.0)

        self.assertAllEqual(variable_watcher.watched_variables(), (var1, var2))
Пример #3
0
    def testNonTrainableVariables(self):
        var1 = variables.Variable(0.0)
        var2 = variables.Variable(1.0, trainable=False)
        with tape.VariableWatcher() as variable_watcher:
            var1.assign_add(1.0)
            var2.assign_add(2.0)

        self.assertAllEqual(variable_watcher.watched_variables(), (var1, ))
Пример #4
0
def _eager_mode_decorator(f, args, kwargs):
    """Implement custom gradient decorator for eager mode."""
    with tape_lib.VariableWatcher() as variable_watcher:
        result, grad_fn = f(*args, **kwargs)
    args = nest.flatten(args)
    all_inputs = list(args) + list(kwargs.values())
    # The variables that grad_fn needs to return gradients for are the set of
    # variables used that are *not* part of the inputs.
    variables = [
        v.deref()  # pylint: disable=g-complex-comprehension
        for v in set(v.ref() for v in variable_watcher.watched_variables())
        if all(v.deref() is not i for i in all_inputs)
    ]
    grad_argspec = tf_inspect.getfullargspec(grad_fn)
    if (variables and ("variables" not in grad_argspec.args)
            and ("variables" not in grad_argspec.kwonlyargs)
            and not grad_argspec.varkw):
        raise TypeError(
            "@tf.custom_gradient grad_fn must accept keyword argument 'variables', "
            "since function uses variables: {}".format(variables))
    flat_result = nest.flatten(result)
    # TODO(apassos) consider removing the identity below.
    flat_result = [gen_array_ops.identity(x) for x in flat_result]

    input_tensors = [
        ops.convert_to_tensor(x) for x in list(args) + list(variables)
    ]

    recorded_inputs = input_tensors
    arg_count = len(args)

    def actual_grad_fn(*result_grads):
        """Custom grad fn wrapper."""
        if variables:
            input_grads, variable_grads = grad_fn(*result_grads,
                                                  variables=variables)
            if len(variable_grads) != len(variables):
                raise ValueError("Must return gradient for each variable from "
                                 "@custom_gradient grad_fn.")
        else:
            input_grads = grad_fn(*result_grads)
            variable_grads = []
        flat_grads = nest.flatten(input_grads)
        if len(flat_grads) != arg_count:
            raise ValueError("custom_gradient function expected to return",
                             arg_count, "gradients but returned",
                             len(flat_grads), "instead.")
        return flat_grads + variable_grads

    tape_lib.record_operation(f.__name__, flat_result, recorded_inputs,
                              actual_grad_fn)
    flat_result = list(flat_result)
    return nest.pack_sequence_as(result, flat_result)
Пример #5
0
def _graph_mode_decorator(f, args, kwargs):
    """Implement custom gradient decorator for graph mode."""
    # TODO(rsepassi): Add support for kwargs
    if kwargs:
        raise ValueError(
            "The custom_gradient decorator currently supports keywords "
            "arguments only when eager execution is enabled.")
    name = "CustomGradient-%s" % ops.uid()

    default_graph = ops.get_default_graph()

    def convert_arg(x):
        x = ops.convert_to_tensor(x)
        # If graph building, be sure to capture all inputs
        if default_graph.building_function and x.graph != default_graph:
            x = default_graph.capture(x)
        return x

    args = nest.map_structure(convert_arg, args)

    # Checking global and local variables attempts to ensure that no non-resource
    # Variables are added to the graph.
    current_var_scope = variable_scope.get_variable_scope()
    before_vars = set([
        v.ref() for v in current_var_scope.global_variables() +
        current_var_scope.local_variables()
    ])
    with tape_lib.VariableWatcher() as variable_watcher:
        result, grad_fn = f(*args)

    args = nest.flatten(args)
    flat_result = nest.flatten(result)
    flat_result_len = len(flat_result)

    after_vars = set([
        v.ref() for v in current_var_scope.global_variables() +
        current_var_scope.local_variables()
    ])
    new_vars = after_vars - before_vars
    new_vars_list = [v.deref() for v in new_vars]
    for v in new_vars_list:
        if not resource_variable_ops.is_resource_variable(v):
            raise TypeError(
                "All variables used by a function wrapped with @custom_gradient must "
                "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
                "with `use_resource=False`.")

    # The variables that grad_fn needs to return gradients for are the set of
    # variables used that are *not* part of the inputs.
    variables_in_tape = frozenset(
        [v.ref() for v in variable_watcher.watched_variables()])
    variables_in_subgraph = frozenset([
        v.ref() for v in _get_dependent_variables(input_ops=args,
                                                  output_ops=flat_result)
    ])
    variables = list(
        [v.deref() for v in variables_in_subgraph.union(variables_in_tape)])

    grad_argspec = tf_inspect.getfullargspec(grad_fn)
    variables_in_signature = ("variables" in grad_argspec.args
                              or "variables" in grad_argspec.kwonlyargs
                              or grad_argspec.varkw)
    if variables and not variables_in_signature:
        raise TypeError(
            "@tf.custom_gradient grad_fn must accept keyword argument 'variables', "
            "since function uses variables: {}".format(variables))
    if variables_in_signature and not variables:
        # User seems to intend to use variables but none were captured.
        logging.warn(
            "@custom_gradient grad_fn has 'variables' in signature, but "
            "no ResourceVariables were used on the forward pass.")

    all_tensors = flat_result + args + variables

    def tape_grad_fn(*result_grads):
        """Custom grad fn wrapper."""
        result_grads = result_grads[:flat_result_len]
        if variables:
            input_grads, variable_grads = grad_fn(*result_grads,
                                                  variables=variables)
            if len(variable_grads) != len(variables):
                raise ValueError("Must return gradient for each variable from "
                                 "@custom_gradient grad_fn.")
        else:
            input_grads = grad_fn(*result_grads)
            variable_grads = []

        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        input_grads = nest.flatten(input_grads)
        return ([None] * flat_result_len) + input_grads + variable_grads

    @ops.RegisterGradient(name)
    def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        """Custom grad fn wrapper."""
        return tape_grad_fn(*result_grads)

    original_tensors = all_tensors
    with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)

    original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]

    # Propagate handle data for happier shape inference for resource variables.
    for i, t in enumerate(original_tensors):
        if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
            all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
    tape_lib.record_operation(f.__name__, all_tensors, original_tensors,
                              tape_grad_fn)
    for ot, t in zip(original_tensors, all_tensors):
        copy_handle_data(ot, t)
    return nest.pack_sequence_as(structure=result,
                                 flat_sequence=all_tensors[:flat_result_len])
Пример #6
0
def _graph_mode_decorator(f, args, kwargs):
  """Implement custom gradient decorator for graph mode."""
  # TODO(rsepassi): Add support for kwargs
  if kwargs:
    raise ValueError(
        "The custom_gradient decorator currently supports keywords "
        "arguments only when eager execution is enabled.")
  name = "CustomGradient-%s" % ops.uid()
  args = nest.map_structure(ops.convert_to_tensor, args)

  # Checking global and local variables attempts to ensure that no non-resource
  # Variables are added to the graph.
  current_var_scope = variable_scope.get_variable_scope()
  before_vars = set([
      v.ref() for v in current_var_scope.global_variables() +
      current_var_scope.local_variables()
  ])
  with tape_lib.VariableWatcher() as variable_watcher:
    result, grad_fn = f(*args)
  args = nest.flatten(args)
  after_vars = set([
      v.ref() for v in current_var_scope.global_variables() +
      current_var_scope.local_variables()
  ])
  new_vars = after_vars - before_vars
  new_vars_list = [v.deref() for v in new_vars]
  for v in new_vars_list:
    if not resource_variable_ops.is_resource_variable(v):
      raise TypeError(
          "All variables used by a function wrapped with @custom_gradient must "
          "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
          "with `use_resource=False`.")

  # It is possible for the caller to pass in an input that is from a different
  # graph. Even though this is not valid we filter these out if they are not
  # from the output graph to make it easier for some code to migrate to custom
  # gradients.
  inputs = nest.flatten(args)
  outputs = nest.flatten(result)
  graphs = {getattr(o, "graph", None) for o in outputs}
  # Not all results may be tensors. However, we want to ensure that all outputs
  # are from the same graph and use that to filter the inputs.
  graphs.discard(None)  # Discard non-graph outputs
  if graphs:
    if len(graphs) > 1:
      raise ValueError("All graph outputs should be from the same graph")
    output_graph = graphs.pop()
    filtered_inputs = []
    for i in inputs:
      if i.graph != output_graph:
        logging.warn("%s does not belong to output graph %s", i, output_graph)
      else:
        filtered_inputs.append(i)

    inputs = filtered_inputs

  # The variables that grad_fn needs to return gradients for are the set of
  # variables used that are *not* part of the inputs.
  variables_in_tape = frozenset([
      v.ref() for v in variable_watcher.watched_variables()
  ]) - frozenset(v.ref() for v in inputs)
  variables_in_subgraph = frozenset([
      v.ref()
      for v in get_dependent_variables(input_ops=inputs, output_ops=outputs)
  ])
  variables = list(
      [v.deref() for v in variables_in_subgraph.union(variables_in_tape)])

  grad_argspec = tf_inspect.getfullargspec(grad_fn)
  variables_in_signature = ("variables" in grad_argspec.args or
                            grad_argspec.varkw)
  if variables and not variables_in_signature:
    raise TypeError("If using @custom_gradient with a function that "
                    "uses variables, then grad_fn must accept a keyword "
                    "argument 'variables'.")
  if variables_in_signature and not variables:
    # User seems to intend to use variables but none were captured.
    logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
                 "no ResourceVariables were used on the forward pass.")
  flat_result = nest.flatten(result)
  flat_result_len = len(flat_result)

  all_tensors = flat_result + inputs + variables

  def tape_grad_fn(*result_grads):
    """Custom grad fn wrapper."""
    result_grads = result_grads[:flat_result_len]
    if variables:
      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
      if len(variable_grads) != len(variables):
        raise ValueError("Must return gradient for each variable from "
                         "@custom_gradient grad_fn.")
    else:
      input_grads = grad_fn(*result_grads)
      variable_grads = []

    # Need to return one value per input to the IdentityN, so pad the
    # gradients of the inputs of the custom_gradient function with the
    # gradients of the outputs as well.
    input_grads = nest.flatten(input_grads)
    return ([None] * flat_result_len) + input_grads + variable_grads

  @ops.RegisterGradient(name)
  def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
    """Custom grad fn wrapper."""
    return tape_grad_fn(*result_grads)

  original_tensors = all_tensors
  with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
    all_tensors = array_ops.identity_n(all_tensors)

  original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]

  # Propagate handle data for happier shape inference for resource variables.
  for i, t in enumerate(original_tensors):
    if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
      all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
  tape_lib.record_operation(
      f.__name__, all_tensors, original_tensors, tape_grad_fn)
  for ot, t in zip(original_tensors, all_tensors):
    copy_handle_data(ot, t)
  return nest.pack_sequence_as(
      structure=result, flat_sequence=all_tensors[:flat_result_len])