예제 #1
0
 def fn_with_recompute(*args):
   cached_vs.append(variable_scope.get_variable_scope())
   # TODO(rsepassi): Rm conditional in TF 1.4
   if hasattr(contrib_framework_ops, "current_arg_scope"):
     cached_arg_scope.append(contrib_framework_ops.current_arg_scope())
   else:
     cached_arg_scope.append({})
   return fn(*args)
예제 #2
0
 def fn_with_recompute(*args):
     cached_vs.append(variable_scope.get_variable_scope())
     # TODO(rsepassi): Rm conditional in TF 1.4
     if hasattr(contrib_framework_ops, "current_arg_scope"):
         cached_arg_scope.append(contrib_framework_ops.current_arg_scope())
     else:
         cached_arg_scope.append({})
     return fn(*args)
예제 #3
0
  def fn_with_recompute(*args):
    """Wrapper for fn."""
    # Forward pass
    vs = variable_scope.get_variable_scope()
    arg_scope = contrib_framework_ops.current_arg_scope()
    with backprop.GradientTape() as tape:
      fn_kwargs = {}
      if has_is_recompute_kwarg:
        fn_kwargs["is_recomputing"] = False
      outputs = fn(*args, **fn_kwargs)
    original_vars = set(tape.watched_variables())

    # Backward pass
    def grad_fn(*output_grads, **kwargs):
      """Recompute outputs for gradient computation."""
      variables = []
      if original_vars:
        variables = kwargs["variables"]
      if set(variables) != original_vars:
        raise ValueError(_WRONG_VARS_ERR)
      del kwargs
      inputs = list(args)
      # Recompute outputs
      with framework_ops.control_dependencies(output_grads):
        if use_data_dep_:
          inputs = _force_data_dependency(output_grads, inputs)
        with contrib_framework_ops.arg_scope(arg_scope):
          with variable_scope.variable_scope(vs, reuse=True):
            with backprop.GradientTape() as tape:
              fn_kwargs = {}
              if has_is_recompute_kwarg:
                fn_kwargs["is_recomputing"] = True
              outputs = fn(*inputs, **fn_kwargs)
            recompute_vars = set(tape.watched_variables())
            if original_vars != recompute_vars:
              raise ValueError(_WRONG_VARS_ERR)

      if not (isinstance(outputs, list) or isinstance(outputs, tuple)):
        outputs = [outputs]
      outputs = list(outputs)
      grads = gradients_impl.gradients(outputs, inputs + variables,
                                       output_grads)

      if tupleize_grads:
        if use_data_dep_:
          grads = _tuple_with_data_dep(grads)
        else:
          grads = control_flow_ops.tuple(grads)

      grad_inputs = grads[:len(inputs)]
      grad_vars = grads[len(inputs):]
      return grad_inputs, grad_vars

    return outputs, grad_fn
예제 #4
0
    def fn_with_recompute(*args):
        """Wrapper for fn."""
        # Forward pass
        vs = variable_scope.get_variable_scope()
        arg_scope = contrib_framework_ops.current_arg_scope()
        with backprop.GradientTape() as tape:
            fn_kwargs = {}
            if has_is_recompute_kwarg:
                fn_kwargs["is_recomputing"] = False
            outputs = fn(*args, **fn_kwargs)
        original_vars = set(tape.watched_variables())

        # Backward pass
        def grad_fn(*output_grads, **kwargs):
            """Recompute outputs for gradient computation."""
            variables = []
            if original_vars:
                variables = kwargs["variables"]
            if set(variables) != original_vars:
                raise ValueError(_WRONG_VARS_ERR)
            del kwargs
            inputs = list(args)
            # Recompute outputs
            with framework_ops.control_dependencies(output_grads):
                if use_data_dep_:
                    inputs = _force_data_dependency(output_grads, inputs)
                with contrib_framework_ops.arg_scope(arg_scope):
                    with variable_scope.variable_scope(vs, reuse=True):
                        with backprop.GradientTape() as tape:
                            fn_kwargs = {}
                            if has_is_recompute_kwarg:
                                fn_kwargs["is_recomputing"] = True
                            outputs = fn(*inputs, **fn_kwargs)
                        recompute_vars = set(tape.watched_variables())
                        if original_vars != recompute_vars:
                            raise ValueError(_WRONG_VARS_ERR)

            if not (isinstance(outputs, list) or isinstance(outputs, tuple)):
                outputs = [outputs]
            outputs = list(outputs)
            grads = gradients_impl.gradients(outputs, inputs + variables,
                                             output_grads)

            if tupleize_grads:
                if use_data_dep_:
                    grads = _tuple_with_data_dep(grads)
                else:
                    grads = control_flow_ops.tuple(grads)

            grad_inputs = grads[:len(inputs)]
            grad_vars = grads[len(inputs):]
            return grad_inputs, grad_vars

        return outputs, grad_fn
