Esempio n. 1
  def test_add_noise_has_correct_variance_scaling(self):
    # Prepare to compare noise with a rescaled unit-variance substitute.
    eta = 0.3
    gamma = 0.55
    seed = 314
    noise = transform.add_noise(eta, gamma, seed)
    noise_unit = transform.add_noise(1.0, 0.0, seed)

    params = self.init_params
    state = noise.init(params)
    state_unit = noise_unit.init(params)

    # Check the noise itself by adding it to zeros.
    updates = jax.tree_map(jnp.zeros_like, params)

    for i in range(1, STEPS + 1):
      updates_i, state = self.variant(noise.update)(updates, state)
      updates_i_unit, state_unit = noise_unit.update(updates, state_unit)

      scale = jnp.sqrt(eta / i**gamma)

      updates_i_rescaled = jax.tree_map(
          lambda g, s=scale: g * s, updates_i_unit)

      chex.assert_tree_all_close(updates_i, updates_i_rescaled, rtol=1e-4)
Esempio n. 2
def noisy_sgd(learning_rate: ScalarOrSchedule,
              eta: float = 0.01,
              gamma: float = 0.55,
              seed: int = 0) -> base.GradientTransformation:
    r"""A variant of SGD with added noise.

  It has been found that adding noise to the gradients can improve
  both the training error and the generalisation error in very deep networks.

    [Neelakantan et al, 2014](

    learning_rate: this is a fixed global scaling factor.
    eta: the initial variance for the gaussian noise added to gradients.
    gamma: a parameter controlling the annealing of noise over time,
      the variance decays according to `(1+t)^-\gamma`.
    seed: the seed for the pseudo-random generation process.

    the corresponding `GradientTransformation`.
    return combine.chain(
        transform.add_noise(eta, gamma, seed),
Esempio n. 3
def noisy_sgd(learning_rate: float,
              eta: float = 0.01,
              gamma: float = 0.55,
              seed: int = 0) -> GradientTransformation:
    return combine.chain(
        transform.trace(decay=0., nesterov=False),
        transform.add_noise(eta, gamma, seed),