Exemple #1
0
    def _create_variable(self, next_creator, *args, **kwargs):
        """Create a mirrored variable. See `DistributionStrategy.scope`."""
        # Figure out what collections this variable should be added to.
        # We'll add the MirroredVariable to those collections instead.
        collections = kwargs.pop("collections", None)
        if collections is None:
            collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        kwargs["collections"] = []

        colocate_with = kwargs.pop("colocate_with", None)
        devices = self._get_devices_from(colocate_with)

        tower_local = kwargs.pop("tower_local_reduce_method", None)
        if tower_local is not None:
            kwargs["trainable"] = False

        # TODO(josh11b,apassos): It would be better if variable initialization
        # was never recorded on the tape instead of having to do this manually
        # here.
        with tape.stop_recording():
            index = {}
            for i, d in enumerate(devices):
                with ops.device(d):
                    if i > 0:
                        # Give replicas meaningful distinct names:
                        var0name = index[devices[0]].name.split(":")[0]
                        kwargs["name"] = "%s/replica_%d" % (var0name, i)
                        # Initialize replicas with the same value:
                        if context.executing_eagerly():
                            initial_value = index[devices[0]].value()
                        else:
                            initial_value = index[devices[0]].initial_value
                        kwargs["initial_value"] = array_ops.identity(
                            initial_value)
                    with context.context().device_policy(
                            context.DEVICE_PLACEMENT_SILENT):
                        v = next_creator(*args, **kwargs)
                    assert not isinstance(v, values.DistributedVariable)
                    index[d] = v

            if tower_local is None:
                result = values.MirroredVariable(index, index[devices[0]])
            else:
                result = values.TowerLocalVariable(index, index[devices[0]],
                                                   tower_local)

        if not context.executing_eagerly():
            g = ops.get_default_graph()
            # If "trainable" is True, next_creator() will add the member variables
            # to the TRAINABLE_VARIABLES collection, so we manually remove
            # them and replace with the MirroredVariable. We can't set
            # "trainable" to False for next_creator() since that causes functions
            # like implicit_gradients to skip those variables.
            if kwargs.get("trainable", True):
                collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
                l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
                for v in index.values():
                    l.remove(v)
            g.add_to_collections(collections, result)
        return result
Exemple #2
0
  def testVariableOnAnotherDevice(self):
    v = variable_scope.get_variable(
        name="v", initializer=[1.], use_resource=True)
    index = {"/job:foo/device:CPU:0": v}
    mirrored = values.MirroredVariable(index, v)

    self.assertEquals(v.name, mirrored.name)
    self.assertEquals(v.dtype, mirrored.dtype)
    self.assertEquals(v.shape, mirrored.shape)
Exemple #3
0
def _make_mirrored():
  v = []
  index = {}
  devices = ["/device:GPU:0", "/device:CPU:0"]
  for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
    with ops.device(d):
      v.append(variable_scope.get_variable(
          name=n, initializer=init, use_resource=True))
      index[d] = v[-1]
  mirrored = values.MirroredVariable(index, v[0])
  return v, devices, mirrored
Exemple #4
0
  def testFetchAMirroredVariable(self):
    if context.num_gpus() < 1 or context.executing_eagerly():
      self.skipTest("A GPU is not available for this test or it's eager mode.")

    with self.test_session(
        graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy(
            ["/device:GPU:0"]).scope():
      with ops.device("/device:GPU:0"):
        v = variable_scope.get_variable(
            name="v", initializer=1., use_resource=True)
      mirrored = values.MirroredVariable({"/device:GPU:0": v}, v)
      sess.run(variables_lib.global_variables_initializer())
      sess.run({"complicated": mirrored})
Exemple #5
0
  def testOneDevice(self):
    result = values.regroup({_device_str(0): _nested_value("1")})
    # On one device regroup() and select_device() are basically identity.
    self.assertEqual(_nested_value("1"), result)
    self.assertEqual(_nested_value("1"),
                     values.select_device(_device_str(0), result))

    # The one exception has to do with MirroredVariables.
    d = "/device:CPU:0"
    with ops.device(d):
      v = variable_scope.get_variable(
          name="v", initializer=1., use_resource=True)
      index = {d: v}
    mirrored = values.MirroredVariable(index, v)
    result = values.regroup(index)
    self.assertIs(mirrored, result)
def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):  # pylint: disable=g-missing-docstring
    # Figure out what collections this variable should be added to.
    # We'll add the MirroredVariable to those collections instead.
    collections = kwargs.pop("collections", None)
    if collections is None:
        collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    kwargs["collections"] = []

    # Get synchronization value
    synchronization = kwargs.get(
        "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
    if synchronization == variable_scope.VariableSynchronization.NONE:
        raise ValueError(
            "`NONE` variable synchronization mode is not "
            "supported with `Mirrored` distribution strategy. Please"
            " change the `synchronization` for variable: " + kwargs["name"])
    elif synchronization == variable_scope.VariableSynchronization.ON_READ:
        # Variables that are to be synced on read are replica local.
        is_replica_local = True
        kwargs["trainable"] = False
    elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE
          or synchronization == variable_scope.VariableSynchronization.AUTO):
        # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
        is_replica_local = False
    else:
        raise ValueError("Invalid variable synchronization mode: " +
                         synchronization + " for variable: " + kwargs["name"])

    # Get aggregation value
    aggregation = kwargs.pop("aggregation",
                             variable_scope.VariableAggregation.NONE)
    if aggregation not in (
            variable_scope.VariableAggregation.NONE,
            variable_scope.VariableAggregation.SUM,
            variable_scope.VariableAggregation.MEAN,
            variable_scope.VariableAggregation.ONLY_FIRST_REPLICA):
        raise ValueError("Invalid variable aggregation mode: " + aggregation +
                         " for variable: " + kwargs["name"])

    # Ignore user-specified caching device, not needed for mirrored variables.
    kwargs.pop("caching_device", None)

    # TODO(josh11b,apassos): It would be better if variable initialization
    # was never recorded on the tape instead of having to do this manually
    # here.
    with tape.stop_recording():
        index = real_mirrored_creator(devices, *args, **kwargs)

        if is_replica_local:
            result = values.ReplicaLocalVariable(index, index[devices[0]],
                                                 aggregation)
        else:
            result = values.MirroredVariable(index, index[devices[0]],
                                             aggregation)

    # Add the wrapped variable to the requested collections.
    # The handling of eager mode and the global step matches
    # ResourceVariable._init_from_args().
    if not context.executing_eagerly():
        g = ops.get_default_graph()
        # If "trainable" is True, next_creator() will add the member variables
        # to the TRAINABLE_VARIABLES collection, so we manually remove
        # them and replace with the MirroredVariable. We can't set
        # "trainable" to False for next_creator() since that causes functions
        # like implicit_gradients to skip those variables.
        if kwargs.get("trainable", True):
            collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
            for v in index.values():
                if v in l:
                    l.remove(v)
        g.add_to_collections(collections, result)
    elif ops.GraphKeys.GLOBAL_STEP in collections:
        ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)

    return result
