def var_creator(*args, **kwargs):
                """Create an AggregatingVariable and fix up collections."""
                # Record what collections this variable should be added to.
                collections = kwargs.pop("collections", None)
                if collections is None:
                    collections = [ops.GraphKeys.GLOBAL_VARIABLES]
                kwargs["collections"] = []

                # Create and wrap the variable.
                v = next_creator(*args, **kwargs)
                wrapped = values.AggregatingVariable(
                    self._container_strategy(), v, aggregation)

                # Add the wrapped variable to the requested collections.
                # The handling of eager mode and the global step matches
                # ResourceVariable._init_from_args().
                if not context.executing_eagerly():
                    g = ops.get_default_graph()
                    # If "trainable" is True, next_creator() will add the contained
                    # variable to the TRAINABLE_VARIABLES collection, so we manually
                    # remove it and replace with the wrapper. We can't set "trainable"
                    # to False for next_creator() since that causes functions like
                    # implicit_gradients to skip those variables.
                    if kwargs.get("trainable", True):
                        collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
                        l = g.get_collection_ref(
                            ops.GraphKeys.TRAINABLE_VARIABLES)
                        if v in l:
                            l.remove(v)
                    g.add_to_collections(collections, wrapped)
                elif ops.GraphKeys.GLOBAL_STEP in collections:
                    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)

                return wrapped
예제 #2
0
    def test_supports_distributed_variables(self):
        mirrored = distributed_values.MirroredVariable(
            None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
        tpu = tpu_values.TPUMirroredVariable(strategy=None,
                                             values=[variables.Variable(42.)],
                                             aggregation=None)
        aggregating = distributed_values.AggregatingVariable(
            strategy=None, v=variables.Variable(1.), aggregation=None)

        m = module.Module()
        m.a = mirrored
        m.b = tpu
        m.c = aggregating
        self.assertEqual(m.variables, (mirrored, tpu, aggregating))