예제 #5
0
    def fn_with_recompute(*args):
        """Wrapper for fn."""
        # Capture the variable and arg scopes so we can re-enter them when
        # recomputing.
        vs = variable_scope.get_variable_scope()
        arg_scope = contrib_framework_ops.current_arg_scope()
        # Track all variables touched in the function.
        with backprop.GradientTape() as tape:
            fn_kwargs = {}
            if has_is_recompute_kwarg:
                fn_kwargs["is_recomputing"] = False
            outputs = fn(*args, **fn_kwargs)
        original_vars = set(tape.watched_variables())

        def _grad_fn(output_grads, variables=None):
            # Validate that custom_gradient passes the right variables into grad_fn.
            if original_vars:
                assert variables, (
                    "Fn created variables but the variables were not "
                    "passed to the gradient fn.")
                if set(variables) != original_vars:
                    raise ValueError(_WRONG_VARS_ERR)

            return _recomputing_grad_fn(
                compute_fn=fn,
                original_args=args,
                original_vars=original_vars,
                output_grads=output_grads,
                grad_fn_variables=variables,
                use_data_dep=use_data_dep_,
                tupleize_grads=tupleize_grads,
                arg_scope=arg_scope,
                var_scope=vs,
                has_is_recompute_kwarg=has_is_recompute_kwarg)

        # custom_gradient inspects the signature of the function to determine
        # whether the user expects variables passed in the grad_fn. If the function
        # created variables, the grad_fn should accept the "variables" kwarg.
        if original_vars:

            def grad_fn(*output_grads, **kwargs):
                return _grad_fn(output_grads, kwargs["variables"])
        else:

            def grad_fn(*output_grads):
                return _grad_fn(output_grads)

        return outputs, grad_fn
예제 #6
0
  def fn_with_recompute(*args):
    """Wrapper for fn."""
    # Capture the variable and arg scopes so we can re-enter them when
    # recomputing.
    vs = variable_scope.get_variable_scope()
    arg_scope = contrib_framework_ops.current_arg_scope()
    # Track all variables touched in the function.
    with backprop.GradientTape() as tape:
      fn_kwargs = {}
      if has_is_recompute_kwarg:
        fn_kwargs["is_recomputing"] = False
      outputs = fn(*args, **fn_kwargs)
    original_vars = set(tape.watched_variables())

    def _grad_fn(output_grads, variables=None):
      # Validate that custom_gradient passes the right variables into grad_fn.
      if original_vars:
        assert variables, ("Fn created variables but the variables were not "
                           "passed to the gradient fn.")
        if set(variables) != original_vars:
          raise ValueError(_WRONG_VARS_ERR)

      return _recomputing_grad_fn(
          compute_fn=fn,
          original_args=args,
          original_vars=original_vars,
          output_grads=output_grads,
          grad_fn_variables=variables,
          use_data_dep=use_data_dep_,
          tupleize_grads=tupleize_grads,
          arg_scope=arg_scope,
          var_scope=vs,
          has_is_recompute_kwarg=has_is_recompute_kwarg)

    # custom_gradient inspects the signature of the function to determine
    # whether the user expects variables passed in the grad_fn. If the function
    # created variables, the grad_fn should accept the "variables" kwarg.
    if original_vars:
      def grad_fn(*output_grads, **kwargs):
        return _grad_fn(output_grads, kwargs["variables"])
    else:
      def grad_fn(*output_grads):
        return _grad_fn(output_grads)

    return outputs, grad_fn
