Exemplo n.º 1
0
  def testVariableOnAnotherDevice(self):
    v = variable_scope.get_variable(
        name="v", initializer=[1.], use_resource=True)
    mirrored = values_lib.MirroredVariable(
        None, (v,), variable_scope.VariableAggregation.MEAN)

    self.assertEqual(v.name, mirrored.name)
    self.assertEqual(v.dtype, mirrored.dtype)
    self.assertEqual(v.shape, mirrored.shape)
Exemplo n.º 2
0
 def testFetchAMirroredVariable(self, distribution):
   with self.session(graph=ops.Graph()) as sess, distribution.scope():
     with ops.device("/device:GPU:0"):
       v = variable_scope.get_variable(
           name="v", initializer=1., use_resource=True)
     mirrored = values.MirroredVariable(
         distribution, values.ReplicaDeviceMap(("/device:GPU:0",)), (v,),
         variable_scope.VariableAggregation.MEAN)
     sess.run(variables_lib.global_variables_initializer())
     sess.run({"complicated": mirrored})
Exemplo n.º 3
0
  def testVariableOnAnotherDevice(self):
    v = variable_scope.get_variable(
        name="v", initializer=[1.], use_resource=True)
    device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",))
    mirrored = values.MirroredVariable(None, device_map, (v,),
                                       variable_scope.VariableAggregation.MEAN)

    self.assertEqual(v.name, mirrored.name)
    self.assertEqual(v.dtype, mirrored.dtype)
    self.assertEqual(v.shape, mirrored.shape)
Exemplo n.º 4
0
  def testVariableOnAnotherDevice(self):
    v = variable_scope.get_variable(
        name="v", initializer=[1.], use_resource=True)
    index = {"/job:foo/device:CPU:0": v}
    mirrored = values.MirroredVariable(index, v,
                                       variable_scope.VariableAggregation.MEAN)

    self.assertEqual(v.name, mirrored.name)
    self.assertEqual(v.dtype, mirrored.dtype)
    self.assertEqual(v.shape, mirrored.shape)
Exemplo n.º 5
0
def _make_mirrored():
  v = []
  devices = ["/device:GPU:0", "/device:CPU:0"]
  for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
    with ops.device(d):
      v.append(variable_scope.get_variable(
          name=n, initializer=init, use_resource=True))
  device_map = values.ReplicaDeviceMap(devices)
  mirrored = values.MirroredVariable(None, device_map, v,
                                     variable_scope.VariableAggregation.SUM)
  return v, device_map, mirrored
Exemplo n.º 6
0
def _make_mirrored():
  v = []
  index = {}
  devices = ["/device:GPU:0", "/device:CPU:0"]
  for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
    with ops.device(d):
      v.append(variable_scope.get_variable(
          name=n, initializer=init, use_resource=True))
      index[d] = v[-1]
  mirrored = values.MirroredVariable(index, v[0],
                                     variable_scope.VariableAggregation.SUM)
  return v, devices, mirrored
