示例#1
0
def _create_tpu_mirrored_variable(  # pylint: disable=missing-docstring
        strategy, device_map, logical_device, real_mirrored_creator, *args,
        **kwargs):
    # Figure out what collections this variable should be added to.
    # We'll add the TPUMirroredVariable to those collections instead.
    var_collections = kwargs.pop("collections", None)
    if var_collections is None:
        var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    kwargs["collections"] = []

    # TODO(jhseu): Should we have different behavior for different
    # synchronization settings?

    # Get aggregation value
    # TODO(jhseu): Support aggregation in a replica context.
    aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
    if aggregation not in [
            vs.VariableAggregation.NONE,
            vs.VariableAggregation.SUM,
            vs.VariableAggregation.MEAN,
            vs.VariableAggregation.ONLY_FIRST_REPLICA,
    ]:
        raise ValueError(
            "Invalid variable aggregation mode: {} for variable: {}".format(
                aggregation, kwargs["name"]))

    # Ignore user-specified caching device, not needed for mirrored variables.
    kwargs.pop("caching_device", None)

    # TODO(josh11b,apassos): It would be better if variable initialization
    # was never recorded on the tape instead of having to do this manually
    # here.
    with tape.stop_recording():
        devices = device_map.logical_to_actual_devices(logical_device)
        value_list = real_mirrored_creator(devices, *args, **kwargs)
        result = values.TPUMirroredVariable(strategy,
                                            device_map,
                                            value_list,
                                            aggregation,
                                            logical_device=logical_device)

    if not (context.executing_eagerly() or ops.inside_function()):
        g = ops.get_default_graph()
        # If "trainable" is True, next_creator() will add the member variables
        # to the TRAINABLE_VARIABLES collection, so we manually remove
        # them and replace with the MirroredVariable. We can't set
        # "trainable" to False for next_creator() since that causes functions
        # like implicit_gradients to skip those variables.
        if kwargs.get("trainable", True):
            var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
            for v in value_list:
                l.remove(v)
        g.add_to_collections(var_collections, result)
    return result
示例#2
0
    def testFetchOnFrozenGraph(self):
        with context.graph_mode():
            v = values.TPUMirroredVariable(
                strategy=None,
                device_map=values.SingleDeviceMap("/cpu:0"),
                values=[variables_lib.Variable(42.)],
                aggregation=None)

            self.evaluate(variables_lib.global_variables_initializer())
            ops.get_default_graph().finalize()
            self.assertEqual(42., self.evaluate(v))
示例#3
0
  def test_supports_distributed_variables(self):
    mirrored = distributed_values.MirroredVariable(
        None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
    tpu = distributed_values.TPUMirroredVariable(
        strategy=None,
        values=[variables.Variable(42.)],
        aggregation=None)
    aggregating = distributed_values.AggregatingVariable(
        strategy=None, v=variables.Variable(1.), aggregation=None)

    m = module.Module()
    m.a = mirrored
    m.b = tpu
    m.c = aggregating
    self.assertEqual(m.variables, (mirrored, tpu, aggregating))