Exemple #7
0
    def _create_variable(self, next_creator, *args, **kwargs):
        """Create a mirrored variable. See `DistributionStrategy.scope`."""
        # Figure out what collections this variable should be added to.
        # We'll add the MirroredVariable to those collections instead.
        collections = kwargs.pop("collections", None)
        if collections is None:
            collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        kwargs["collections"] = []

        colocate_with = kwargs.pop("colocate_with", None)
        devices = self._get_devices_from(colocate_with)

        # Get synchronization value
        synchronization = kwargs.get(
            "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
        if synchronization == variable_scope.VariableSynchronization.NONE:
            raise ValueError(
                "`NONE` variable synchronization mode is not "
                "supported with `Mirrored` distribution strategy. Please"
                " change the `synchronization` for variable: " +
                kwargs["name"])
        elif synchronization == variable_scope.VariableSynchronization.ON_READ:
            # Variables that are to be synced on read are tower local.
            is_tower_local = True
            kwargs["trainable"] = False
        elif (synchronization
              == variable_scope.VariableSynchronization.ON_WRITE or
              synchronization == variable_scope.VariableSynchronization.AUTO):
            # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
            is_tower_local = False
        else:
            raise ValueError("Invalid variable synchronization mode: " +
                             synchronization + " for  variable: " +
                             kwargs["name"])

        # Get aggregation value
        aggregation = kwargs.pop("aggregation",
                                 variable_scope.VariableAggregation.NONE)
        if aggregation not in [a for a in variable_scope.VariableAggregation]:
            raise ValueError("Invalid variable aggregation mode: " +
                             aggregation + " for variable: " + kwargs["name"])

        # Ignore user-specified caching device, not needed for mirrored variables.
        kwargs.pop("caching_device", None)

        # TODO(josh11b,apassos): It would be better if variable initialization
        # was never recorded on the tape instead of having to do this manually
        # here.
        with tape.stop_recording():
            index = {}
            for i, d in enumerate(devices):
                with ops.device(d):
                    if i > 0:
                        # Give replicas meaningful distinct names:
                        var0name = index[devices[0]].name.split(":")[0]
                        # We append a / to variable names created on towers with id > 0 to
                        # ensure that we ignore the name scope and instead use the given
                        # name as the absolute name of the variable.
                        kwargs["name"] = "%s/replica_%d/" % (var0name, i)
                        # Initialize replicas with the same value:
                        if context.executing_eagerly():
                            kwargs["initial_value"] = array_ops.identity(
                                index[devices[0]].value())
                        else:

                            def initial_value_fn(device=d):
                                with ops.device(device):
                                    return array_ops.identity(
                                        index[devices[0]].initial_value)

                            kwargs["initial_value"] = initial_value_fn
                    with context.context().device_policy(
                            context.DEVICE_PLACEMENT_SILENT):
                        v = next_creator(*args, **kwargs)
                    assert not isinstance(v, values.DistributedVariable)
                    index[d] = v

            if is_tower_local:
                result = values.TowerLocalVariable(index, index[devices[0]],
                                                   aggregation)
            else:
                result = values.MirroredVariable(index, index[devices[0]],
                                                 aggregation)

        if not context.executing_eagerly():
            g = ops.get_default_graph()
            # If "trainable" is True, next_creator() will add the member variables
            # to the TRAINABLE_VARIABLES collection, so we manually remove
            # them and replace with the MirroredVariable. We can't set
            # "trainable" to False for next_creator() since that causes functions
            # like implicit_gradients to skip those variables.
            if kwargs.get("trainable", True):
                collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
                l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
                for v in index.values():
                    l.remove(v)
            g.add_to_collections(collections, result)
        return result