Beispiel #1
0
    def test_get_state_no_shape_raises(self):
        with base.new_context():
            with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
                base.get_state("i", init=jnp.zeros)

        with base.new_context(state={"~": {}}):
            with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
                base.get_state("i", init=jnp.zeros)
Beispiel #2
0
    def test_get_state_no_init_raises(self):
        with base.new_context():
            with self.assertRaisesRegex(ValueError, "set an init function"):
                base.get_state("i")

        with base.new_context(state={"~": {}}):
            with self.assertRaisesRegex(ValueError, "set an init function"):
                base.get_state("i")
Beispiel #3
0
  def initialize(self, shape, dtype=jnp.float32):
    """If uninitialized sets the average to ``zeros`` of the given shape/dtype."""
    if hasattr(shape, "shape"):
      warnings.warn("Passing a value into initialize instead of a shape/dtype "
                    "is deprecated. Update your code to use: "
                    "`ema.initialize(v.shape, v.dtype)`.",
                    category=DeprecationWarning)
      shape, dtype = shape.shape, shape.dtype

    base.get_state("hidden", shape, dtype, init=jnp.zeros)
    base.get_state("average", shape, dtype, init=jnp.zeros)
Beispiel #4
0
 def test_difference_update_state(self):
   base.get_state("a", [], init=jnp.zeros)
   base.get_state("b", [], init=jnp.zeros)
   before = stateful.internal_state()
   base.set_state("b", jnp.ones([]))
   after = stateful.internal_state()
   diff = stateful.difference(before, after)
   self.assertEmpty(diff.params)
   self.assertEqual(diff.state, {"~": {"a": None,
                                       "b": base.StatePair(0., 1.)}})
   self.assertIsNone(diff.rng)
Beispiel #5
0
    def test_set_then_get(self):
        with base.new_context() as ctx:
            base.set_state("i", 1)
            base.get_state("i")

        self.assertEqual(ctx.collect_initial_state(), {"~": {"i": 1}})

        for _ in range(10):
            with ctx:
                base.set_state("i", 1)
                y = base.get_state("i")
                self.assertEqual(y, 1)
            self.assertEqual(ctx.collect_initial_state(), {"~": {"i": 1}})
Beispiel #6
0
    def test_do_not_store(self):
        def my_creator(next_creator, shape, dtype, init, context):
            del next_creator, shape, dtype, init, context
            return base.DO_NOT_STORE

        def my_getter(next_getter, value, context):
            assert value is base.DO_NOT_STORE
            return next_getter(
                context.original_init(context.original_shape,
                                      context.original_dtype))

        def my_setter(next_setter, value, context):
            del next_setter, value, context
            return base.DO_NOT_STORE

        with base.new_context() as ctx:
            with base.custom_creator(my_creator, state=True), \
                 base.custom_getter(my_getter, state=True), \
                 base.custom_setter(my_setter):
                self.assertEqual(base.get_parameter("w", [], init=jnp.ones), 1)
                self.assertEqual(base.get_state("s1", [], init=jnp.ones), 1)
                base.set_state("s2", jnp.ones([]))

        self.assertEmpty(ctx.collect_params())
        self.assertEmpty(ctx.collect_state())
Beispiel #7
0
    def test_setter_tree(self):
        witness = []
        x = {"a": jnp.ones([]), "b": jnp.zeros([123])}
        y = jax.tree_map(lambda x: x + 1, x)

        def my_setter(next_setter, value, ctx):
            self.assertIs(value, x)
            self.assertEqual(ctx.original_shape, {"a": (), "b": (123, )})
            self.assertEqual(ctx.original_dtype, {
                "a": jnp.float32,
                "b": jnp.float32
            })
            self.assertEqual(ctx.full_name, "~/x")
            self.assertEqual(ctx.name, "x")
            self.assertIsNone(ctx.module)
            witness.append(None)
            del next_setter
            return y

        with base.new_context():
            with base.custom_setter(my_setter):
                base.set_state("x", x)
                x = base.get_state("x")
                self.assertIs(x, y)

        self.assertNotEmpty(witness)
