def _make_replica_local(method, strategy=None): device_map = values.ReplicaDeviceMap(_devices) v = [] for d, n, init in zip(_devices, ["v", "v/replica"], [1., 2.]): with ops.device(d): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) replica_local = values.SyncOnReadVariable(strategy, device_map, v, method) return v, replica_local
def testVariableOnAnotherDevice(self): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",)) replica_local = values.SyncOnReadVariable( None, device_map, (v,), variable_scope.VariableAggregation.MEAN) self.assertEqual(v.name, replica_local.name) self.assertEqual(v.dtype, replica_local.dtype) self.assertEqual(v.shape, replica_local.shape) self.assertEqual(variable_scope.VariableAggregation.MEAN, replica_local.aggregation)
def _create_mirrored_variable( strategy, device_map, logical_device, # pylint: disable=missing-docstring real_mirrored_creator, *args, **kwargs): # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. collections = kwargs.pop("collections", None) if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] # Get synchronization value synchronization = kwargs.get( "synchronization", variable_scope.VariableSynchronization.ON_WRITE) if synchronization == variable_scope.VariableSynchronization.NONE: raise ValueError( "`NONE` variable synchronization mode is not " "supported with `Mirrored` distribution strategy. Please" " change the `synchronization` for variable: " + kwargs["name"]) elif synchronization == variable_scope.VariableSynchronization.ON_READ: # Variables that are to be synced on read are replica local. is_sync_on_read = True kwargs["trainable"] = False elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or synchronization == variable_scope.VariableSynchronization.AUTO): # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. is_sync_on_read = False else: raise ValueError( "Invalid variable synchronization mode: %s for variable: %s" % (synchronization, kwargs["name"])) # Get aggregation value aggregation = kwargs.pop("aggregation", variable_scope.VariableAggregation.NONE) if aggregation not in ( variable_scope.VariableAggregation.NONE, variable_scope.VariableAggregation.SUM, variable_scope.VariableAggregation.MEAN, variable_scope.VariableAggregation.ONLY_FIRST_REPLICA): raise ValueError( "Invalid variable aggregation mode: %s for variable: %s" % (aggregation, kwargs["name"])) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): devices = device_map.logical_to_actual_devices(logical_device) value_list = real_mirrored_creator(devices, *args, **kwargs) if is_sync_on_read: result = values.SyncOnReadVariable(strategy, device_map, value_list, aggregation, logical_device=logical_device) else: result = values.MirroredVariable(strategy, device_map, value_list, aggregation, logical_device=logical_device) # Add the wrapped variable to the requested collections. # The handling of eager mode and the global step matches # ResourceVariable._init_from_args(). if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in value_list: if v in l: l.remove(v) g.add_to_collections(collections, result) elif ops.GraphKeys.GLOBAL_STEP in collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) return result