예제 #1
0
def _make_replica_local(method):
    v = []
    index = {}
    for d, n, init in zip(_devices, ["v", "v/replica"], [1., 2.]):
        with ops.device(d):
            v.append(
                variable_scope.get_variable(name=n,
                                            initializer=init,
                                            use_resource=True))
            index[d] = v[-1]
    replica_local = values.ReplicaLocalVariable(index, v[0], method)
    return v, replica_local
예제 #2
0
    def testVariableOnAnotherDevice(self):
        v = variable_scope.get_variable(name="v",
                                        initializer=[1.],
                                        use_resource=True)
        index = {"/job:foo/device:CPU:0": v}
        replica_local = values.ReplicaLocalVariable(
            index, v, variable_scope.VariableAggregation.MEAN)

        self.assertEquals(v.name, replica_local.name)
        self.assertEquals(v.dtype, replica_local.dtype)
        self.assertEquals(v.shape, replica_local.shape)
        self.assertEquals(variable_scope.VariableAggregation.MEAN,
                          replica_local.aggregation)
def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):  # pylint: disable=g-missing-docstring
    # Figure out what collections this variable should be added to.
    # We'll add the MirroredVariable to those collections instead.
    collections = kwargs.pop("collections", None)
    if collections is None:
        collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    kwargs["collections"] = []

    # Get synchronization value
    synchronization = kwargs.get(
        "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
    if synchronization == variable_scope.VariableSynchronization.NONE:
        raise ValueError(
            "`NONE` variable synchronization mode is not "
            "supported with `Mirrored` distribution strategy. Please"
            " change the `synchronization` for variable: " + kwargs["name"])
    elif synchronization == variable_scope.VariableSynchronization.ON_READ:
        # Variables that are to be synced on read are replica local.
        is_replica_local = True
        kwargs["trainable"] = False
    elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE
          or synchronization == variable_scope.VariableSynchronization.AUTO):
        # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
        is_replica_local = False
    else:
        raise ValueError("Invalid variable synchronization mode: " +
                         synchronization + " for variable: " + kwargs["name"])

    # Get aggregation value
    aggregation = kwargs.pop("aggregation",
                             variable_scope.VariableAggregation.NONE)
    if aggregation not in (
            variable_scope.VariableAggregation.NONE,
            variable_scope.VariableAggregation.SUM,
            variable_scope.VariableAggregation.MEAN,
            variable_scope.VariableAggregation.ONLY_FIRST_REPLICA):
        raise ValueError("Invalid variable aggregation mode: " + aggregation +
                         " for variable: " + kwargs["name"])

    # Ignore user-specified caching device, not needed for mirrored variables.
    kwargs.pop("caching_device", None)

    # TODO(josh11b,apassos): It would be better if variable initialization
    # was never recorded on the tape instead of having to do this manually
    # here.
    with tape.stop_recording():
        index = real_mirrored_creator(devices, *args, **kwargs)

        if is_replica_local:
            result = values.ReplicaLocalVariable(index, index[devices[0]],
                                                 aggregation)
        else:
            result = values.MirroredVariable(index, index[devices[0]],
                                             aggregation)

    # Add the wrapped variable to the requested collections.
    # The handling of eager mode and the global step matches
    # ResourceVariable._init_from_args().
    if not context.executing_eagerly():
        g = ops.get_default_graph()
        # If "trainable" is True, next_creator() will add the member variables
        # to the TRAINABLE_VARIABLES collection, so we manually remove
        # them and replace with the MirroredVariable. We can't set
        # "trainable" to False for next_creator() since that causes functions
        # like implicit_gradients to skip those variables.
        if kwargs.get("trainable", True):
            collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
            for v in index.values():
                if v in l:
                    l.remove(v)
        g.add_to_collections(collections, result)
    elif ops.GraphKeys.GLOBAL_STEP in collections:
        ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)

    return result