Ejemplo n.º 1
0
    def __init__(self,
                 create_scale,
                 create_offset,
                 decay_rate,
                 eps=1e-5,
                 scale_init=None,
                 offset_init=None,
                 axis=None,
                 cross_replica_axis=None,
                 data_format="channels_last",
                 name=None):
        """Constructs a BatchNorm module.

    Args:
      create_scale: Whether to include a trainable scaling factor.
      create_offset: Whether to include a trainable offset.
      decay_rate: Decay rate for EMA.
      eps: Small epsilon to avoid division by zero variance. Defaults 1e-5, as
        in the paper and Sonnet.
      scale_init: Optional initializer for gain (aka scale). Can only be set
        if `create_scale=True`. By default, one.
      offset_init: Optional initializer for bias (aka offset). Can only be set
        if `create_offset=True`. By default, zero.
      axis: Which axes to reduce over. The default (None)
        signifies that all but the channel axis should be normalized. Otherwise
        this is a list of axis indices which will have normalization
        statistics calculated.
      cross_replica_axis: If not None, it should be a string representing
        the axis name over which this module is being run within a jax.pmap.
        Supplying this argument means that batch statistics are calculated
        across all replicas on that axis.
      data_format: The data format of the input. Can be either
        `channels_first`, `channels_last`, `N...C` or `NC...`. By
        default it is `channels_last`.
      name: The module name.
    """
        super(BatchNorm, self).__init__(name=name)
        self._create_scale = create_scale
        self._create_offset = create_offset
        if not self._create_scale and scale_init is not None:
            raise ValueError("Cannot set `scale_init` if `create_scale=False`")
        self._scale_init = scale_init or jnp.ones
        if not self._create_offset and offset_init is not None:
            raise ValueError(
                "Cannot set `offset_init` if `create_offset=False`")
        self._offset_init = offset_init or jnp.zeros
        self._eps = eps

        self._cross_replica_axis = cross_replica_axis
        self._data_format = data_format
        self._channel_index = utils.get_channel_index(data_format)
        self._axis = axis

        self._mean_ema = moving_averages.ExponentialMovingAverage(
            decay_rate, name="mean_ema")
        self._var_ema = moving_averages.ExponentialMovingAverage(
            decay_rate, name="var_ema")
Ejemplo n.º 2
0
 def test_fast_slow_decay_without_update(self):
     ema_fast = moving_averages.ExponentialMovingAverage(0.5)
     ema_slow = moving_averages.ExponentialMovingAverage(0.8)
     # This shouldn't have an effect.
     np.testing.assert_allclose(ema_fast(1., update_stats=False),
                                ema_slow(1., update_stats=False),
                                rtol=1e-4)
     np.testing.assert_allclose(ema_fast(1.), ema_slow(1.), rtol=1e-4)
     self.assertGreater(ema_fast(2.), ema_slow(2.))
  def test_zero_decay(self):
    ema = moving_averages.ExponentialMovingAverage(0.)
    random_input = jax.random.uniform(jax.random.PRNGKey(428), shape=(2, 3, 4))

    # The ema should be equal to the input with decay=0.
    np.testing.assert_allclose(random_input[0], ema(random_input[0]))
    np.testing.assert_allclose(random_input[1], ema(random_input[1]))