Beispiel #8
0
 def test_get_state_no_init_raises(self):
   init_fn, apply_fn = base.transform_with_state(lambda: base.get_state("i"))
   with self.assertRaisesRegex(ValueError, "set an init function"):
     init_fn(None)
   state = params = {"~": {}}
   with self.assertRaisesRegex(ValueError, "set an init function"):
     apply_fn(params, state, None)
Beispiel #9
0
 def test_get_state_no_init(self):
     _, apply_fn = transform.transform_with_state(
         lambda: base.get_state("i"))
     for i in range(10):
         state_in = {"~": {"i": i}}
         _, state_out = apply_fn({}, state_in, None)
         self.assertEqual(state_in, state_out)
Beispiel #10
0
  def test_getter_types(self, params, state):
    log = []
    def logging_getter(next_getter, value, context):
      log.append(context.full_name)
      return next_getter(value)

    with base.new_context():
      with base.custom_getter(logging_getter, params=params, state=state):
        base.get_parameter("params", [], init=jnp.zeros)
        base.get_state("state", [], init=jnp.zeros)

    self.assertLen(log, int(params) + int(state))
    if params:
      self.assertIn("~/params", log)
    if state:
      self.assertIn("~/state", log)
Beispiel #11
0
    def __call__(self, value, update_stats=True, error_on_non_matrix=False):
        """Performs Spectral Normalization and returns the new value.

    Args:
      value: The array-like object for which you would like to perform an
        spectral normalization on.
      update_stats: A boolean defaulting to True. Regardless of this arg, this
        function will return the normalized input. When
        `update_stats` is True, the internal state of this object will also be
        updated to reflect the input value. When `update_stats` is False the
        internal stats will remain unchanged.
      error_on_non_matrix: Spectral normalization is only defined on matrices.
        By default, this module will return scalars unchanged and flatten
        higher-order tensors in their leading dimensions. Setting this flag to
        True will instead throw errors in those cases.
    Returns:
      The input value normalized by it's first singular value.
    Raises:
      ValueError: If `error_on_non_matrix` is True and `value` has ndims > 2.
    """
        value = jnp.asarray(value)
        value_shape = value.shape

        # Handle scalars.
        if value.ndim <= 1:
            raise ValueError("Spectral normalization is not well defined for "
                             "scalar or vector inputs.")
        # Handle higher-order tensors.
        elif value.ndim > 2:
            if error_on_non_matrix:
                raise ValueError(
                    "Input is {}D but error_on_non_matrix is True".format(
                        value.ndim))
            else:
                value = jnp.reshape(value, [-1, value.shape[-1]])

        u0 = base.get_state("u0",
                            shape=[1, value.shape[-1]],
                            dtype=value.dtype,
                            init=initializers.RandomNormal())

        # Power iteration for the weight's singular value.
        for _ in range(self._n_steps):
            v0 = _l2_normalize(jnp.matmul(u0, value.transpose([1, 0])),
                               eps=self._eps)
            u0 = _l2_normalize(jnp.matmul(v0, value), eps=self._eps)

        u0 = jax.lax.stop_gradient(u0)
        v0 = jax.lax.stop_gradient(v0)

        sigma = jnp.matmul(jnp.matmul(v0, value), jnp.transpose(u0))[0, 0]

        value /= sigma
        value_bar = value.reshape(value_shape)

        if update_stats:
            base.set_state("u0", u0)
            base.set_state("sigma", sigma)
        return value_bar
Beispiel #12
0
    def test_stateful(self):
        with base.new_context() as ctx:
            for _ in range(10):
                count = base.get_state("count", (), jnp.int32, jnp.zeros)
                base.set_state("count", count + 1)

        self.assertEqual(ctx.collect_initial_state(), {"~": {"count": 0}})
        self.assertEqual(ctx.collect_state(), {"~": {"count": 10}})
