def testPredictOnCPU(self): x_train = random.normal(random.PRNGKey(1), (4, 4, 4, 2)) x_test = random.normal(random.PRNGKey(1), (8, 4, 4, 2)) y_train = random.uniform(random.PRNGKey(1), (4, 2)) _, _, kernel_fn = stax.serial( stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1)) for store_on_device in [False, True]: for device_count in [0, 1]: for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]: for x in [None, 'x_test']: with self.subTest( store_on_device=store_on_device, device_count=device_count, get=get, x=x): kernel_fn_batched = batch.batch(kernel_fn, 2, device_count, store_on_device) predictor = predict.gradient_descent_mse_ensemble( kernel_fn_batched, x_train, y_train) x = x if x is None else x_test predict_none = predictor(None, x, get, compute_cov=True) predict_inf = predictor(np.inf, x, get, compute_cov=True) self.assertAllClose(predict_none, predict_inf) if x is not None: on_cpu = (not store_on_device or xla_bridge.get_backend().platform == 'cpu') self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf)) self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
def testPredictOnCPU(self): key1 = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) key2 = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) key3 = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) x_train = np.asarray(normal((4, 4, 4, 2), seed=key1)) x_test = np.asarray(normal((8, 4, 4, 2), seed=key2)) y_train = np.asarray(stateless_uniform(shape=(4, 2), seed=key3)) _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1)) for store_on_device in [False, True]: for device_count in [0, 1]: for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]: for x in [None, 'x_test']: with self.subTest(store_on_device=store_on_device, device_count=device_count, get=get, x=x): kernel_fn_batched = batch.batch( kernel_fn, 2, device_count, store_on_device) predictor = predict.gradient_descent_mse_ensemble( kernel_fn_batched, x_train, y_train) x = x if x is None else x_test predict_none = predictor(None, x, get, compute_cov=True) predict_inf = predictor(np.inf, x, get, compute_cov=True) self.assertAllClose(predict_none, predict_inf) if x is not None: on_cpu = (not store_on_device or xla_bridge.get_backend().platform == 'cpu') self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf)) self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
def max_learning_rate(ntk_train_train: np.ndarray, y_train_size: int = None, eps: float = 1e-12) -> float: r"""Computes the maximal feasible learning rate for infinite width NNs. The network is assumed to be trained using SGD or full-batch GD with mean squared loss. The loss is assumed to have the form `1/(2 * batch_size * output_size) \|f(train_x) - train_y\|^2`. The maximal feasible learning rate is the largest `\eta` such that the operator `(I - \eta / (batch_size * output_size) * NTK)` is a contraction, which is '2 * batch_size * output_size * lambda_max(NTK)'. Args: ntk_train_train: analytic or empirical NTK on the training data. y_train_size: total training set output size, i.e. `f(x_train).size == y_train.size`. If `output_size=None` it is inferred from `ntk_train_train.shape` assuming `trace_axes=()`. eps: a float to avoid zero divisor. Returns: The maximal feasible learning rate for infinite width NNs. """ ntk_train_train = utils.make_2d(ntk_train_train) factor = ntk_train_train.shape[0] if y_train_size is None else y_train_size if utils.is_on_cpu(ntk_train_train): max_eva = osp.linalg.eigvalsh( ntk_train_train, eigvals=(ntk_train_train.shape[0] - 1, ntk_train_train.shape[0] - 1))[-1] else: max_eva = np.linalg.eigvalsh(ntk_train_train)[-1] lr = 2 * factor / (max_eva + eps) return lr
def testIsOnCPU(self): for dtype in [np.float32, np.float64]: with self.subTest(dtype=dtype): def x(): return random.normal(random.PRNGKey(1), (2, 3), dtype) def x_cpu(): return device_get( random.normal(random.PRNGKey(1), (2, 3), dtype)) x_jit = jit(x) # x_cpu_jit = jit(x_cpu) x_cpu_jit_cpu = jit(x_cpu, backend='cpu') self.assertTrue(utils.is_on_cpu(x_cpu())) # TODO(mattjj): re-enable this when device_put under jit works # self.assertTrue(utils.is_on_cpu(x_cpu_jit())) self.assertTrue(utils.is_on_cpu(x_cpu_jit_cpu())) if xla_bridge.get_backend().platform == 'cpu': self.assertTrue(utils.is_on_cpu(x())) self.assertTrue(utils.is_on_cpu(x_jit())) else: self.assertFalse(utils.is_on_cpu(x())) self.assertFalse(utils.is_on_cpu(x_jit()))
def max_learning_rate(ntk_train_train: np.ndarray, y_train_size: int = None, momentum=0., eps: float = 1e-12) -> float: r"""Computes the maximal feasible learning rate for infinite width NNs. The network is assumed to be trained using mini-/full-batch GD + momentum with mean squared loss. The loss is assumed to have the form `1/(2 * batch_size * output_size) \|f(train_x) - train_y\|^2`. For vanilla SGD (i.e. `momentum = 0`) the maximal feasible learning rate is the largest `\eta` such that the operator `(I - \eta / (batch_size * output_size) * NTK)` is a contraction, which is `2 * batch_size * output_size * lambda_max(NTK)`. When `momentum > 0`, we use (see `The Dynamics of Momentum` section in https://distill.pub/2017/momentum/) `2 * (1 + momentum) * batch_size * output_size * lambda_max(NTK)`. Args: ntk_train_train: analytic or empirical NTK on the training data. y_train_size: total training set output size, i.e. `f(x_train).size == y_train.size`. If `output_size=None` it is inferred from `ntk_train_train.shape` assuming `trace_axes=()`. momentum: The `momentum` for momentum optimizers. eps: a float to avoid zero divisor. Returns: The maximal feasible learning rate for infinite width NNs. """ ntk_train_train = utils.make_2d(ntk_train_train) factor = ntk_train_train.shape[0] if y_train_size is None else y_train_size if utils.is_on_cpu(ntk_train_train): max_eva = osp.linalg.eigvalsh( ntk_train_train, eigvals=(ntk_train_train.shape[0] - 1, ntk_train_train.shape[0] - 1))[-1] else: max_eva = np.linalg.eigvalsh(ntk_train_train)[-1] lr = 2 * (1 + momentum) * factor / (max_eva + eps) return lr