def testPredictOnCPU(self):
    x_train = random.normal(random.PRNGKey(1), (4, 4, 4, 2))
    x_test = random.normal(random.PRNGKey(1), (8, 4, 4, 2))

    y_train = random.uniform(random.PRNGKey(1), (4, 2))

    _, _, kernel_fn = stax.serial(
        stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1))

    for store_on_device in [False, True]:
      for device_count in [0, 1]:
        for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
          for x in [None, 'x_test']:
            with self.subTest(
                store_on_device=store_on_device,
                device_count=device_count,
                get=get,
                x=x):
              kernel_fn_batched = batch.batch(kernel_fn, 2, device_count,
                                              store_on_device)
              predictor = predict.gradient_descent_mse_ensemble(
                  kernel_fn_batched, x_train, y_train)

              x = x if x is None else x_test
              predict_none = predictor(None, x, get, compute_cov=True)
              predict_inf = predictor(np.inf, x, get, compute_cov=True)
              self.assertAllClose(predict_none, predict_inf)

              if x is not None:
                on_cpu = (not store_on_device or
                          xla_bridge.get_backend().platform == 'cpu')
                self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf))
                self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
示例#2
0
    def testPredictOnCPU(self):
        key1 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key2 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key3 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        x_train = np.asarray(normal((4, 4, 4, 2), seed=key1))
        x_test = np.asarray(normal((8, 4, 4, 2), seed=key2))

        y_train = np.asarray(stateless_uniform(shape=(4, 2), seed=key3))

        _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(),
                                      stax.Flatten(), stax.Dense(1))

        for store_on_device in [False, True]:
            for device_count in [0, 1]:
                for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
                    for x in [None, 'x_test']:
                        with self.subTest(store_on_device=store_on_device,
                                          device_count=device_count,
                                          get=get,
                                          x=x):
                            kernel_fn_batched = batch.batch(
                                kernel_fn, 2, device_count, store_on_device)
                            predictor = predict.gradient_descent_mse_ensemble(
                                kernel_fn_batched, x_train, y_train)

                            x = x if x is None else x_test
                            predict_none = predictor(None,
                                                     x,
                                                     get,
                                                     compute_cov=True)
                            predict_inf = predictor(np.inf,
                                                    x,
                                                    get,
                                                    compute_cov=True)
                            self.assertAllClose(predict_none, predict_inf)

                            if x is not None:
                                on_cpu = (not store_on_device
                                          or xla_bridge.get_backend().platform
                                          == 'cpu')
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_inf))
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_none))
示例#3
0
def max_learning_rate(ntk_train_train: np.ndarray,
                      y_train_size: int = None,
                      eps: float = 1e-12) -> float:
    r"""Computes the maximal feasible learning rate for infinite width NNs.

  The network is assumed to be trained using SGD or full-batch GD with mean
  squared loss. The loss is assumed to have the form
  `1/(2 * batch_size * output_size) \|f(train_x) - train_y\|^2`. The maximal
  feasible learning rate is the largest `\eta` such that the operator
  `(I - \eta / (batch_size * output_size) * NTK)` is a contraction, which is
  '2 * batch_size * output_size * lambda_max(NTK)'.

  Args:
    ntk_train_train: analytic or empirical NTK on the training data.
    y_train_size: total training set output size, i.e.
      `f(x_train).size ==  y_train.size`. If `output_size=None` it is inferred
      from `ntk_train_train.shape` assuming `trace_axes=()`.
    eps: a float to avoid zero divisor.

  Returns:
    The maximal feasible learning rate for infinite width NNs.
  """
    ntk_train_train = utils.make_2d(ntk_train_train)
    factor = ntk_train_train.shape[0] if y_train_size is None else y_train_size

    if utils.is_on_cpu(ntk_train_train):
        max_eva = osp.linalg.eigvalsh(
            ntk_train_train,
            eigvals=(ntk_train_train.shape[0] - 1,
                     ntk_train_train.shape[0] - 1))[-1]
    else:
        max_eva = np.linalg.eigvalsh(ntk_train_train)[-1]
    lr = 2 * factor / (max_eva + eps)
    return lr
示例#4
0
    def testIsOnCPU(self):
        for dtype in [np.float32, np.float64]:
            with self.subTest(dtype=dtype):

                def x():
                    return random.normal(random.PRNGKey(1), (2, 3), dtype)

                def x_cpu():
                    return device_get(
                        random.normal(random.PRNGKey(1), (2, 3), dtype))

                x_jit = jit(x)
                # x_cpu_jit = jit(x_cpu)
                x_cpu_jit_cpu = jit(x_cpu, backend='cpu')

                self.assertTrue(utils.is_on_cpu(x_cpu()))
                # TODO(mattjj): re-enable this when device_put under jit works
                # self.assertTrue(utils.is_on_cpu(x_cpu_jit()))
                self.assertTrue(utils.is_on_cpu(x_cpu_jit_cpu()))

                if xla_bridge.get_backend().platform == 'cpu':
                    self.assertTrue(utils.is_on_cpu(x()))
                    self.assertTrue(utils.is_on_cpu(x_jit()))
                else:
                    self.assertFalse(utils.is_on_cpu(x()))
                    self.assertFalse(utils.is_on_cpu(x_jit()))
示例#5
0
def max_learning_rate(ntk_train_train: np.ndarray,
                      y_train_size: int = None,
                      momentum=0.,
                      eps: float = 1e-12) -> float:
    r"""Computes the maximal feasible learning rate for infinite width NNs.

  The network is assumed to be trained using mini-/full-batch GD + momentum
  with mean squared loss. The loss is assumed to have the form
  `1/(2 * batch_size * output_size) \|f(train_x) - train_y\|^2`. For vanilla SGD
  (i.e. `momentum = 0`) the maximal feasible learning rate is the largest `\eta`
  such that the operator
                `(I - \eta / (batch_size * output_size) * NTK)`
  is a contraction, which is
                `2 * batch_size * output_size * lambda_max(NTK)`.
  When `momentum > 0`, we use (see `The Dynamics of Momentum` section in
  https://distill.pub/2017/momentum/)
                `2 * (1 + momentum) * batch_size * output_size * lambda_max(NTK)`.

  Args:
    ntk_train_train:
      analytic or empirical NTK on the training data.
    y_train_size:
      total training set output size, i.e.
      `f(x_train).size ==  y_train.size`. If `output_size=None` it is inferred
      from `ntk_train_train.shape` assuming `trace_axes=()`.
    momentum:
      The `momentum` for momentum optimizers.
    eps:
      a float to avoid zero divisor.

  Returns:
    The maximal feasible learning rate for infinite width NNs.
  """
    ntk_train_train = utils.make_2d(ntk_train_train)
    factor = ntk_train_train.shape[0] if y_train_size is None else y_train_size

    if utils.is_on_cpu(ntk_train_train):
        max_eva = osp.linalg.eigvalsh(
            ntk_train_train,
            eigvals=(ntk_train_train.shape[0] - 1,
                     ntk_train_train.shape[0] - 1))[-1]
    else:
        max_eva = np.linalg.eigvalsh(ntk_train_train)[-1]
    lr = 2 * (1 + momentum) * factor / (max_eva + eps)
    return lr