예제 #7
0
  def fn_with_recompute(*args):
    """Wrapper for fn."""
    # Forward pass
    vs = variable_scope.get_variable_scope()
    arg_scope = contrib_framework_ops.current_arg_scope()
    with backprop.GradientTape() as tape:
      fn_kwargs = {}
      if has_is_recompute_kwarg:
        fn_kwargs["is_recomputing"] = False
      outputs = fn(*args, **fn_kwargs)
    original_vars = set(tape.watched_variables())

    # Backward pass
    def _grad_fn(output_grads, variables=None):
      """Recompute outputs for gradient computation."""
      variables = variables or []
      if original_vars:
        assert variables, ("Fn created variables but the variables were not "
                           "passed to the gradient fn.")
        if set(variables) != original_vars:
          raise ValueError(_WRONG_VARS_ERR)
      inputs = [array_ops.identity(x) for x in list(args)]
      # Recompute outputs
      with framework_ops.control_dependencies(output_grads):
        if use_data_dep_:
          inputs = _force_data_dependency(output_grads, inputs)
        with contrib_framework_ops.arg_scope(arg_scope):
          with variable_scope.variable_scope(vs, reuse=True):
            with backprop.GradientTape() as tape:
              fn_kwargs = {}
              if has_is_recompute_kwarg:
                fn_kwargs["is_recomputing"] = True
              outputs = fn(*inputs, **fn_kwargs)
            recompute_vars = set(tape.watched_variables())
            if original_vars != recompute_vars:
              raise ValueError(_WRONG_VARS_ERR)

      if not isinstance(outputs, (list, tuple)):
        outputs = [outputs]
      outputs = list(outputs)
      grads = gradients_impl.gradients(outputs, inputs + variables,
                                       output_grads)

      if tupleize_grads:
        if use_data_dep_:
          grads = _tuple_with_data_dep(grads)
        else:
          grads = control_flow_ops.tuple(grads)

      grad_inputs = grads[:len(inputs)]
      grad_vars = grads[len(inputs):]
      return grad_inputs, grad_vars

    # custom_gradient inspects the signature of the function to determine
    # whether the user expects variables passed in the grad_fn. If the function
    # created variables, the grad_fn should accept the "variables" kwarg.
    if original_vars:
      def grad_fn(*output_grads, **kwargs):
        return _grad_fn(output_grads, kwargs["variables"])
    else:
      def grad_fn(*output_grads):
        return _grad_fn(output_grads)

    return outputs, grad_fn
예제 #8
0
    def fn_with_recompute(*args):
        """Wrapper for fn."""
        # Forward pass
        vs = variable_scope.get_variable_scope()
        arg_scope = contrib_framework_ops.current_arg_scope()
        with backprop.GradientTape() as tape:
            fn_kwargs = {}
            if has_is_recompute_kwarg:
                fn_kwargs["is_recomputing"] = False
            outputs = fn(*args, **fn_kwargs)
        original_vars = set(tape.watched_variables())

        # Backward pass
        def _grad_fn(output_grads, variables=None):
            """Recompute outputs for gradient computation."""
            variables = variables or []
            if original_vars:
                assert variables, (
                    "Fn created variables but the variables were not "
                    "passed to the gradient fn.")
                if set(variables) != original_vars:
                    raise ValueError(_WRONG_VARS_ERR)
            inputs = [array_ops.identity(x) for x in list(args)]
            # Recompute outputs
            with framework_ops.control_dependencies(output_grads):
                if use_data_dep_:
                    inputs = _force_data_dependency(output_grads, inputs)
                with contrib_framework_ops.arg_scope(arg_scope):
                    with variable_scope.variable_scope(vs, reuse=True):
                        with backprop.GradientTape() as tape:
                            fn_kwargs = {}
                            if has_is_recompute_kwarg:
                                fn_kwargs["is_recomputing"] = True
                            outputs = fn(*inputs, **fn_kwargs)
                        recompute_vars = set(tape.watched_variables())
                        if original_vars != recompute_vars:
                            raise ValueError(_WRONG_VARS_ERR)

            if not isinstance(outputs, (list, tuple)):
                outputs = [outputs]
            outputs = list(outputs)
            grads = gradients_impl.gradients(outputs, inputs + variables,
                                             output_grads)

            if tupleize_grads:
                if use_data_dep_:
                    grads = _tuple_with_data_dep(grads)
                else:
                    grads = control_flow_ops.tuple(grads)

            grad_inputs = grads[:len(inputs)]
            grad_vars = grads[len(inputs):]
            return grad_inputs, grad_vars

        # custom_gradient inspects the signature of the function to determine
        # whether the user expects variables passed in the grad_fn. If the function
        # created variables, the grad_fn should accept the "variables" kwarg.
        if original_vars:

            def grad_fn(*output_grads, **kwargs):
                return _grad_fn(output_grads, kwargs["variables"])
        else:

            def grad_fn(*output_grads):
                return _grad_fn(output_grads)

        return outputs, grad_fn
예제 #9
0
 def fn_with_recompute(*args):
   cached_vs.append(variable_scope.get_variable_scope())
   cached_arg_scope.append(contrib_framework_ops.current_arg_scope())
   return fn(*args)