def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. collections = kwargs.pop("collections", None) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] colocate_with = kwargs.pop("colocate_with", None) devices = self._get_devices_from(colocate_with) tower_local = kwargs.pop("tower_local_reduce_method", None) if tower_local is not None: kwargs["trainable"] = False # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): index = {} for i, d in enumerate(devices): with ops.device(d): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] kwargs["name"] = "%s/replica_%d" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): initial_value = index[devices[0]].value() else: initial_value = index[devices[0]].initial_value kwargs["initial_value"] = array_ops.identity( initial_value) with context.context().device_policy( context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v if tower_local is None: result = values.MirroredVariable(index, index[devices[0]]) else: result = values.TowerLocalVariable(index, index[devices[0]], tower_local) if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in index.values(): l.remove(v) g.add_to_collections(collections, result) return result
def testVariableOnAnotherDevice(self): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) index = {"/job:foo/device:CPU:0": v} mirrored = values.MirroredVariable(index, v) self.assertEquals(v.name, mirrored.name) self.assertEquals(v.dtype, mirrored.dtype) self.assertEquals(v.shape, mirrored.shape)
def _make_mirrored(): v = [] index = {} devices = ["/device:GPU:0", "/device:CPU:0"] for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): with ops.device(d): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) index[d] = v[-1] mirrored = values.MirroredVariable(index, v[0]) return v, devices, mirrored
def testFetchAMirroredVariable(self): if context.num_gpus() < 1 or context.executing_eagerly(): self.skipTest("A GPU is not available for this test or it's eager mode.") with self.test_session( graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy( ["/device:GPU:0"]).scope(): with ops.device("/device:GPU:0"): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) mirrored = values.MirroredVariable({"/device:GPU:0": v}, v) sess.run(variables_lib.global_variables_initializer()) sess.run({"complicated": mirrored})
def testOneDevice(self): result = values.regroup({_device_str(0): _nested_value("1")}) # On one device regroup() and select_device() are basically identity. self.assertEqual(_nested_value("1"), result) self.assertEqual(_nested_value("1"), values.select_device(_device_str(0), result)) # The one exception has to do with MirroredVariables. d = "/device:CPU:0" with ops.device(d): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) index = {d: v} mirrored = values.MirroredVariable(index, v) result = values.regroup(index) self.assertIs(mirrored, result)
def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. collections = kwargs.pop("collections", None) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] # Get synchronization value synchronization = kwargs.get( "synchronization", variable_scope.VariableSynchronization.ON_WRITE) if synchronization == variable_scope.VariableSynchronization.NONE: raise ValueError( "`NONE` variable synchronization mode is not " "supported with `Mirrored` distribution strategy. Please" " change the `synchronization` for variable: " + kwargs["name"]) elif synchronization == variable_scope.VariableSynchronization.ON_READ: # Variables that are to be synced on read are replica local. is_replica_local = True kwargs["trainable"] = False elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or synchronization == variable_scope.VariableSynchronization.AUTO): # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. is_replica_local = False else: raise ValueError("Invalid variable synchronization mode: " + synchronization + " for variable: " + kwargs["name"]) # Get aggregation value aggregation = kwargs.pop("aggregation", variable_scope.VariableAggregation.NONE) if aggregation not in ( variable_scope.VariableAggregation.NONE, variable_scope.VariableAggregation.SUM, variable_scope.VariableAggregation.MEAN, variable_scope.VariableAggregation.ONLY_FIRST_REPLICA): raise ValueError("Invalid variable aggregation mode: " + aggregation + " for variable: " + kwargs["name"]) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): index = real_mirrored_creator(devices, *args, **kwargs) if is_replica_local: result = values.ReplicaLocalVariable(index, index[devices[0]], aggregation) else: result = values.MirroredVariable(index, index[devices[0]], aggregation) # Add the wrapped variable to the requested collections. # The handling of eager mode and the global step matches # ResourceVariable._init_from_args(). if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in index.values(): if v in l: l.remove(v) g.add_to_collections(collections, result) elif ops.GraphKeys.GLOBAL_STEP in collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) return result
def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. collections = kwargs.pop("collections", None) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] colocate_with = kwargs.pop("colocate_with", None) devices = self._get_devices_from(colocate_with) # Get synchronization value synchronization = kwargs.get( "synchronization", variable_scope.VariableSynchronization.ON_WRITE) if synchronization == variable_scope.VariableSynchronization.NONE: raise ValueError( "`NONE` variable synchronization mode is not " "supported with `Mirrored` distribution strategy. Please" " change the `synchronization` for variable: " + kwargs["name"]) elif synchronization == variable_scope.VariableSynchronization.ON_READ: # Variables that are to be synced on read are tower local. is_tower_local = True kwargs["trainable"] = False elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or synchronization == variable_scope.VariableSynchronization.AUTO): # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. is_tower_local = False else: raise ValueError("Invalid variable synchronization mode: " + synchronization + " for variable: " + kwargs["name"]) # Get aggregation value aggregation = kwargs.pop("aggregation", variable_scope.VariableAggregation.NONE) if aggregation not in [a for a in variable_scope.VariableAggregation]: raise ValueError("Invalid variable aggregation mode: " + aggregation + " for variable: " + kwargs["name"]) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): index = {} for i, d in enumerate(devices): with ops.device(d): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] # We append a / to variable names created on towers with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): kwargs["initial_value"] = array_ops.identity( index[devices[0]].value()) else: def initial_value_fn(device=d): with ops.device(device): return array_ops.identity( index[devices[0]].initial_value) kwargs["initial_value"] = initial_value_fn with context.context().device_policy( context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v if is_tower_local: result = values.TowerLocalVariable(index, index[devices[0]], aggregation) else: result = values.MirroredVariable(index, index[devices[0]], aggregation) if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in index.values(): l.remove(v) g.add_to_collections(collections, result) return result