def testMaxLearningRate(self, train_shape, network, out_logits, fn_and_kernel, name): key = random.PRNGKey(0) key, split = random.split(key) if len(train_shape) == 2: train_shape = (train_shape[0] * 5, train_shape[1] * 10) else: train_shape = (16, 8, 8, 3) x_train = random.normal(split, train_shape) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) for lr_factor in [0.5, 3.]: params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) # Regress to an MSE loss. loss = lambda params, x: \ 0.5 * np.mean((f(params, x) - y_train) ** 2) grad_loss = jit(grad(loss)) g_dd = ntk(x_train, None, 'ntk') steps = 20 if name == 'theoretical': step_size = predict.max_learning_rate( g_dd, num_outputs=out_logits) * lr_factor else: step_size = predict.max_learning_rate( g_dd, num_outputs=-1) * lr_factor opt_init, opt_update, get_params = optimizers.sgd(step_size) opt_state = opt_init(params) def get_loss(opt_state): return loss(get_params(opt_state), x_train) init_loss = get_loss(opt_state) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) trained_loss = get_loss(opt_state) loss_ratio = trained_loss / (init_loss + 1e-12) if lr_factor == 3.: if not math.isnan(loss_ratio): self.assertGreater(loss_ratio, 10.) else: self.assertLess(loss_ratio, 0.1)
def testMaxLearningRate(self, train_shape, network, out_logits, fn_and_kernel): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key) key = keys[0] split = keys[1] if len(train_shape) == 2: train_shape = (train_shape[0] * 5, train_shape[1] * 10) else: train_shape = (16, 8, 8, 3) x_train = np.asarray(normal(train_shape, seed=split)) keys = tf_random_split(key) key = keys[0] split = keys[1] y_train = np.asarray( stateless_uniform(shape=(train_shape[0], out_logits), seed=split, minval=0, maxval=1) < 0.5, np.float32) # Regress to an MSE loss. loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train)**2) grad_loss = jit(grad(loss)) def get_loss(opt_state): return loss(get_params(opt_state), x_train) steps = 20 for lr_factor in [0.5, 3.]: params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) g_dd = ntk(x_train, None, 'ntk') step_size = predict.max_learning_rate( g_dd, y_train_size=y_train.size) * lr_factor opt_init, opt_update, get_params = optimizers.sgd(step_size) opt_state = opt_init(params) init_loss = get_loss(opt_state) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) trained_loss = get_loss(opt_state) loss_ratio = trained_loss / (init_loss + 1e-12) if lr_factor == 3.: if not math.isnan(loss_ratio): self.assertGreater(loss_ratio, 10.) else: self.assertLess(loss_ratio, 0.1)
def testMaxLearningRate(self, train_shape, network, out_logits, fn_and_kernel, lr_factor, momentum): key = random.PRNGKey(0) key, split = random.split(key) if len(train_shape) == 2: train_shape = (train_shape[0] * 5, train_shape[1] * 10) else: train_shape = (16, 8, 8, 3) x_train = random.normal(split, train_shape) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) # Regress to an MSE loss. loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train) ** 2) grad_loss = jit(grad(loss)) def get_loss(opt_state): return loss(get_params(opt_state), x_train) steps = 30 params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) g_dd = ntk(x_train, None, 'ntk') step_size = predict.max_learning_rate( g_dd, y_train_size=y_train.size, momentum=momentum) * lr_factor opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum) opt_state = opt_init(params) init_loss = get_loss(opt_state) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) trained_loss = get_loss(opt_state) loss_ratio = trained_loss / (init_loss + 1e-12) if lr_factor < 1.: self.assertLess(loss_ratio, 0.1) elif lr_factor == 1: # At the threshold, the loss decays slowly self.assertLess(loss_ratio, 1.) if lr_factor > 2.: if not math.isnan(loss_ratio): self.assertGreater(loss_ratio, 10.)