def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring index = {} for i, d in enumerate(devices): with ops.device(d): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] # We append a / to variable names created on towers with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: def initial_value_fn(device=d): if context.executing_eagerly(): init_value = index[devices[0]].value() return array_ops.identity(init_value) else: with ops.device(device): init_value = index[devices[0]].initial_value return array_ops.identity(init_value) kwargs["initial_value"] = initial_value_fn with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): # Don't record operations (e.g. other variable reads) during # variable creation. with tape.stop_recording(): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v return index
def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. collections = kwargs.pop("collections", None) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] colocate_with = kwargs.pop("colocate_with", None) devices = self._get_devices_from(colocate_with) tower_local = kwargs.pop("tower_local_reduce_method", None) if tower_local is not None: kwargs["trainable"] = False # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): index = {} for i, d in enumerate(devices): with ops.device(d): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] kwargs["name"] = "%s/replica_%d" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): initial_value = index[devices[0]].value() else: initial_value = index[devices[0]].initial_value kwargs["initial_value"] = array_ops.identity(initial_value) with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v if tower_local is None: result = values.MirroredVariable(index, index[devices[0]]) else: result = values.TowerLocalVariable( index, index[devices[0]], tower_local) if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in index.values(): l.remove(v) g.add_to_collections(collections, result) return result
def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring strategy, device_map, logical_device, real_mirrored_creator, *args, **kwargs): # Figure out what collections this variable should be added to. # We'll add the TPUMirroredVariable to those collections instead. var_collections = kwargs.pop("collections", None) if var_collections is None: var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] # TODO(jhseu): Should we have different behavior for different # synchronization settings? # Get aggregation value # TODO(jhseu): Support aggregation in a replica context. aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) if aggregation not in [ vs.VariableAggregation.NONE, vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN, vs.VariableAggregation.ONLY_FIRST_REPLICA, ]: raise ValueError("Invalid variable aggregation mode: {} for variable: {}" .format(aggregation, kwargs["name"])) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): devices = device_map.logical_to_actual_devices(logical_device) value_list = real_mirrored_creator(devices, *args, **kwargs) result = values.TPUMirroredVariable( strategy, device_map, value_list, aggregation, logical_device=logical_device) if not (context.executing_eagerly() or ops.inside_function()): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in value_list: l.remove(v) g.add_to_collections(var_collections, result) return result
def compute_gradients(model, images, labels, num_replicas=1): with tf.GradientTape() as grad_tape: logits = model(images, training=True) loss = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) tf.contrib.summary.scalar(name='loss', tensor=loss) if num_replicas != 1: loss /= num_replicas # TODO(b/110991947): We can mistakenly trace the gradient call in # multi-threaded environment. Explicitly disable recording until # this is fixed. with tape.stop_recording(): grads = grad_tape.gradient(loss, model.variables) return grads
def decorated(*args, **kwargs): """Decorated function with custom gradient.""" if context.in_graph_mode(): if kwargs: raise ValueError( "custom_gradient in graph mode doesn't support keyword arguments.") name = "CustomGradient-%s" % tf_ops.uid() args = [tf_ops.convert_to_tensor(x) for x in args] result, grad_fn = f(*args) flat_result = nest.flatten(result) all_tensors = flat_result + args @tf_ops.RegisterGradient(name) def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)])) # Need to return one value per input to the IdentityN, so pad the # gradients of the inputs of the custom_gradient function with the # gradients of the outputs as well. return ([None] * len(flat_result)) + gradients with tf_ops.get_default_graph().gradient_override_map( {"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) return nest.pack_sequence_as( structure=result, flat_sequence=all_tensors[:len(flat_result)]) input_tensors = [x for x in args if isinstance(x, tf_ops.Tensor)] with tape.stop_recording(): result, grad_fn = f(*args, **kwargs) # TODO(apassos): naive uses of custom_gradient will not get the correct # second derivative this way if they capture any output tensors. Change the # signature of custom_gradient. def actual_grad_fn(*outputs): return grad_fn(*outputs) flat_result = nest.flatten(result) tape.record_operation( f.__name__, flat_result, input_tensors, [], actual_grad_fn) flat_result = list(flat_result) return result
def decorated(*args, **kwargs): """Decorated function with custom gradient.""" input_tensors = [x for x in args if isinstance(x, tf_ops.Tensor)] with tape.stop_recording(): result, grad_fn = f(*args, **kwargs) # TODO(apassos): naive uses of custom_gradient will not get the correct # second derivative this way if they capture any output tensors. Change the # signature of custom_gradient. def actual_grad_fn(*outputs): return grad_fn(*outputs) flat_result = nest.flatten(result) tape.record_operation( flat_result, input_tensors, [], actual_grad_fn) flat_result = list(flat_result) return result
def __call__(self, *args, **kwds): """Calls the graph function.""" if self._created_variables: # In this case we have created variables on the first call, so we run the # defunned version which is guaranteed to never create variables. return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable elif self._stateful_fn is not None: # In this case we have not created variables on the first call. So we can # run the first trace but we should fail if variables are created. results = self._stateful_fn(*args, **kwds) if self._created_variables: raise ValueError("Creating variables on a non-first call to a function" " decorated with tf.function.") return results # This is the first call of __call__, so we have to initialize. self._initialize(args, kwds) if self._lifted_all_initializers and self._lifted_placeholders: with ops.init_scope(): handles, placeholders = zip(*self._lifted_placeholders) if context.executing_eagerly(): lifted_fn = function_lib._EagerDefinedFunction( # pylint: disable=protected-access "initializer" + str(ops.uid()), self._lifted_initializer_graph, placeholders, [], {}) with tape.stop_recording(): lifted_fn.call(context.context(), list(handles)) return self._stateless_fn(*args, **kwds) canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds) if not self._created_variables: # If we did not create any variables the trace we have is good enough. return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds) # pylint: disable=protected-access def fn_with_cond(*inner_args, **inner_kwds): """Conditionally runs initialization if it's needed.""" condition = True for wr in self._created_variables: variable = wr() if variable is None: raise ValueError( "A tf.Variable created inside your tf.function has been" " garbage-collected. Your code needs to keep Python references" " to variables created inside `tf.function`s.\n" "\n" "A common way to raise this error is to create and return a" " variable only referenced inside your function:\n" "\n" "@tf.function\n" "def f():\n" " v = tf.Variable(1.0)\n" " return v\n" "\n" "v = f() # Crashes with this error message!\n" "\n" "The reason this crashes is that @tf.function annotated" " function returns a **`tf.Tensor`** with the **value** of the" " variable when the function is called rather than the" " variable instance itself. As such there is no code holding a" " reference to the `v` created inside the function and Python" " garbage collects it.\n" "\n" "The simplest way to fix this issue is to create variables" " outside the function and capture them:\n" "\n" "v = tf.Variable(1.0)\n" "\n" "@tf.function\n" "def f():\n" " return v\n" "\n" "f() # <tf.Tensor: ... numpy=1.>\n" "v.assign_add(1.)\n" "f() # <tf.Tensor: ... numpy=2.>") condition = math_ops.logical_and( condition, resource_variable_ops.var_is_initialized_op( variable.handle)) # We want to call stateless_fn if possible because it avoids recomputing # potentially expensive initializers. return control_flow_ops.cond( condition, lambda: self._stateless_fn(*inner_args, **inner_kwds), functools.partial(self._concrete_stateful_fn._filtered_call, # pylint: disable=protected-access inner_args, inner_kwds)) return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds)
def _real_mirrored_creator(devices, *args, **kwargs): """Creates one MirroredVariable on the current worker.""" unique_var_name = ops.get_default_graph().unique_name( kwargs["name"], mark_as_used=False).rstrip("/") # pylint: disable=protected-access collective_instance_key = self._collective_keys.get_instance_key( key_id=unique_var_name) # Only the first device participles in the broadcast of initial values. group_key = self._collective_keys.get_group_key([devices[0]]) group_size = self._num_workers if "initial_value" not in kwargs: raise ValueError("Initial value must be specified.") initial_value = kwargs["initial_value"] if callable(initial_value): initial_value_fn = initial_value else: initial_value_fn = lambda: initial_value value_list = [] for i, d in enumerate(devices): with ops.init_scope(), ops.device(d): if i == 0: # The initial value fn makes sure variables all initialized to # same values. The first device of the chief worker will send their # variable values to other workers. def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring with ops.device(device): initial_value = initial_value_fn() assert not callable(initial_value) initial_value = ops.convert_to_tensor(initial_value) assert index == 0, index if self._num_workers > 1: if self._is_chief: bcast_send = collective_ops.broadcast_send( initial_value, initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key) with ops.control_dependencies([bcast_send]): return array_ops.identity(initial_value) else: return collective_ops.broadcast_recv( initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key) return initial_value else: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Variables on non-first replica get initial values from the # variables created on the first device of each worker. def _overridden_initial_value_fn(device=d, index=i): assert index > 0 with ops.device(device): if context.executing_eagerly(): return array_ops.identity(value_list[0].value()) else: return array_ops.identity(value_list[0].initial_value) kwargs["initial_value"] = _overridden_initial_value_fn with context.device_policy(context.DEVICE_PLACEMENT_SILENT): # Don't record operations (e.g. other variable reads) during # variable creation. with tape.stop_recording(): v = next_creator(*args, **kwargs) if i == 0: actual_var_name = v.name.split(":")[0] assert unique_var_name == actual_var_name, "%r vs %r" % ( unique_var_name, actual_var_name) assert not isinstance(v, values.DistributedVariable) value_list.append(v) return value_list
def inner(*args, **kwargs): """Inner function closure for calculating gradients.""" current_var_scope = variable_scope.get_variable_scope() with tape_lib.stop_recording(): result = f(*args, **kwargs) def grad_wrapper(*wrapper_args, variables=None): """Wrapper function to accomodate lack of kwargs in graph mode custom_gradient.""" @custom_gradient def inner_recompute_grad(*dresult): """Nested custom gradient function for computing grads in reverse and forward mode autodiff.""" # Gradient calculation for reverse mode autodiff. with backprop.GradientTape() as t: id_args = nest.map_structure(gen_array_ops.identity, args) # Tuple `dresult` should contain at least one tensor. assert len(dresult) >= 1 if not context.executing_eagerly(): # XLA doesn't respect `tf.control_dependencies`. The code block # below manually adds a data dependency to `dresult` to ensure # recomputation of `f(*args, **kwargs)` happens after `dresult`. # This works even if `dresult[0]` is a size 0 tensor as reduce_max # of a size 0 tensor returns -inf. Use reshape here to avoid reading # the entire `dresult[0]`. elem = math_ops.reduce_max( array_ops.reshape(dresult[0], [-1])[:1]) # Cast elem to bool in case elem is NaN. elem_bool = math_ops.cast(elem, dtypes.bool) dresult_dep = array_ops.where_v2( elem_bool == elem_bool, 0., float("nan")) # pylint: disable=comparison-with-itself id_args = nest.map_structure( lambda x: x + math_ops.cast(dresult_dep, x.dtype), id_args) t.watch(id_args) if variables is not None: t.watch(variables) with variable_scope.variable_scope(current_var_scope): recomputed_result = f(*id_args, **kwargs) kw_vars = [] if variables is not None: kw_vars = list(variables) grads = t.gradient( recomputed_result, list(id_args) + kw_vars, output_gradients=dresult, unconnected_gradients=UnconnectedGradients.ZERO) def transpose(*t_args, **t_kwargs): """Gradient function calculation for forward mode autodiff.""" # Just throw an error since gradients / activations are not stored on # tape for recompute. raise NotImplementedError( "recompute_grad tried to transpose grad of {}. " "Consider not using recompute_grad in forward mode" "autodiff".format(f.__name__)) return (grads[:len(id_args)], grads[len(id_args):]), transpose return inner_recompute_grad(*wrapper_args) return result, grad_wrapper
def _create_mirrored_variable( strategy, device_map, logical_device, # pylint: disable=missing-docstring real_mirrored_creator, *args, **kwargs): # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. collections = kwargs.pop("collections", None) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] # Get synchronization value synchronization = kwargs.get( "synchronization", variable_scope.VariableSynchronization.ON_WRITE) if synchronization == variable_scope.VariableSynchronization.NONE: raise ValueError( "`NONE` variable synchronization mode is not " "supported with `Mirrored` distribution strategy. Please" " change the `synchronization` for variable: " + kwargs["name"]) elif synchronization == variable_scope.VariableSynchronization.ON_READ: # Variables that are to be synced on read are replica local. is_replica_local = True kwargs["trainable"] = False elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or synchronization == variable_scope.VariableSynchronization.AUTO): # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. is_replica_local = False else: raise ValueError("Invalid variable synchronization mode: " + synchronization + " for variable: " + kwargs["name"]) # Get aggregation value aggregation = kwargs.pop("aggregation", variable_scope.VariableAggregation.NONE) if aggregation not in ( variable_scope.VariableAggregation.NONE, variable_scope.VariableAggregation.SUM, variable_scope.VariableAggregation.MEAN, variable_scope.VariableAggregation.ONLY_FIRST_REPLICA): raise ValueError("Invalid variable aggregation mode: " + aggregation + " for variable: " + kwargs["name"]) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): devices = device_map.logical_to_actual_devices(logical_device) value_list = real_mirrored_creator(devices, *args, **kwargs) if is_replica_local: result = values.ReplicaLocalVariable(strategy, device_map, value_list, aggregation, logical_device=logical_device) else: result = values.MirroredVariable(strategy, device_map, value_list, aggregation, logical_device=logical_device) # Add the wrapped variable to the requested collections. # The handling of eager mode and the global step matches # ResourceVariable._init_from_args(). if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in value_list: if v in l: l.remove(v) g.add_to_collections(collections, result) elif ops.GraphKeys.GLOBAL_STEP in collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) return result
def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. collections = kwargs.pop("collections", None) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] colocate_with = kwargs.pop("colocate_with", None) devices = self._get_devices_from(colocate_with) tower_local = kwargs.pop("tower_local_reduce_method", None) if tower_local is not None: kwargs["trainable"] = False # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): index = {} for i, d in enumerate(devices): with ops.device(d): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] # We append a / to variable names created on towers with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): kwargs["initial_value"] = array_ops.identity( index[devices[0]].value()) else: def initial_value_fn(device=d): with ops.device(device): return array_ops.identity(index[devices[0]].initial_value) kwargs["initial_value"] = initial_value_fn with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v if tower_local is None: result = values.MirroredVariable(index, index[devices[0]]) else: result = values.TowerLocalVariable( index, index[devices[0]], tower_local) if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in index.values(): l.remove(v) g.add_to_collections(collections, result) return result
def create_mirrored_variable( # pylint: disable=missing-docstring strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs): # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. var_collections = kwargs.pop("collections", None) if var_collections is None: var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] synchronization = kwargs.get("synchronization", vs.VariableSynchronization.ON_WRITE) if synchronization == vs.VariableSynchronization.NONE: raise ValueError( "`NONE` variable synchronization mode is not supported with `Mirrored` " "distribution strategy. Please change the `synchronization` for " "variable: " + str(kwargs["name"])) elif synchronization == vs.VariableSynchronization.ON_READ: is_sync_on_read = True elif synchronization in (vs.VariableSynchronization.ON_WRITE, vs.VariableSynchronization.AUTO): # `AUTO` synchronization defaults to `ON_WRITE`. is_sync_on_read = False else: raise ValueError( "Invalid variable synchronization mode: %s for variable: %s" % (synchronization, kwargs["name"])) aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) if aggregation not in (vs.VariableAggregation.NONE, vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN, vs.VariableAggregation.ONLY_FIRST_REPLICA): raise ValueError( "Invalid variable aggregation mode: %s for variable: %s" % (aggregation, kwargs["name"])) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): value_list = real_mirrored_creator(**kwargs) var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls result = var_cls(strategy, value_list, aggregation) # Install the created DistributedVariable as _distributed_container property # of the underlying variables, to make it easy to map back to the container. for v in result.values: # Hold a strong reference to avoid the container from being GC-ed. After # v = v.assign(), the user code may no longer holds references to the # original container, since v.assign() returns a new DistributedVariable. v._distributed_container = result # pylint: disable=protected-access # Add the wrapped variable to the requested collections. # The handling of eager mode and the global step matches # ResourceVariable._init_from_args(). if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for value in value_list: for i, trainable_variable in enumerate(l): if value is trainable_variable: del l[i] break g.add_to_collections(var_collections, result) elif ops.GraphKeys.GLOBAL_STEP in var_collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) return result
def _real_mirrored_creator(devices, *args, **kwargs): """Creates one MirroredVariable on the current worker.""" unique_var_name = ops.get_default_graph().unique_name( kwargs["name"], mark_as_used=False).rstrip("/") # pylint: disable=protected-access collective_instance_key = self._collective_keys.get_instance_key( key_id=unique_var_name) # Only the first device participles in the broadcast of initial values. group_key = self._collective_keys.get_group_key([devices[0]]) group_size = self._num_workers if "initial_value" not in kwargs: raise ValueError("Initial value must be specified.") initial_value = kwargs["initial_value"] if callable(initial_value): initial_value_fn = initial_value else: initial_value_fn = lambda: initial_value value_list = [] for i, d in enumerate(devices): with ops.init_scope(), ops.device(d): if i == 0: # The initial value fn makes sure variables all initialized to # same values. The first device of the chief worker will send their # variable values to other workers. def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring with ops.device(device): initial_value = initial_value_fn() assert not callable(initial_value) initial_value = ops.convert_to_tensor( initial_value) assert index == 0, index if self._num_workers > 1: if self._is_chief: bcast_send = collective_ops.broadcast_send( initial_value, initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key) with ops.control_dependencies( [bcast_send]): return array_ops.identity( initial_value) else: return collective_ops.broadcast_recv( initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key) return initial_value else: # Give replicas meaningful distinct names: var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Variables on non-first replica get initial values from the # variables created on the first device of each worker. def _overridden_initial_value_fn(device=d, index=i): assert index > 0 with ops.device(device): if context.executing_eagerly(): return array_ops.identity( value_list[0].value()) else: return array_ops.identity( value_list[0].initial_value) kwargs["initial_value"] = _overridden_initial_value_fn with context.context().device_policy( context.DEVICE_PLACEMENT_SILENT): # Don't record operations (e.g. other variable reads) during # variable creation. with tape.stop_recording(): v = next_creator(*args, **kwargs) if i == 0: actual_var_name = v.name.split(":")[0] assert unique_var_name == actual_var_name, "%r vs %r" % ( unique_var_name, actual_var_name) assert not isinstance(v, values.DistributedVariable) value_list.append(v) return value_list
def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping, policy_mapping, **kwargs): """Create distributed variables with given synchronization and aggregation.""" # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. var_collections = kwargs.pop("collections", None) if var_collections is None: var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] synchronization = _validate_synchronization(kwargs) # Update synchronization in kwargs in case it's AUTO, which is converted to # ON_WRITE. kwargs["synchronization"] = synchronization aggregation = _validate_aggregation(kwargs) use_var_policy = getattr(strategy.extended, "_use_var_policy", False) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): value_list = real_mirrored_creator(**kwargs) if use_var_policy: var_policy_cls = policy_mapping.get(synchronization) var_policy = var_policy_cls(aggregation=aggregation) var_cls = class_mapping.get("VariableClass") result = var_cls(strategy, value_list, aggregation, var_policy=var_policy) else: var_cls = class_mapping.get(synchronization) result = var_cls(strategy, value_list, aggregation) # Add the wrapped variable to the requested collections. # The handling of eager mode and the global step matches # ResourceVariable._init_from_args(). if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for value in value_list: for i, trainable_variable in enumerate(l): if value is trainable_variable: del l[i] break g.add_to_collections(var_collections, result) elif ops.GraphKeys.GLOBAL_STEP in var_collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) return result
def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. collections = kwargs.pop("collections", None) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] colocate_with = kwargs.pop("colocate_with", None) devices = self._get_devices_from(colocate_with) # Get synchronization value synchronization = kwargs.get( "synchronization", variable_scope.VariableSynchronization.ON_WRITE) if synchronization == variable_scope.VariableSynchronization.NONE: raise ValueError("`NONE` variable synchronization mode is not " "supported with `Mirrored` distribution strategy. Please" " change the `synchronization` for variable: " + kwargs["name"]) elif synchronization == variable_scope.VariableSynchronization.ON_READ: # Variables that are to be synced on read are tower local. is_tower_local = True kwargs["trainable"] = False elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or synchronization == variable_scope.VariableSynchronization.AUTO): # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. is_tower_local = False else: raise ValueError("Invalid variable synchronization mode: " + synchronization + " for variable: " + kwargs["name"]) # Get aggregation value aggregation = kwargs.pop("aggregation", variable_scope.VariableAggregation.NONE) if aggregation not in [ variable_scope.VariableAggregation.NONE, variable_scope.VariableAggregation.SUM, variable_scope.VariableAggregation.MEAN ]: raise ValueError("Invalid variable aggregation mode: " + aggregation + " for variable: " + kwargs["name"]) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): index = {} for i, d in enumerate(devices): with ops.device(d): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] # We append a / to variable names created on towers with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): kwargs["initial_value"] = array_ops.identity( index[devices[0]].value()) else: def initial_value_fn(device=d): with ops.device(device): return array_ops.identity(index[devices[0]].initial_value) kwargs["initial_value"] = initial_value_fn with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v if is_tower_local: result = values.TowerLocalVariable(index, index[devices[0]], aggregation) else: result = values.MirroredVariable(index, index[devices[0]], aggregation) if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in index.values(): l.remove(v) g.add_to_collections(collections, result) return result
def __call__(self, *args, **kwds): """Calls the graph function.""" if self._created_variables: # In this case we have created variables on the first call, so we run the # defunned version which is guaranteed to never create variables. return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable elif self._stateful_fn is not None: # In this case we have not created variables on the first call. So we can # run the first trace but we should fail if variables are created. results = self._stateful_fn(*args, **kwds) if self._created_variables: raise ValueError( "Creating variables on a non-first call to a function" " decorated with tf.function.") return results # This is the first call of __call__, so we have to initialize. self._initialize(args, kwds) if self._lifted_all_initializers and self._lifted_placeholders: with ops.init_scope(): handles, placeholders = zip(*self._lifted_placeholders) if context.executing_eagerly(): lifted_fn = function_lib._EagerDefinedFunction( # pylint: disable=protected-access "initializer" + str(ops.uid()), self._lifted_initializer_graph, placeholders, [], {}) with tape.stop_recording(): lifted_fn.call(context.context(), list(handles)) return self._stateless_fn(*args, **kwds) canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds) if not self._created_variables: # If we did not create any variables the trace we have is good enough. return self._concrete_stateful_fn._filtered_call( canon_args, canon_kwds) # pylint: disable=protected-access def fn_with_cond(*inner_args, **inner_kwds): """Conditionally runs initialization if it's needed.""" condition = True for wr in self._created_variables: variable = wr() if variable is None: raise ValueError( "A tf.Variable created inside your tf.function has been" " garbage-collected. Your code needs to keep Python references" " to variables created inside `tf.function`s.\n" "\n" "A common way to raise this error is to create and return a" " variable only referenced inside your function:\n" "\n" "@tf.function\n" "def f():\n" " v = tf.Variable(1.0)\n" " return v\n" "\n" "v = f() # Crashes with this error message!\n" "\n" "The reason this crashes is that @tf.function annotated" " function returns a **`tf.Tensor`** with the **value** of the" " variable when the function is called rather than the" " variable instance itself. As such there is no code holding a" " reference to the `v` created inside the function and Python" " garbage collects it.\n" "\n" "The simplest way to fix this issue is to create variables" " outside the function and capture them:\n" "\n" "v = tf.Variable(1.0)\n" "\n" "@tf.function\n" "def f():\n" " return v\n" "\n" "f() # <tf.Tensor: ... numpy=1.>\n" "v.assign_add(1.)\n" "f() # <tf.Tensor: ... numpy=2.>") condition = math_ops.logical_and( condition, resource_variable_ops.var_is_initialized_op( variable.handle)) # We want to call stateless_fn if possible because it avoids recomputing # potentially expensive initializers. return control_flow_ops.cond( condition, lambda: self._stateless_fn(*inner_args, **inner_kwds), functools.partial( self._concrete_stateful_fn._filtered_call, # pylint: disable=protected-access inner_args, inner_kwds)) return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds)
def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. collections = kwargs.pop("collections", None) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] # Get synchronization value synchronization = kwargs.get("synchronization", variable_scope.VariableSynchronization.ON_WRITE) if synchronization == variable_scope.VariableSynchronization.NONE: raise ValueError("`NONE` variable synchronization mode is not " "supported with `Mirrored` distribution strategy. Please" " change the `synchronization` for variable: " + kwargs["name"]) elif synchronization == variable_scope.VariableSynchronization.ON_READ: # Variables that are to be synced on read are tower local. is_tower_local = True kwargs["trainable"] = False elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or synchronization == variable_scope.VariableSynchronization.AUTO): # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. is_tower_local = False else: raise ValueError("Invalid variable synchronization mode: " + synchronization + " for variable: " + kwargs["name"]) # Get aggregation value aggregation = kwargs.pop("aggregation", variable_scope.VariableAggregation.NONE) if aggregation not in ( variable_scope.VariableAggregation.NONE, variable_scope.VariableAggregation.SUM, variable_scope.VariableAggregation.MEAN, variable_scope.VariableAggregation.ONLY_FIRST_TOWER ): raise ValueError("Invalid variable aggregation mode: " + aggregation + " for variable: " + kwargs["name"]) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): index = real_mirrored_creator(devices, *args, **kwargs) if is_tower_local: result = values.TowerLocalVariable(index, index[devices[0]], aggregation) else: result = values.MirroredVariable(index, index[devices[0]], aggregation) # Add the wrapped variable to the requested collections. # The handling of eager mode and the global step matches # ResourceVariable._init_from_args(). if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in index.values(): l.remove(v) g.add_to_collections(collections, result) elif ops.GraphKeys.GLOBAL_STEP in collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) return result
def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. collections = kwargs.pop("collections", None) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] colocate_with = kwargs.pop("colocate_with", None) devices = self._get_devices_from(colocate_with) # Get synchronization value synchronization = kwargs.get( "synchronization", variable_scope.VariableSynchronization.ON_WRITE) if synchronization == variable_scope.VariableSynchronization.NONE: raise ValueError( "`NONE` variable synchronization mode is not " "supported with `Mirrored` distribution strategy. Please" " change the `synchronization` for variable: " + kwargs["name"]) elif synchronization == variable_scope.VariableSynchronization.ON_READ: # Variables that are to be synced on read are tower local. is_tower_local = True kwargs["trainable"] = False elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or synchronization == variable_scope.VariableSynchronization.AUTO): # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. is_tower_local = False else: raise ValueError("Invalid variable synchronization mode: " + synchronization + " for variable: " + kwargs["name"]) # Get aggregation value aggregation = kwargs.pop("aggregation", variable_scope.VariableAggregation.NONE) if aggregation not in [ variable_scope.VariableAggregation.NONE, variable_scope.VariableAggregation.SUM, variable_scope.VariableAggregation.MEAN ]: raise ValueError("Invalid variable aggregation mode: " + aggregation + " for variable: " + kwargs["name"]) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): index = {} for i, d in enumerate(devices): with ops.device(d): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] # We append a / to variable names created on towers with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): kwargs["initial_value"] = array_ops.identity( index[devices[0]].value()) else: def initial_value_fn(device=d): with ops.device(device): return array_ops.identity( index[devices[0]].initial_value) kwargs["initial_value"] = initial_value_fn with context.context().device_policy( context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v if is_tower_local: result = values.TowerLocalVariable(index, index[devices[0]], aggregation) else: result = values.MirroredVariable(index, index[devices[0]], aggregation) if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in index.values(): l.remove(v) g.add_to_collections(collections, result) return result