Beispiel #13
0
 def test_get_state_no_shape_raises(self):
     init_fn, apply_fn = transform.transform_with_state(
         lambda: base.get_state("i", init=jnp.zeros))
     with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
         init_fn(None)
     state = params = {"~": {}}
     with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
         apply_fn(params, state, None)
    def __call__(self, value, update_stats=True):
        """Updates the EMA and returns the new value.

    Args:
      value: The array-like object for which you would like to perform an
        exponential decay on.
      update_stats: A Boolean, whether to update the internal state
        of this object to reflect the input value. When `update_stats` is False
        the internal stats will remain unchanged.

    Returns:
      The exponentially weighted average of the input value.
    """
        if not isinstance(value, jnp.ndarray):
            value = jnp.asarray(value)

        counter = base.get_state(
            "counter", (),
            jnp.int32,
            init=initializers.Constant(-self._warmup_length))
        counter += 1

        decay = jax.lax.convert_element_type(self._decay, value.dtype)
        if self._warmup_length > 0:
            decay = self._cond(counter <= 0, 0.0, decay, value.dtype)

        one = jnp.ones([], value.dtype)
        hidden = base.get_state("hidden",
                                value.shape,
                                value.dtype,
                                init=jnp.zeros)
        hidden = hidden * decay + value * (one - decay)

        average = hidden
        if self._zero_debias:
            average /= (one - jnp.power(decay, counter))

        if update_stats:
            base.set_state("counter", counter)
            base.set_state("hidden", hidden)
            base.set_state("average", average)

        return average
Beispiel #15
0
    def __call__(self, value, update_stats=True):
        """Updates the EMA and returns the new value.

    Args:
      value: The array-like object for which you would like to perform an
        exponential decay on.
      update_stats: A Boolean, whether to update the internal state
        of this object to reflect the input value. When `update_stats` is False
        the internal stats will remain unchanged.

    Returns:
      The exponentially weighted average of the input value.

    """
        value = jnp.asarray(value)  # Ensure value has a dtype.
        prev_counter = base.get_state(
            "counter",
            shape=(),
            dtype=jnp.int32,
            init=initializers.Constant(-self._warmup_length))
        prev_hidden = base.get_state("hidden",
                                     shape=value.shape,
                                     dtype=value.dtype,
                                     init=jnp.zeros)

        decay = jnp.asarray(self._decay).astype(value.dtype)
        counter = prev_counter + 1
        decay = self._cond(jnp.less_equal(counter, 0), 0.0, decay, value.dtype)
        hidden = prev_hidden * decay + value * (1 - decay)

        if self._zero_debias:
            average = hidden / (1. - jnp.power(decay, counter))
        else:
            average = hidden

        if update_stats:
            base.set_state("counter", counter)
            base.set_state("hidden", hidden)
            base.set_state("average", average)
        return average
Beispiel #16
0
 def maybe_initialize(self, shape, dtype):
     """If uninitialized sets the average to ``0`` of the given shape/dtype."""
     base.get_state("counter", (),
                    jnp.int32,
                    init=initializers.Constant(-self._warmup_length))
     base.get_state("hidden", shape, dtype, init=jnp.zeros)
     base.get_state("average", shape, dtype, init=jnp.zeros)
Beispiel #17
0
        def f():
            w = base.get_parameter('w', [], init=jnp.zeros)
            s = base.get_state('s', [], init=jnp.zeros)
            init = lambda: None

            def add():
                s_add = base.get_state('s', [], init=jnp.zeros)
                w_add = base.get_parameter('w', [], init=jnp.zeros)
                return w, w_add, s, s_add

            def sub():
                s_sub = base.get_state('s', [], init=jnp.zeros)
                w_sub = base.get_parameter('w', [], init=jnp.zeros)
                return w, w_sub, s, s_sub

            return init, (add, sub)
