def testMaxLearningRate(self, train_shape, network, out_logits, fn_and_kernel): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key) key = keys[0] split = keys[1] if len(train_shape) == 2: train_shape = (train_shape[0] * 5, train_shape[1] * 10) else: train_shape = (16, 8, 8, 3) x_train = np.asarray(normal(train_shape, seed=split)) keys = tf_random_split(key) key = keys[0] split = keys[1] y_train = np.asarray( stateless_uniform(shape=(train_shape[0], out_logits), seed=split, minval=0, maxval=1) < 0.5, np.float32) # Regress to an MSE loss. loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train)**2) grad_loss = jit(grad(loss)) def get_loss(opt_state): return loss(get_params(opt_state), x_train) steps = 20 for lr_factor in [0.5, 3.]: params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) g_dd = ntk(x_train, None, 'ntk') step_size = predict.max_learning_rate( g_dd, y_train_size=y_train.size) * lr_factor opt_init, opt_update, get_params = optimizers.sgd(step_size) opt_state = opt_init(params) init_loss = get_loss(opt_state) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) trained_loss = get_loss(opt_state) loss_ratio = trained_loss / (init_loss + 1e-12) if lr_factor == 3.: if not math.isnan(loss_ratio): self.assertGreater(loss_ratio, 10.) else: self.assertLess(loss_ratio, 0.1)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, apply_fn, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) _, params = init_fn(key, (1, 784)) # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train) # Get initial values of the network in function space. fx_train = apply_fn(params, x_train) fx_test = apply_fn(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test, g_td) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
def testTrainedEnsemblePredCov(self, train_shape, test_shape, network, out_logits): training_steps = 1000 learning_rate = 0.1 ensemble_size = 1024 init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_update = jit(opt_update) key, x_test, x_train, y_train = self._get_inputs( out_logits, test_shape, train_shape) predict_fn_mse_ens = predict.gradient_descent_mse_ensemble( kernel_fn, x_train, y_train, learning_rate=learning_rate, diag_reg=0.) train = (x_train, y_train) ensemble_key = tf_random_split(key, ensemble_size) loss = jit(lambda params, x, y: 0.5 * np.mean( (apply_fn(params, x) - y)**2)) grad_loss = jit(lambda state, x, y: grad(loss) (get_params(state), x, y)) def train_network(key): _, params = init_fn(key, (-1, ) + train_shape[1:]) opt_state = opt_init(params) for i in range(training_steps): opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state) return get_params(opt_state) params = vmap(train_network)(ensemble_key) rtol = 0.08 for x in [None, 'x_test']: with self.subTest(x=x): x = x if x is None else x_test x_fin = x_train if x is None else x_test ensemble_fx = vmap(apply_fn, (0, None))(params, x_fin) mean_emp = np.mean(ensemble_fx, axis=0) mean_subtracted = ensemble_fx - mean_emp cov_emp = np.einsum( 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (mean_subtracted.shape[0] * mean_subtracted.shape[-1]) ntk = predict_fn_mse_ens(training_steps, x, 'ntk', compute_cov=True) self._assertAllClose(mean_emp, ntk.mean, rtol) self._assertAllClose(cov_emp, ntk.covariance, rtol)
def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits, fn_and_kernel, momentum, learning_rate, t, loss): key, x_test, x_train, y_train = self._get_inputs( out_logits, test_shape, train_shape) params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) g_dd = ntk(x_train, None, 'ntk') g_td = ntk(x_test, x_train, 'ntk') # Regress to an MSE loss. loss_fn = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2) grad_loss = jit(grad(lambda params, x: loss_fn(f(params, x), y_train))) trace_axes = () if g_dd.ndim == 4 else (-1, ) if loss == 'mse_analytic': if momentum is not None: raise absltest.SkipTest(momentum) predictor = predict.gradient_descent_mse( g_dd, y_train, learning_rate=learning_rate, trace_axes=trace_axes) elif loss == 'mse': predictor = predict.gradient_descent(loss_fn, g_dd, y_train, learning_rate=learning_rate, momentum=momentum, trace_axes=trace_axes) else: raise NotImplementedError(loss) predictor = jit(predictor) fx_train_0 = f(params, x_train) fx_test_0 = f(params, x_test) self._test_zero_time(predictor, fx_train_0, fx_test_0, g_td, momentum) self._test_multi_step(predictor, fx_train_0, fx_test_0, g_td, momentum) if loss == 'mse_analytic': self._test_inf_time(predictor, fx_train_0, fx_test_0, g_td, y_train) if momentum is None: opt_init, opt_update, get_params = optimizers.sgd(learning_rate) else: opt_init, opt_update, get_params = optimizers.momentum( learning_rate, momentum) opt_state = opt_init(params) for i in range(t): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) params = get_params(opt_state) fx_train_nn, fx_test_nn = f(params, x_train), f(params, x_test) fx_train_t, fx_test_t = predictor(t, fx_train_0, fx_test_0, g_td) self.assertAllClose(fx_train_nn, fx_train_t, rtol=RTOL, atol=ATOL) self.assertAllClose(fx_test_nn, fx_test_t, rtol=RTOL, atol=ATOL)
def main(unused_argv): # Build data and . print('Loading data.') x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', permute_train=True) # Build the network init_fn, f, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) _, params = init_fn(key, (1, 784)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = momentum(FLAGS.learning_rate, 0.9) # opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # momentum = MomentumOptimizer(learning_rate=FLAGS.learning_rate, momentum=0.9) # momentum_lin = MomentumOptimizer(learning_rate=FLAGS.learning_rate, momentum=0.9) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 steps_per_epoch = 50000 // FLAGS.batch_size for i, (x, y) in enumerate( datasets.minibatch(x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) # x = np.asarray(x) # y = np.asarray(y) # momentum.apply_gradients((grad_loss(params, x, y), params)) # momentum.apply_gradients((grad_loss_lin(params_lin, x, y), params_lin)) if i % steps_per_epoch == 0: print('{}\t{}\t{}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 # Print out summary data comparing the linear / nonlinear model. x, y = x_train[:10000], y_train[:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def gradient_descent( loss: Callable[[np.ndarray, np.ndarray], float], k_train_train: np.ndarray, y_train: np.ndarray, learning_rate: float = 1., momentum: float = None, trace_axes: Axes = (-1, ) ) -> Callable[[ ArrayOrScalar, Union[ArrayOrScalar, ODEState], ArrayOrScalar, Optional[np.ndarray] ], Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]]: r"""Predicts the outcome of function space training using gradient descent. Uses an ODE solver. If `momentum != None`, solves a continuous-time version of gradient descent with momentum (note: this case uses standard momentum as opposed to Nesterov momentum). Solves the function space ODE for [momentum] gradient descent with a given `loss` (detailed in [*]) given a Neural Tangent Kernel[s] over the dataset[s] at arbitrary time[s] (step[s]) `t`. Note that for gradient descent `absolute_time = learning_rate * t` and the scales of the learning rate and query step[s] `t` are interchangeable. However, the momentum gradient descent ODE is solved in the units of `learning_rate**0.5`, and therefore `absolute_time = learning_rate**0.5 * t`, hence the `learning_rate` and training time[s] (step[s]) `t` scales are not interchangeable. [*] https://arxiv.org/abs/1902.06720 Example: >>> from neural_tangents import empirical_ntk_fn >>> from neural_tangents import predict >>> >>> t = 1e-7 >>> learning_rate = 1e-2 >>> momentum = 0.9 >>> >>> kernel_fn = empirical_ntk_fn(f) >>> k_test_train = kernel_fn(x_test, x_train, params) >>> >>> from jax.experimental import stax >>> cross_entropy = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat) >>> predict_fn = predict.gradient_descent(cross_entropy, k_train_train, >>> y_train, learning_rate, momentum) >>> >>> fx_train_0 = f(params, x_train) >>> fx_test_0 = f(params, x_test) >>> >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, >>> k_test_train) Args: loss: a loss function whose signature is `loss(f(x_train), y_train)`. Note: the loss function should treat the batch and output dimensions symmetrically. k_train_train: kernel on the training data. Must have the shape of `zip(y_train.shape, y_train.shape)` with `trace_axes` absent. y_train: targets for the training data. learning_rate: learning rate, step size. momentum: momentum scalar. trace_axes: `f(x_train)` axes such that `k_train_train` lacks these pairs of dimensions and is to be interpreted as :math:`\Theta \otimes I`, i.e. block-diagonal along `trace_axes`. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, `trace_axes=()` will yield most accurate result. Returns: A function that returns output train [and test] set[s] predictions at time[s] `t`. """ _, odd, _, _ = _get_axes(k_train_train) trace_axes = utils.canonicalize_axis(trace_axes, y_train) non_t_axes = tuple(a for a in range(y_train.ndim) if a not in trace_axes) last_t_axes = range(-len(trace_axes), 0) dtype = k_train_train.dtype grad_loss = grad(lambda fx: loss(fx, y_train)) if momentum is not None: learning_rate **= 0.5 momentum = (momentum - 1.0) / learning_rate def get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape): if isinstance(fx_train_or_state_0, ODEState): fx_train_0 = fx_train_or_state_0.fx_train fx_test_0 = fx_train_or_state_0.fx_test qx_train_0 = fx_train_or_state_0.qx_train qx_test_0 = fx_train_or_state_0.qx_test else: fx_train_0 = fx_train_or_state_0 qx_train_0 = qx_test_0 = None if fx_train_0 is None: fx_train_0 = np.zeros_like(y_train, dtype) else: fx_train_0 = np.broadcast_to(fx_train_0, y_train.shape) if fx_test_0 is not None: fx_test_0 = np.broadcast_to(fx_test_0, fx_test_shape) if momentum is None: if qx_train_0 is not None or qx_test_0 is not None: raise ValueError('Got passed momentum state variables, while ' '`momentum is None`.') else: qx_train_0 = (np.zeros_like(y_train, dtype) if qx_train_0 is None else np.broadcast_to(qx_train_0, y_train.shape)) qx_test_0 = (None if fx_test_0 is None else (np.zeros(fx_test_shape, dtype) if qx_test_0 is None else np.broadcast_to(qx_test_0, fx_test_shape))) return ODEState(fx_train_0, fx_test_0, qx_train_0, qx_test_0) # pytype: disable=wrong-arg-count def get_dstate_dt(k_test_train): def dstate_dt(state_t: ODEState, unused_t) -> ODEState: fx_train_t, fx_test_t, qx_train_t, qx_test_t = (state_t.fx_train, state_t.fx_test, state_t.qx_train, state_t.qx_test) dy_df_t = grad_loss(fx_train_t) fx_train_t = -np.moveaxis( np.tensordot(k_train_train, dy_df_t, (odd, non_t_axes)), last_t_axes, trace_axes) if fx_test_t is not None: fx_test_t = -np.moveaxis( np.tensordot(k_test_train, dy_df_t, (odd, non_t_axes)), last_t_axes, trace_axes) if momentum is None: return ODEState(fx_train_t, fx_test_t) # pytype: disable=wrong-arg-count fx_train_t += momentum * qx_train_t if qx_test_t is not None: fx_test_t += momentum * qx_test_t return ODEState(qx_train_t, qx_test_t, fx_train_t, fx_test_t) # pytype: disable=wrong-arg-count return dstate_dt def predict_fn( t: ArrayOrScalar = None, fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0., fx_test_0: ArrayOrScalar = None, k_test_train: np.ndarray = None ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]: """Return output predictions on train [and test] set[s] at time[s] `t`. Args: t: a scalar or array of scalars of any shape in strictly increasing order. `t=None` is equivalent to `t=np.inf` and may not converge. Equivalent of training steps (but can be fractional). fx_train_or_state_0: either (a) output of the network at `t == 0` on the training set or (b) complete ODE state (`predict.ODEState`). Pass an ODE state if you want to operate on the full ODE state instead of output variables only (useful for inspecting auxiliary variables or resuming an optimizer with auxiliary variables from a specific state. Note that only `momentum != None` optimizer currently has auxiliary variables. To initialize an ODE state from scratch, call `predict.ODEState(fx_train_0, fx_test_0)`. If an ODE state is passed, an ODE state is returned. `fx_train_0=None` means to not compute predictions on the training set. fx_test_0: output of the network at `t == 0` on the test set. `fx_test_0=None` means to not compute predictions on the test set. k_test_train: kernel relating test data with training data. Must have the shape of `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass `k_test_train=None` if you only need predictions on the training set. Returns: `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with potentially additional leading time dimensions matching `t.shape`. Alternatively can return an `ODEState` at time[s] `t`. Raises: ValueError: if `fx_test_0` is not `None`, but `k_test_train` is `None`. """ _check_inputs(fx_train_or_state_0, fx_test_0, k_test_train) t = np.array(t if t is not None else np.inf, dtype) * learning_rate t_shape = t.shape t = t.reshape((-1, )) # ODE solver requires `t[0]` to be the time where `fx_train_0` [and # `fx_test_0`] are evaluated, but also a strictly increasing sequence of # timesteps, so we always temporarily append an [almost] `0` at the start. identity = lambda x: x t0 = np.where(t[0] == 0, np.full((1, ), -1e-24, t.dtype), np.zeros((1, ), t.dtype)) t = np.concatenate([t0, t]) # Solve the ODE. fx_test_shape = _get_fx_test_shape(y_train, k_test_train, trace_axes) state_0 = get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape) state_t = ode.odeint(get_dstate_dt(k_test_train), state_0, t) # Remove the added `t0`. trim = lambda x: x[1:].reshape(t_shape + x.shape[1:]) trim_tree = lambda tree: tree_map(trim, tree) state_t = trim_tree(state_t) # `ODEState` -> `ODEState` if isinstance(fx_train_or_state_0, ODEState): return state_t # `np.ndarray` -> `np.ndarray` fx_train_t, fx_test_t = state_t.fx_train, state_t.fx_test if fx_train_or_state_0 is not None and fx_test_0 is None: return fx_train_t if fx_test_0 is not None and fx_train_or_state_0 is None: return fx_test_t return fx_train_t, fx_test_t return predict_fn