Ejemplo n.º 1
0
 def testConstructNAdamWithLR(self):
   opt = nadam.Nadam(lr=1.0)
   self.assertEqual(opt.lr, 1.0)
   opt_2 = nadam.Nadam(learning_rate=0.1, lr=1.0)
   self.assertEqual(opt_2.lr, 1.0)
   opt_3 = nadam.Nadam(learning_rate=0.1)
   self.assertEqual(opt_3.lr, 0.1)
Ejemplo n.º 2
0
    def testConstructNAdamWithEpsilonValues(self):
        opt = nadam.Nadam(epsilon=None)
        config = opt.get_config()
        self.assertEqual(config["epsilon"], 1e-7)

        opt = nadam.Nadam(epsilon=1e-8)
        config = opt.get_config()
        self.assertEqual(config["epsilon"], 1e-8)
Ejemplo n.º 3
0
  def testBasic(self):
    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
      with self.cached_session():
        # Initialize variables for numpy implementation.
        m0, v0, m1, v1, mcache = 0.0, 0.0, 0.0, 0.0, 1.0
        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
        grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
        grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)

        var0 = resource_variable_ops.ResourceVariable(var0_np)
        var1 = resource_variable_ops.ResourceVariable(var1_np)
        grads0 = constant_op.constant(grads0_np)
        grads1 = constant_op.constant(grads1_np)
        opt = nadam.Nadam()
        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
        variables.global_variables_initializer().run()

        # Fetch params to validate initial values
        self.assertAllClose([1.0, 2.0], var0.eval())
        self.assertAllClose([3.0, 4.0], var1.eval())

        # Run 3 steps of Nadam
        for t in range(3):
          update.run()

          mcache = update_m_cache(mcache, t)
          var0_np, m0, v0 = nadam_update_numpy(var0_np, grads0_np, t, m0, v0,
                                               mcache)
          var1_np, m1, v1 = nadam_update_numpy(var1_np, grads1_np, t, m1, v1,
                                               mcache)

          # Validate updated params
          self.assertAllCloseAccordingToType(var0_np, var0.eval())
          self.assertAllCloseAccordingToType(var1_np, var1.eval())
Ejemplo n.º 4
0
    def testBasicWithLearningRateDecay(self):
        for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
            with self.cached_session():
                # Initialize variables for numpy implementation.
                m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
                var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
                grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
                var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
                grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)

                var0 = resource_variable_ops.ResourceVariable(var0_np)
                var1 = resource_variable_ops.ResourceVariable(var1_np)
                grads0 = constant_op.constant(grads0_np)
                grads1 = constant_op.constant(grads1_np)
                learning_rate = 0.001
                decay = 0.5
                opt = nadam.Nadam(learning_rate=learning_rate, decay=decay)
                update = opt.apply_gradients(
                    zip([grads0, grads1], [var0, var1]))
                variables.global_variables_initializer().run()

                # Fetch params to validate initial values
                self.assertAllClose([1.0, 2.0], var0.eval())
                self.assertAllClose([3.0, 4.0], var1.eval())

                beta1_power, beta2_power = get_beta_accumulators(opt, dtype)

                # Run 3 steps of Nadam
                for t in range(3):
                    self.assertAllCloseAccordingToType(0.9**(t + 1),
                                                       beta1_power.eval())
                    self.assertAllCloseAccordingToType(0.999**(t + 1),
                                                       beta2_power.eval())
                    update.run()

                    lr = learning_rate / (1 + decay * t)
                    var0_np, m0, v0 = nadam_update_numpy(var0_np,
                                                         grads0_np,
                                                         t,
                                                         m0,
                                                         v0,
                                                         alpha=lr)
                    var1_np, m1, v1 = nadam_update_numpy(var1_np,
                                                         grads1_np,
                                                         t,
                                                         m1,
                                                         v1,
                                                         alpha=lr)

                    # Validate updated params
                    self.assertAllCloseAccordingToType(var0_np, var0.eval())
                    self.assertAllCloseAccordingToType(var1_np, var1.eval())