Exemplo n.º 7
0
  def test_supports_distributed_variables(self):
    mirrored = distributed_values.MirroredVariable(
        None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
    tpu = tpu_values.TPUMirroredVariable(
        strategy=None, values=[variables.Variable(42.)], aggregation=None)
    aggregating = ps_values.AggregatingVariable(
        strategy=None, v=variables.Variable(1.), aggregation=None)

    m = module.Module()
    m.a = mirrored
    m.b = tpu
    m.c = aggregating
    self.assertEqual(m.variables, (mirrored, tpu, aggregating))
Exemplo n.º 8
0
    def testFetchAMirroredVariable(self):
        if context.num_gpus() < 1 or context.executing_eagerly():
            self.skipTest(
                "A GPU is not available for this test or it's eager mode.")

        with self.session(
                graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy(
                    ["/device:GPU:0"]).scope():
            with ops.device("/device:GPU:0"):
                v = variable_scope.get_variable(name="v",
                                                initializer=1.,
                                                use_resource=True)
            mirrored = values.MirroredVariable(
                {"/device:GPU:0": v}, v,
                variable_scope.VariableAggregation.MEAN)
            sess.run(variables_lib.global_variables_initializer())
            sess.run({"complicated": mirrored})
Exemplo n.º 9
0
  def testOneDevice(self):
    result = values.regroup({_device_str(0): _nested_value("1")})
    # On one device regroup() and select_device() are basically identity.
    self.assertEqual(_nested_value("1"), result)
    self.assertEqual(_nested_value("1"),
                     values.select_device(_device_str(0), result))

    # The one exception has to do with MirroredVariables.
    d = "/device:CPU:0"
    with ops.device(d):
      v = variable_scope.get_variable(
          name="v", initializer=1., use_resource=True)
      index = {d: v}
    mirrored = values.MirroredVariable(index, v,
                                       variable_scope.VariableAggregation.SUM)
    result = values.regroup(index)
    self.assertIs(mirrored, result)
Exemplo n.º 10
0
    def test_supports_distributed_variables(self):
        device_map = distributed_values.SingleDeviceMap("/CPU:0")
        mirrored = distributed_values.MirroredVariable(
            None, device_map, [variables.Variable(1.)],
            variables.VariableAggregation.SUM)
        tpu = distributed_values.TPUMirroredVariable(
            strategy=None,
            device_map=device_map,
            values=[variables.Variable(42.)],
            aggregation=None)
        aggregating = distributed_values.AggregatingVariable(
            strategy=None, v=variables.Variable(1.), aggregation=None)

        m = module.Module()
        m.a = mirrored
        m.b = tpu
        m.c = aggregating
        self.assertEqual(m.variables, (mirrored, tpu, aggregating))
Exemplo n.º 11
0
  def testOneDevice(self):
    device_map = values.ReplicaDeviceMap((_device_str(0),))
    result = values.regroup(device_map, (_nested_value("1"),))
    # On one device regroup() and select_replica() are basically identity.
    self.assertEqual(_nested_value("1"), result)
    self.assertEqual(_nested_value("1"),
                     values.select_replica(0, result))

    # The one exception has to do with MirroredVariables.
    d = "/device:CPU:0"
    with ops.device(d):
      v = variable_scope.get_variable(
          name="v", initializer=1., use_resource=True)
      device_map = values.ReplicaDeviceMap((d,))
    mirrored = values.MirroredVariable(None, device_map, (v,),
                                       variable_scope.VariableAggregation.SUM)
    result = values.regroup(device_map, (v,))
    self.assertIs(mirrored, result)
def _create_mirrored_variable(
        strategy,
        device_map,
        logical_device,  # pylint: disable=missing-docstring
        real_mirrored_creator,
        *args,
        **kwargs):
    # Figure out what collections this variable should be added to.
    # We'll add the MirroredVariable to those collections instead.
    collections = kwargs.pop("collections", None)
    if collections is None:
        collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    kwargs["collections"] = []

    # Get synchronization value
    synchronization = kwargs.get(
        "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
    if synchronization == variable_scope.VariableSynchronization.NONE:
        raise ValueError(
            "`NONE` variable synchronization mode is not "
            "supported with `Mirrored` distribution strategy. Please"
            " change the `synchronization` for variable: " + kwargs["name"])
    elif synchronization == variable_scope.VariableSynchronization.ON_READ:
        # Variables that are to be synced on read are replica local.
        is_sync_on_read = True
        kwargs["trainable"] = False
    elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE
          or synchronization == variable_scope.VariableSynchronization.AUTO):
        # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
        is_sync_on_read = False
    else:
        raise ValueError(
            "Invalid variable synchronization mode: %s for variable: %s" %
            (synchronization, kwargs["name"]))

    # Get aggregation value
    aggregation = kwargs.pop("aggregation",
                             variable_scope.VariableAggregation.NONE)
    if aggregation not in (
            variable_scope.VariableAggregation.NONE,
            variable_scope.VariableAggregation.SUM,
            variable_scope.VariableAggregation.MEAN,
            variable_scope.VariableAggregation.ONLY_FIRST_REPLICA):
        raise ValueError(
            "Invalid variable aggregation mode: %s for variable: %s" %
            (aggregation, kwargs["name"]))

    # Ignore user-specified caching device, not needed for mirrored variables.
    kwargs.pop("caching_device", None)

    # TODO(josh11b,apassos): It would be better if variable initialization
    # was never recorded on the tape instead of having to do this manually
    # here.
    with tape.stop_recording():
        devices = device_map.logical_to_actual_devices(logical_device)
        value_list = real_mirrored_creator(devices, *args, **kwargs)

        if is_sync_on_read:
            result = values.SyncOnReadVariable(strategy,
                                               device_map,
                                               value_list,
                                               aggregation,
                                               logical_device=logical_device)
        else:
            result = values.MirroredVariable(strategy,
                                             device_map,
                                             value_list,
                                             aggregation,
                                             logical_device=logical_device)

    # Add the wrapped variable to the requested collections.
    # The handling of eager mode and the global step matches
    # ResourceVariable._init_from_args().
    if not context.executing_eagerly():
        g = ops.get_default_graph()
        # If "trainable" is True, next_creator() will add the member variables
        # to the TRAINABLE_VARIABLES collection, so we manually remove
        # them and replace with the MirroredVariable. We can't set
        # "trainable" to False for next_creator() since that causes functions
        # like implicit_gradients to skip those variables.
        if kwargs.get("trainable", True):
            collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
            for v in value_list:
                if v in l:
                    l.remove(v)
        g.add_to_collections(collections, result)
    elif ops.GraphKeys.GLOBAL_STEP in collections:
        ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)

    return result