Esempio n. 1
0
  def test_add_noise_has_correct_variance_scaling(self):
    # Prepare to compare noise with a rescaled unit-variance substitute.
    eta = 0.3
    gamma = 0.55
    seed = 314
    noise = transform.add_noise(eta, gamma, seed)
    noise_unit = transform.add_noise(1.0, 0.0, seed)

    params = self.init_params
    state = noise.init(params)
    state_unit = noise_unit.init(params)

    # Check the noise itself by adding it to zeros.
    updates = jax.tree_map(jnp.zeros_like, params)

    for i in range(1, STEPS + 1):
      updates_i, state = self.variant(noise.update)(updates, state)
      updates_i_unit, state_unit = noise_unit.update(updates, state_unit)

      scale = jnp.sqrt(eta / i**gamma)

      updates_i_rescaled = jax.tree_map(
          lambda g, s=scale: g * s, updates_i_unit)

      chex.assert_tree_all_close(updates_i, updates_i_rescaled, rtol=1e-4)
Esempio n. 2
0
def noisy_sgd(learning_rate: ScalarOrSchedule,
              eta: float = 0.01,
              gamma: float = 0.55,
              seed: int = 0) -> base.GradientTransformation:
    r"""A variant of SGD with added noise.

  It has been found that adding noise to the gradients can improve
  both the training error and the generalisation error in very deep networks.

  References:
    [Neelakantan et al, 2014](https://arxiv.org/abs/1511.06807)

  Args:
    learning_rate: this is a fixed global scaling factor.
    eta: the initial variance for the gaussian noise added to gradients.
    gamma: a parameter controlling the annealing of noise over time,
      the variance decays according to `(1+t)^-\gamma`.
    seed: the seed for the pseudo-random generation process.

  Returns:
    the corresponding `GradientTransformation`.
  """
    return combine.chain(
        _scale_by_learning_rate(learning_rate),
        transform.add_noise(eta, gamma, seed),
    )
Esempio n. 3
0
def noisy_sgd(learning_rate: float,
              eta: float = 0.01,
              gamma: float = 0.55,
              seed: int = 0) -> GradientTransformation:
    return combine.chain(
        transform.trace(decay=0., nesterov=False),
        transform.scale(-learning_rate),
        transform.add_noise(eta, gamma, seed),
    )