def fn_with_recompute(*args): cached_vs.append(variable_scope.get_variable_scope()) # TODO(rsepassi): Rm conditional in TF 1.4 if hasattr(contrib_framework_ops, "current_arg_scope"): cached_arg_scope.append(contrib_framework_ops.current_arg_scope()) else: cached_arg_scope.append({}) return fn(*args)
def fn_with_recompute(*args): """Wrapper for fn.""" # Forward pass vs = variable_scope.get_variable_scope() arg_scope = contrib_framework_ops.current_arg_scope() with backprop.GradientTape() as tape: fn_kwargs = {} if has_is_recompute_kwarg: fn_kwargs["is_recomputing"] = False outputs = fn(*args, **fn_kwargs) original_vars = set(tape.watched_variables()) # Backward pass def grad_fn(*output_grads, **kwargs): """Recompute outputs for gradient computation.""" variables = [] if original_vars: variables = kwargs["variables"] if set(variables) != original_vars: raise ValueError(_WRONG_VARS_ERR) del kwargs inputs = list(args) # Recompute outputs with framework_ops.control_dependencies(output_grads): if use_data_dep_: inputs = _force_data_dependency(output_grads, inputs) with contrib_framework_ops.arg_scope(arg_scope): with variable_scope.variable_scope(vs, reuse=True): with backprop.GradientTape() as tape: fn_kwargs = {} if has_is_recompute_kwarg: fn_kwargs["is_recomputing"] = True outputs = fn(*inputs, **fn_kwargs) recompute_vars = set(tape.watched_variables()) if original_vars != recompute_vars: raise ValueError(_WRONG_VARS_ERR) if not (isinstance(outputs, list) or isinstance(outputs, tuple)): outputs = [outputs] outputs = list(outputs) grads = gradients_impl.gradients(outputs, inputs + variables, output_grads) if tupleize_grads: if use_data_dep_: grads = _tuple_with_data_dep(grads) else: grads = control_flow_ops.tuple(grads) grad_inputs = grads[:len(inputs)] grad_vars = grads[len(inputs):] return grad_inputs, grad_vars return outputs, grad_fn
def fn_with_recompute(*args): """Wrapper for fn.""" # Capture the variable and arg scopes so we can re-enter them when # recomputing. vs = variable_scope.get_variable_scope() arg_scope = contrib_framework_ops.current_arg_scope() # Track all variables touched in the function. with backprop.GradientTape() as tape: fn_kwargs = {} if has_is_recompute_kwarg: fn_kwargs["is_recomputing"] = False outputs = fn(*args, **fn_kwargs) original_vars = set(tape.watched_variables()) def _grad_fn(output_grads, variables=None): # Validate that custom_gradient passes the right variables into grad_fn. if original_vars: assert variables, ( "Fn created variables but the variables were not " "passed to the gradient fn.") if set(variables) != original_vars: raise ValueError(_WRONG_VARS_ERR) return _recomputing_grad_fn( compute_fn=fn, original_args=args, original_vars=original_vars, output_grads=output_grads, grad_fn_variables=variables, use_data_dep=use_data_dep_, tupleize_grads=tupleize_grads, arg_scope=arg_scope, var_scope=vs, has_is_recompute_kwarg=has_is_recompute_kwarg) # custom_gradient inspects the signature of the function to determine # whether the user expects variables passed in the grad_fn. If the function # created variables, the grad_fn should accept the "variables" kwarg. if original_vars: def grad_fn(*output_grads, **kwargs): return _grad_fn(output_grads, kwargs["variables"]) else: def grad_fn(*output_grads): return _grad_fn(output_grads) return outputs, grad_fn
def fn_with_recompute(*args): """Wrapper for fn.""" # Capture the variable and arg scopes so we can re-enter them when # recomputing. vs = variable_scope.get_variable_scope() arg_scope = contrib_framework_ops.current_arg_scope() # Track all variables touched in the function. with backprop.GradientTape() as tape: fn_kwargs = {} if has_is_recompute_kwarg: fn_kwargs["is_recomputing"] = False outputs = fn(*args, **fn_kwargs) original_vars = set(tape.watched_variables()) def _grad_fn(output_grads, variables=None): # Validate that custom_gradient passes the right variables into grad_fn. if original_vars: assert variables, ("Fn created variables but the variables were not " "passed to the gradient fn.") if set(variables) != original_vars: raise ValueError(_WRONG_VARS_ERR) return _recomputing_grad_fn( compute_fn=fn, original_args=args, original_vars=original_vars, output_grads=output_grads, grad_fn_variables=variables, use_data_dep=use_data_dep_, tupleize_grads=tupleize_grads, arg_scope=arg_scope, var_scope=vs, has_is_recompute_kwarg=has_is_recompute_kwarg) # custom_gradient inspects the signature of the function to determine # whether the user expects variables passed in the grad_fn. If the function # created variables, the grad_fn should accept the "variables" kwarg. if original_vars: def grad_fn(*output_grads, **kwargs): return _grad_fn(output_grads, kwargs["variables"]) else: def grad_fn(*output_grads): return _grad_fn(output_grads) return outputs, grad_fn
def fn_with_recompute(*args): """Wrapper for fn.""" # Forward pass vs = variable_scope.get_variable_scope() arg_scope = contrib_framework_ops.current_arg_scope() with backprop.GradientTape() as tape: fn_kwargs = {} if has_is_recompute_kwarg: fn_kwargs["is_recomputing"] = False outputs = fn(*args, **fn_kwargs) original_vars = set(tape.watched_variables()) # Backward pass def _grad_fn(output_grads, variables=None): """Recompute outputs for gradient computation.""" variables = variables or [] if original_vars: assert variables, ("Fn created variables but the variables were not " "passed to the gradient fn.") if set(variables) != original_vars: raise ValueError(_WRONG_VARS_ERR) inputs = [array_ops.identity(x) for x in list(args)] # Recompute outputs with framework_ops.control_dependencies(output_grads): if use_data_dep_: inputs = _force_data_dependency(output_grads, inputs) with contrib_framework_ops.arg_scope(arg_scope): with variable_scope.variable_scope(vs, reuse=True): with backprop.GradientTape() as tape: fn_kwargs = {} if has_is_recompute_kwarg: fn_kwargs["is_recomputing"] = True outputs = fn(*inputs, **fn_kwargs) recompute_vars = set(tape.watched_variables()) if original_vars != recompute_vars: raise ValueError(_WRONG_VARS_ERR) if not isinstance(outputs, (list, tuple)): outputs = [outputs] outputs = list(outputs) grads = gradients_impl.gradients(outputs, inputs + variables, output_grads) if tupleize_grads: if use_data_dep_: grads = _tuple_with_data_dep(grads) else: grads = control_flow_ops.tuple(grads) grad_inputs = grads[:len(inputs)] grad_vars = grads[len(inputs):] return grad_inputs, grad_vars # custom_gradient inspects the signature of the function to determine # whether the user expects variables passed in the grad_fn. If the function # created variables, the grad_fn should accept the "variables" kwarg. if original_vars: def grad_fn(*output_grads, **kwargs): return _grad_fn(output_grads, kwargs["variables"]) else: def grad_fn(*output_grads): return _grad_fn(output_grads) return outputs, grad_fn
def fn_with_recompute(*args): """Wrapper for fn.""" # Forward pass vs = variable_scope.get_variable_scope() arg_scope = contrib_framework_ops.current_arg_scope() with backprop.GradientTape() as tape: fn_kwargs = {} if has_is_recompute_kwarg: fn_kwargs["is_recomputing"] = False outputs = fn(*args, **fn_kwargs) original_vars = set(tape.watched_variables()) # Backward pass def _grad_fn(output_grads, variables=None): """Recompute outputs for gradient computation.""" variables = variables or [] if original_vars: assert variables, ( "Fn created variables but the variables were not " "passed to the gradient fn.") if set(variables) != original_vars: raise ValueError(_WRONG_VARS_ERR) inputs = [array_ops.identity(x) for x in list(args)] # Recompute outputs with framework_ops.control_dependencies(output_grads): if use_data_dep_: inputs = _force_data_dependency(output_grads, inputs) with contrib_framework_ops.arg_scope(arg_scope): with variable_scope.variable_scope(vs, reuse=True): with backprop.GradientTape() as tape: fn_kwargs = {} if has_is_recompute_kwarg: fn_kwargs["is_recomputing"] = True outputs = fn(*inputs, **fn_kwargs) recompute_vars = set(tape.watched_variables()) if original_vars != recompute_vars: raise ValueError(_WRONG_VARS_ERR) if not isinstance(outputs, (list, tuple)): outputs = [outputs] outputs = list(outputs) grads = gradients_impl.gradients(outputs, inputs + variables, output_grads) if tupleize_grads: if use_data_dep_: grads = _tuple_with_data_dep(grads) else: grads = control_flow_ops.tuple(grads) grad_inputs = grads[:len(inputs)] grad_vars = grads[len(inputs):] return grad_inputs, grad_vars # custom_gradient inspects the signature of the function to determine # whether the user expects variables passed in the grad_fn. If the function # created variables, the grad_fn should accept the "variables" kwarg. if original_vars: def grad_fn(*output_grads, **kwargs): return _grad_fn(output_grads, kwargs["variables"]) else: def grad_fn(*output_grads): return _grad_fn(output_grads) return outputs, grad_fn
def fn_with_recompute(*args): cached_vs.append(variable_scope.get_variable_scope()) cached_arg_scope.append(contrib_framework_ops.current_arg_scope()) return fn(*args)