Ejemplo n.º 4
0
  def __init__(self,
               embedding_dim,
               num_embeddings,
               commitment_cost,
               decay,
               epsilon: float = 1e-5,
               dtype: DType = jnp.float32,
               name: str = None):
    """Initializes a VQ-VAE EMA module.

    Args:
      embedding_dim: integer representing the dimensionality of the tensors in
        the quantized space. Inputs to the modules must be in this format as
        well.
      num_embeddings: integer, the number of vectors in the quantized space.
      commitment_cost: scalar which controls the weighting of the loss terms
        (see equation 4 in the paper - this variable is Beta).
      decay: float between 0 and 1, controls the speed of the Exponential Moving
        Averages.
      epsilon: small constant to aid numerical stability, default 1e-5.
      dtype: dtype for the embeddings variable, defaults to tf.float32.
      name: name of the module.
    """
    super(VectorQuantizerEMA, self).__init__(name=name)
    self.embedding_dim = embedding_dim
    self.num_embeddings = num_embeddings
    if not 0 <= decay <= 1:
      raise ValueError('decay must be in range [0, 1]')
    self.decay = decay
    self.commitment_cost = commitment_cost
    self.epsilon = epsilon

    embedding_shape = [embedding_dim, num_embeddings]
    initializer = initializers.VarianceScaling(distribution='uniform')
    embeddings = base.get_state('embeddings', embedding_shape, dtype,
                                init=initializer)

    self.ema_cluster_size = moving_averages.ExponentialMovingAverage(
        decay=self.decay, name='ema_cluster_size')
    self.ema_cluster_size.initialize(jnp.zeros([num_embeddings], dtype=dtype))

    self.ema_dw = moving_averages.ExponentialMovingAverage(
        decay=self.decay, name='ema_dw')
    self.ema_dw.initialize(embeddings)
Ejemplo n.º 5
0
    def test_initialize(self):
        ema = moving_averages.ExponentialMovingAverage(0.99)
        ema.initialize(jnp.ones([]))
        self.assertEqual(ema.average, 0.)
        ema(jnp.array(100.))

        # Matching the behavior of Sonnet 2 initialize only sets the value to zero
        # if the EMA has not already been initialized.
        ema.initialize(jnp.ones([]))
        self.assertNotEqual(ema.average, 0.)
Ejemplo n.º 6
0
    def test_maybe_initialize(self):
        ema = moving_averages.ExponentialMovingAverage(0.99)
        ema.maybe_initialize([], jnp.float32)
        self.assertEqual(ema.average, 0.)
        ema(jnp.array(100.))

        # Matching the behavior of Sonnet 2 maybe_initialize only sets the value if
        # the EMA has not already been initialized.
        ema.maybe_initialize([], jnp.float32)
        self.assertNotEqual(ema.average, 0.)
  def test_warmup(self):
    ema = moving_averages.ExponentialMovingAverage(
        0.5, warmup_length=2, zero_debias=False)
    random_input = jax.random.uniform(jax.random.PRNGKey(428), shape=(2, 3, 4))

    # The ema should be equal to the input for the first two calls.
    np.testing.assert_allclose(random_input[0], ema(random_input[0]))
    np.testing.assert_allclose(random_input[0], ema(random_input[0]))

    # After the warmup period, with decay = 0.5 it should be halfway between the
    # first two inputs and the new input.
    np.testing.assert_allclose(
        (random_input[0] + random_input[1]) / 2, ema(random_input[1]))
Ejemplo n.º 8
0
 def f(x):
     return moving_averages.ExponentialMovingAverage(0.5)(x)
Ejemplo n.º 9
0
 def test_fast_slow_decay(self):
     ema_fast = moving_averages.ExponentialMovingAverage(0.2)
     ema_slow = moving_averages.ExponentialMovingAverage(0.8)
     np.testing.assert_allclose(ema_fast(1.), ema_slow(1.), rtol=1e-4)
     # Expect fast decay to increase more quickly than slow.
     self.assertGreater(ema_fast(2.), ema_slow(2.))
Ejemplo n.º 10
0
 def test_call(self):
     ema = moving_averages.ExponentialMovingAverage(0.5)
     self.assertAlmostEqual(ema(3.), 3.)
     self.assertAlmostEqual(ema(6.), 5.)
Ejemplo n.º 11
0
 def test_warmup_length_and_zero_debias(self):
   with self.assertRaises(ValueError):
     moving_averages.ExponentialMovingAverage(
         0.5, warmup_length=2, zero_debias=True)
Ejemplo n.º 12
0
 def test_invalid_warmup_length(self):
   with self.assertRaises(ValueError):
     moving_averages.ExponentialMovingAverage(
         0.5, warmup_length=-1, zero_debias=False)