def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring strategy, device_map, logical_device, real_mirrored_creator, *args, **kwargs): # Figure out what collections this variable should be added to. # We'll add the TPUMirroredVariable to those collections instead. var_collections = kwargs.pop("collections", None) if var_collections is None: var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] # TODO(jhseu): Should we have different behavior for different # synchronization settings? # Get aggregation value # TODO(jhseu): Support aggregation in a replica context. aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) if aggregation not in [ vs.VariableAggregation.NONE, vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN, vs.VariableAggregation.ONLY_FIRST_REPLICA, ]: raise ValueError( "Invalid variable aggregation mode: {} for variable: {}".format( aggregation, kwargs["name"])) # Ignore user-specified caching device, not needed for mirrored variables. kwargs.pop("caching_device", None) # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): devices = device_map.logical_to_actual_devices(logical_device) value_list = real_mirrored_creator(devices, *args, **kwargs) result = values.TPUMirroredVariable(strategy, device_map, value_list, aggregation, logical_device=logical_device) if not (context.executing_eagerly() or ops.inside_function()): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove # them and replace with the MirroredVariable. We can't set # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in value_list: l.remove(v) g.add_to_collections(var_collections, result) return result
def testFetchOnFrozenGraph(self): with context.graph_mode(): v = values.TPUMirroredVariable( strategy=None, device_map=values.SingleDeviceMap("/cpu:0"), values=[variables_lib.Variable(42.)], aggregation=None) self.evaluate(variables_lib.global_variables_initializer()) ops.get_default_graph().finalize() self.assertEqual(42., self.evaluate(v))
def test_supports_distributed_variables(self): mirrored = distributed_values.MirroredVariable( None, [variables.Variable(1.)], variables.VariableAggregation.SUM) tpu = distributed_values.TPUMirroredVariable( strategy=None, values=[variables.Variable(42.)], aggregation=None) aggregating = distributed_values.AggregatingVariable( strategy=None, v=variables.Variable(1.), aggregation=None) m = module.Module() m.a = mirrored m.b = tpu m.c = aggregating self.assertEqual(m.variables, (mirrored, tpu, aggregating))