コード例 #1
0
def _make_replica_local(method, strategy=None):
  device_map = values.ReplicaDeviceMap(_devices)
  v = []
  for d, n, init in zip(_devices, ["v", "v/replica"], [1., 2.]):
    with ops.device(d):
      v.append(variable_scope.get_variable(
          name=n, initializer=init, use_resource=True))
  replica_local = values.SyncOnReadVariable(strategy, device_map, v, method)
  return v, replica_local
コード例 #2
0
  def testVariableOnAnotherDevice(self):
    v = variable_scope.get_variable(
        name="v", initializer=[1.], use_resource=True)
    device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",))
    replica_local = values.SyncOnReadVariable(
        None, device_map, (v,), variable_scope.VariableAggregation.MEAN)

    self.assertEqual(v.name, replica_local.name)
    self.assertEqual(v.dtype, replica_local.dtype)
    self.assertEqual(v.shape, replica_local.shape)
    self.assertEqual(variable_scope.VariableAggregation.MEAN,
                     replica_local.aggregation)
コード例 #3
0
def _create_mirrored_variable(
        strategy,
        device_map,
        logical_device,  # pylint: disable=missing-docstring
        real_mirrored_creator,
        *args,
        **kwargs):
    # Figure out what collections this variable should be added to.
    # We'll add the MirroredVariable to those collections instead.
    collections = kwargs.pop("collections", None)
    if collections is None:
        collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    kwargs["collections"] = []

    # Get synchronization value
    synchronization = kwargs.get(
        "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
    if synchronization == variable_scope.VariableSynchronization.NONE:
        raise ValueError(
            "`NONE` variable synchronization mode is not "
            "supported with `Mirrored` distribution strategy. Please"
            " change the `synchronization` for variable: " + kwargs["name"])
    elif synchronization == variable_scope.VariableSynchronization.ON_READ:
        # Variables that are to be synced on read are replica local.
        is_sync_on_read = True
        kwargs["trainable"] = False
    elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE
          or synchronization == variable_scope.VariableSynchronization.AUTO):
        # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
        is_sync_on_read = False
    else:
        raise ValueError(
            "Invalid variable synchronization mode: %s for variable: %s" %
            (synchronization, kwargs["name"]))

    # Get aggregation value
    aggregation = kwargs.pop("aggregation",
                             variable_scope.VariableAggregation.NONE)
    if aggregation not in (
            variable_scope.VariableAggregation.NONE,
            variable_scope.VariableAggregation.SUM,
            variable_scope.VariableAggregation.MEAN,
            variable_scope.VariableAggregation.ONLY_FIRST_REPLICA):
        raise ValueError(
            "Invalid variable aggregation mode: %s for variable: %s" %
            (aggregation, kwargs["name"]))

    # Ignore user-specified caching device, not needed for mirrored variables.
    kwargs.pop("caching_device", None)

    # TODO(josh11b,apassos): It would be better if variable initialization
    # was never recorded on the tape instead of having to do this manually
    # here.
    with tape.stop_recording():
        devices = device_map.logical_to_actual_devices(logical_device)
        value_list = real_mirrored_creator(devices, *args, **kwargs)

        if is_sync_on_read:
            result = values.SyncOnReadVariable(strategy,
                                               device_map,
                                               value_list,
                                               aggregation,
                                               logical_device=logical_device)
        else:
            result = values.MirroredVariable(strategy,
                                             device_map,
                                             value_list,
                                             aggregation,
                                             logical_device=logical_device)

    # Add the wrapped variable to the requested collections.
    # The handling of eager mode and the global step matches
    # ResourceVariable._init_from_args().
    if not context.executing_eagerly():
        g = ops.get_default_graph()
        # If "trainable" is True, next_creator() will add the member variables
        # to the TRAINABLE_VARIABLES collection, so we manually remove
        # them and replace with the MirroredVariable. We can't set
        # "trainable" to False for next_creator() since that causes functions
        # like implicit_gradients to skip those variables.
        if kwargs.get("trainable", True):
            collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
            for v in value_list:
                if v in l:
                    l.remove(v)
        g.add_to_collections(collections, result)
    elif ops.GraphKeys.GLOBAL_STEP in collections:
        ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)

    return result