def __init__(self, create_scale, create_offset, decay_rate, eps=1e-5, scale_init=None, offset_init=None, axis=None, cross_replica_axis=None, data_format="channels_last", name=None): """Constructs a BatchNorm module. Args: create_scale: Whether to include a trainable scaling factor. create_offset: Whether to include a trainable offset. decay_rate: Decay rate for EMA. eps: Small epsilon to avoid division by zero variance. Defaults 1e-5, as in the paper and Sonnet. scale_init: Optional initializer for gain (aka scale). Can only be set if `create_scale=True`. By default, one. offset_init: Optional initializer for bias (aka offset). Can only be set if `create_offset=True`. By default, zero. axis: Which axes to reduce over. The default (None) signifies that all but the channel axis should be normalized. Otherwise this is a list of axis indices which will have normalization statistics calculated. cross_replica_axis: If not None, it should be a string representing the axis name over which this module is being run within a jax.pmap. Supplying this argument means that batch statistics are calculated across all replicas on that axis. data_format: The data format of the input. Can be either `channels_first`, `channels_last`, `N...C` or `NC...`. By default it is `channels_last`. name: The module name. """ super(BatchNorm, self).__init__(name=name) self._create_scale = create_scale self._create_offset = create_offset if not self._create_scale and scale_init is not None: raise ValueError("Cannot set `scale_init` if `create_scale=False`") self._scale_init = scale_init or jnp.ones if not self._create_offset and offset_init is not None: raise ValueError( "Cannot set `offset_init` if `create_offset=False`") self._offset_init = offset_init or jnp.zeros self._eps = eps self._cross_replica_axis = cross_replica_axis self._data_format = data_format self._channel_index = utils.get_channel_index(data_format) self._axis = axis self._mean_ema = moving_averages.ExponentialMovingAverage( decay_rate, name="mean_ema") self._var_ema = moving_averages.ExponentialMovingAverage( decay_rate, name="var_ema")
def test_fast_slow_decay_without_update(self): ema_fast = moving_averages.ExponentialMovingAverage(0.5) ema_slow = moving_averages.ExponentialMovingAverage(0.8) # This shouldn't have an effect. np.testing.assert_allclose(ema_fast(1., update_stats=False), ema_slow(1., update_stats=False), rtol=1e-4) np.testing.assert_allclose(ema_fast(1.), ema_slow(1.), rtol=1e-4) self.assertGreater(ema_fast(2.), ema_slow(2.))
def test_zero_decay(self): ema = moving_averages.ExponentialMovingAverage(0.) random_input = jax.random.uniform(jax.random.PRNGKey(428), shape=(2, 3, 4)) # The ema should be equal to the input with decay=0. np.testing.assert_allclose(random_input[0], ema(random_input[0])) np.testing.assert_allclose(random_input[1], ema(random_input[1]))
def __init__(self, embedding_dim, num_embeddings, commitment_cost, decay, epsilon: float = 1e-5, dtype: DType = jnp.float32, name: str = None): """Initializes a VQ-VAE EMA module. Args: embedding_dim: integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well. num_embeddings: integer, the number of vectors in the quantized space. commitment_cost: scalar which controls the weighting of the loss terms (see equation 4 in the paper - this variable is Beta). decay: float between 0 and 1, controls the speed of the Exponential Moving Averages. epsilon: small constant to aid numerical stability, default 1e-5. dtype: dtype for the embeddings variable, defaults to tf.float32. name: name of the module. """ super(VectorQuantizerEMA, self).__init__(name=name) self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings if not 0 <= decay <= 1: raise ValueError('decay must be in range [0, 1]') self.decay = decay self.commitment_cost = commitment_cost self.epsilon = epsilon embedding_shape = [embedding_dim, num_embeddings] initializer = initializers.VarianceScaling(distribution='uniform') embeddings = base.get_state('embeddings', embedding_shape, dtype, init=initializer) self.ema_cluster_size = moving_averages.ExponentialMovingAverage( decay=self.decay, name='ema_cluster_size') self.ema_cluster_size.initialize(jnp.zeros([num_embeddings], dtype=dtype)) self.ema_dw = moving_averages.ExponentialMovingAverage( decay=self.decay, name='ema_dw') self.ema_dw.initialize(embeddings)
def test_initialize(self): ema = moving_averages.ExponentialMovingAverage(0.99) ema.initialize(jnp.ones([])) self.assertEqual(ema.average, 0.) ema(jnp.array(100.)) # Matching the behavior of Sonnet 2 initialize only sets the value to zero # if the EMA has not already been initialized. ema.initialize(jnp.ones([])) self.assertNotEqual(ema.average, 0.)
def test_maybe_initialize(self): ema = moving_averages.ExponentialMovingAverage(0.99) ema.maybe_initialize([], jnp.float32) self.assertEqual(ema.average, 0.) ema(jnp.array(100.)) # Matching the behavior of Sonnet 2 maybe_initialize only sets the value if # the EMA has not already been initialized. ema.maybe_initialize([], jnp.float32) self.assertNotEqual(ema.average, 0.)
def test_warmup(self): ema = moving_averages.ExponentialMovingAverage( 0.5, warmup_length=2, zero_debias=False) random_input = jax.random.uniform(jax.random.PRNGKey(428), shape=(2, 3, 4)) # The ema should be equal to the input for the first two calls. np.testing.assert_allclose(random_input[0], ema(random_input[0])) np.testing.assert_allclose(random_input[0], ema(random_input[0])) # After the warmup period, with decay = 0.5 it should be halfway between the # first two inputs and the new input. np.testing.assert_allclose( (random_input[0] + random_input[1]) / 2, ema(random_input[1]))
def f(x): return moving_averages.ExponentialMovingAverage(0.5)(x)
def test_fast_slow_decay(self): ema_fast = moving_averages.ExponentialMovingAverage(0.2) ema_slow = moving_averages.ExponentialMovingAverage(0.8) np.testing.assert_allclose(ema_fast(1.), ema_slow(1.), rtol=1e-4) # Expect fast decay to increase more quickly than slow. self.assertGreater(ema_fast(2.), ema_slow(2.))
def test_call(self): ema = moving_averages.ExponentialMovingAverage(0.5) self.assertAlmostEqual(ema(3.), 3.) self.assertAlmostEqual(ema(6.), 5.)
def test_warmup_length_and_zero_debias(self): with self.assertRaises(ValueError): moving_averages.ExponentialMovingAverage( 0.5, warmup_length=2, zero_debias=True)
def test_invalid_warmup_length(self): with self.assertRaises(ValueError): moving_averages.ExponentialMovingAverage( 0.5, warmup_length=-1, zero_debias=False)