Beispiel #18
0
  def __init__(self,
               embedding_dim,
               num_embeddings,
               commitment_cost,
               decay,
               epsilon: float = 1e-5,
               dtype: DType = jnp.float32,
               name: str = None):
    """Initializes a VQ-VAE EMA module.

    Args:
      embedding_dim: integer representing the dimensionality of the tensors in
        the quantized space. Inputs to the modules must be in this format as
        well.
      num_embeddings: integer, the number of vectors in the quantized space.
      commitment_cost: scalar which controls the weighting of the loss terms
        (see equation 4 in the paper - this variable is Beta).
      decay: float between 0 and 1, controls the speed of the Exponential Moving
        Averages.
      epsilon: small constant to aid numerical stability, default 1e-5.
      dtype: dtype for the embeddings variable, defaults to tf.float32.
      name: name of the module.
    """
    super(VectorQuantizerEMA, self).__init__(name=name)
    self.embedding_dim = embedding_dim
    self.num_embeddings = num_embeddings
    if not 0 <= decay <= 1:
      raise ValueError('decay must be in range [0, 1]')
    self.decay = decay
    self.commitment_cost = commitment_cost
    self.epsilon = epsilon

    embedding_shape = [embedding_dim, num_embeddings]
    initializer = initializers.VarianceScaling(distribution='uniform')
    embeddings = base.get_state('embeddings', embedding_shape, dtype,
                                init=initializer)

    self.ema_cluster_size = moving_averages.ExponentialMovingAverage(
        decay=self.decay, name='ema_cluster_size')
    self.ema_cluster_size.initialize(jnp.zeros([num_embeddings], dtype=dtype))

    self.ema_dw = moving_averages.ExponentialMovingAverage(
        decay=self.decay, name='ema_dw')
    self.ema_dw.initialize(embeddings)
Beispiel #19
0
    def test_setter_array(self):
        witness = []
        x = jnp.ones([])
        y = x + 1

        def my_setter(next_setter, value, context):
            self.assertIs(value, x)
            self.assertEqual(context.original_shape, value.shape)
            self.assertEqual(context.original_dtype, value.dtype)
            self.assertEqual(context.full_name, "~/x")
            self.assertEqual(context.name, "x")
            self.assertIsNone(context.module)
            witness.append(None)
            del next_setter
            return y

        with base.new_context():
            with base.custom_setter(my_setter):
                base.set_state("x", x)
                x = base.get_state("x")
                self.assertIs(x, y)

        self.assertNotEmpty(witness)
Beispiel #20
0
 def average(self):
     return base.get_state("average")
Beispiel #21
0
 def __call__(self):
     for _ in range(10):
         count = base.get_state("count", (), jnp.int32, jnp.zeros)
         base.set_state("count", count + 1)
     return count
Beispiel #22
0
 def __call__(self):
     return base.get_state("w", [], init=jnp.zeros)
Beispiel #23
0
 def sigma(self):
   return base.get_state("sigma", shape=(), init=jnp.ones)
Beispiel #24
0
 def u0(self):
   return base.get_state("u0")
Beispiel #25
0
 def net():
     base.set_state("i", 1)
     return base.get_state("i")
Beispiel #26
0
 def f():
   base.get_parameter("w", [], init=jnp.zeros)
   base.get_state("w", [], init=jnp.zeros)
Beispiel #27
0
 def test_lift_raises_with_state(self):
     f = transform.transform_with_state(
         lambda: base.get_state("w", [], init=jnp.zeros))
     lifted = lift.lift(f.init)  # pytype: disable=wrong-arg-types
     with self.assertRaisesRegex(ValueError, "use.*lift_with_state"):
         lifted(None)
Beispiel #28
0
 def inner():
     w = base.get_state("w", [], init=jnp.zeros)
     w += 1
     base.set_state("w", w)
     return w
Beispiel #29
0
 def count(self):
   return base.get_state("count", [], init=jnp.zeros)
Beispiel #30
0
 def embeddings(self):
   return base.get_state('embeddings')