def test_behler_parrinello_network_neighbor_list(self, N_types, dtype): key = random.PRNGKey(1) R = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 0]], dtype) species = np.array([1, 1, N_types]) if N_types > 1 else None box_size = f32(1.5) displacement, _ = space.periodic(box_size) neighbor_fn, nn_init, nn_apply = energy.behler_parrinello_neighbor_list( displacement, box_size, species) nbrs = neighbor_fn(R) params = nn_init(key, R, nbrs) nn_force_fn = grad(nn_apply, argnums=1) nn_force = jit(nn_force_fn)(params, R, nbrs) nn_energy = jit(nn_apply)(params, R, nbrs) self.assertAllClose(np.any(np.isnan(nn_energy)), False) self.assertAllClose(np.any(np.isnan(nn_force)), False) self.assertAllClose(nn_force.shape, [3, 3])
def testGammaGrad(self, alpha): rng = random.PRNGKey(0) alphas = onp.full((100, ), alpha) z = random.gamma(rng, alphas) actual_grad = api.grad(lambda x: random.gamma(rng, x).sum())(alphas) eps = 0.01 * alpha / (1.0 + onp.sqrt(alpha)) cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps) - scipy.stats.gamma.cdf(z, alpha - eps)) / (2 * eps) pdf = scipy.stats.gamma.pdf(z, alpha) expected_grad = -cdf_dot / pdf self.assertAllClose( actual_grad, expected_grad, check_dtypes=True, rtol=2e-2 if jtu.device_under_test() == "tpu" else 5e-4)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = batch(get_ntk_fun_empirical(f), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = predict.gradient_descent_mse(g_dd, y_train, g_td) # Get initial values of the network in function space. fx_train = f(params, x_train) fx_test = f(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, f(params, x_train), fx_train, loss) util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
def test_custom_root_scalar(self): # TODO(shoyer): Figure out why this fails and re-enable it, if possible. My # best guess is that TPUs use less stable numerics for pow(). if jtu.device_under_test() == "tpu": raise SkipTest("Test fails on TPU") def scalar_solve(f, y): return y / f(1.0) def binary_search(func, x0, low=0.0, high=100.0, tolerance=1e-6): del x0 # unused def cond(state): low, high = state return high - low > tolerance def body(state): low, high = state midpoint = 0.5 * (low + high) update_upper = func(midpoint) > 0 low = np.where(update_upper, low, midpoint) high = np.where(update_upper, midpoint, high) return (low, high) solution, _ = lax.while_loop(cond, body, (low, high)) return solution def sqrt_cubed(x, tangent_solve=scalar_solve): f = lambda y: y ** 2. - np.array(x) ** 3. return lax.custom_root(f, 0.0, binary_search, tangent_solve) value, grad = api.value_and_grad(sqrt_cubed)(5.0) self.assertAllClose(value, 5 ** 1.5, check_dtypes=False, rtol=1e-6) self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False, rtol=1e-7) jtu.check_grads(sqrt_cubed, (5.0,), order=2, rtol=1e-3) # TODO(shoyer): reenable when batching works # inputs = np.array([4.0, 5.0]) # results = api.vmap(sqrt_cubed)(inputs) # self.assertAllClose(results, inputs ** 1.5, check_dtypes=False) results = api.jit(sqrt_cubed)(5.0) self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False, rtol={onp.float64:1e-7})
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, f = stax.serial(layers.Dense(4096), stax.Tanh, layers.Dense(10)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. theta = tangents.ntk(f, batch_size=32) g_dd = theta(params, x_train) g_td = theta(params, x_test, x_train) predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td) # Get initial values of the network in function space. fx_train = f(params, x_train) fx_test = f(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(fx_train, fx_test, FLAGS.train_time) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, f(params, x_train), fx_train, loss) util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes, rng_factory, rng_idx_factory): rng = rng_factory() rng_idx = rng_idx_factory() fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx)))) operand = rng(shape, dtype) assert operand.shape[op_axis] == idxs.shape[idxs_axis] ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs) expected = onp.stack([ gfun(operand[(slice(None), ) * op_axis + (i, )], idxs[(slice(None), ) * idxs_axis + (i, )]) for i in range(idxs.shape[idxs_axis]) ]) self.assertAllClose(ans, expected, check_dtypes=False)
def testNpMaximumPerExampleGrad(self): R = onp.random.RandomState(0).randn x = R(10, 5) W = R(5, 5) fun = lambda W, x: np.sum(np.maximum(np.dot(x, W), 0.0) ** 2) ans = vmap(partial(grad(fun), W))(x) W_t = np.transpose(W) for i in range(10): x_ex = x[i:i + 1] expected_ans = 2.0 * np.dot( np.maximum(np.dot(W_t, np.transpose(x_ex)), 0.0), x_ex) expected_ans = np.transpose(expected_ans) self.assertAllClose(ans[i], expected_ans, check_dtypes=False)
def hvp(loss, params, batch, v): """Computes the hessian vector product Hv. This implementation uses forward-over-reverse mode for computing the hvp. Args: loss: function computing the loss with signature loss(params, batch). params: pytree for the parameters of the model. batch: A batch of data. Any format is fine as long as it is a valid input to loss(params, batch). v: pytree of the same structure as params. Returns: hvp: array of shape [num_params] equal to Hv where H is the hessian. """ loss_fn = lambda x: loss(x, batch) return jvp(grad(loss_fn), [params], [v])[1]
def test_defvjp_all_multiple_arguments(self): # also tests passing in symbolic zero tangents b/c we differentiate wrt only # the first argument in one case foo_p = Primitive('foo') def foo(x, y): return foo_p.bind(x, y) def vjpfun(x, y): out = x**2 + y**3 vjp = lambda g: (g + x + y, g * x * 9.) return out, vjp ad.defvjp_all(foo_p, vjpfun) val_ans, grad_ans = api.value_and_grad(foo)(3., 4.) self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False) self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False) ans = api.grad(foo, (0, 1))(3., 4.) self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
def test_sparse_grad(self): rng_sparse = rand_sparse(self.rng()) rng = jtu.rand_default(self.rng()) y = rng(5, "float32") X = rng_sparse((10, 5), "float32") Xsp = sparse.BCOO.fromdense(X) def f(X, y): return jnp.sum(X @ y) grad_dense = api.grad(f, argnums=0)(X, y) grad_sparse = sparse.grad(f, argnums=0)(Xsp, y) # extract sparse gradient from dense gradient indices = tuple(Xsp.indices) grad_sparse_from_dense = jnp.zeros_like(grad_dense).at[indices].set(grad_dense[indices]) self.assertArraysEqual(grad_sparse.todense(), grad_sparse_from_dense)
def test_remat_grad_python_control_flow(self): @partial(api.remat, concrete=True) def g(x): if x > 0: return lax.sin(x), 3. else: return lax.cos(x), 4. def f(x): x, _ = g(x) return x ans = f(2.) expected = onp.sin(2.) self.assertAllClose(ans, expected, check_dtypes=False) ans = api.grad(f)(2.) expected = onp.cos(2.) self.assertAllClose(ans, expected, check_dtypes=False)
def testNpMaximumPerExampleGrad(self): R = np.random.RandomState(0).randn x = R(10, 5) W = R(5, 5) fun = lambda W, x: jnp.sum(jnp.maximum(jnp.dot(x, W), 0.0) ** 2) ans = vmap(partial(grad(fun), W))(x) W_t = jnp.transpose(W) for i in range(10): x_ex = x[i:i + 1] expected_ans = 2.0 * jnp.dot( jnp.maximum(jnp.dot(W_t, jnp.transpose(x_ex)), 0.0), x_ex) expected_ans = jnp.transpose(expected_ans) self.assertAllClose( ans[i], expected_ans, check_dtypes=False, atol={np.float32:5e-2} if jtu.device_under_test() == "tpu" else None)
def test_grad_simple(self): def func(x): y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) return x * hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream) grad_func = api.grad(func) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(grad_func)(5.))) with hcb.outfeed_receiver(): res_grad = grad_func(jnp.float32(5.)) self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False) assertMultiLineStrippedEqual(self, """ what: x * 2 10.00 what: y * 3 30.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: y * 3 5.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 15.00""", testing_stream.output) testing_stream.reset()
def test_soft_sphere(self, spatial_dimension, alpha, dtype): key = random.PRNGKey(0) alpha = f32(alpha) for _ in range(STOCHASTIC_SAMPLES): key, split_sigma, split_epsilon = random.split(key, 3) sigma = np.array(random.uniform( split_sigma, (1,), minval=0.0, maxval=3.0)[0], dtype=dtype) epsilon = np.array( random.uniform(split_epsilon, (1,), minval=0.0, maxval=4.0)[0], dtype=dtype) self.assertAllClose( energy.soft_sphere( dtype(0), sigma, epsilon, alpha), epsilon / alpha, True) self.assertAllClose( energy.soft_sphere(dtype(sigma), sigma, epsilon, alpha), np.array(0.0, dtype=dtype), True) if alpha == 3.0: grad_energy = grad(energy.soft_sphere) g = grad_energy(dtype(sigma), sigma, epsilon, alpha) self.assertAllClose(g, np.array(0, dtype=dtype), True)
def test_root_scalar(self): def scalar_solve(f, y): return y / f(1.0) def binary_search(func, x0, low=0.0, high=100.0, tolerance=1e-6): del x0 # unused def cond(state): low, high = state return high - low > tolerance def body(state): low, high = state midpoint = 0.5 * (low + high) update_upper = func(midpoint) > 0 low = np.where(update_upper, low, midpoint) high = np.where(update_upper, midpoint, high) return (low, high) solution, _ = lax.while_loop(cond, body, (low, high)) return solution def sqrt_cubed(x, tangent_solve=scalar_solve): f = lambda y: y ** 2 - x ** 3 return lax.root(f, 0.0, binary_search, tangent_solve) value, grad = api.value_and_grad(sqrt_cubed)(5.0) self.assertAllClose(value, 5 ** 1.5, check_dtypes=False) self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False) jtu.check_grads(sqrt_cubed, (5.0,), order=2, rtol=1e-3) # TODO(shoyer): reenable when batching works # inputs = np.array([4.0, 5.0]) # results = api.vmap(sqrt_cubed)(inputs) # self.assertAllClose(results, inputs ** 1.5, check_dtypes=False) results = api.jit(sqrt_cubed)(5.0) self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False)
def test_morse(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split_sigma, split_epsilon, split_alpha = random.split(key, 4) sigma = dtype( random.uniform(split_sigma, (1, ), minval=0., maxval=3.0)[0]) epsilon = dtype( random.uniform(split_epsilon, (1, ), minval=0.0, maxval=4.0)[0]) alpha = dtype( random.uniform(split_alpha, (1, ), minval=1.0, maxval=30.0)[0]) dr = dtype(sigma) self.assertAllClose(energy.morse(dr, sigma, epsilon, alpha), np.array(-epsilon, dtype=dtype)) g = grad(energy.morse)(dr, sigma, epsilon, alpha) self.assertAllClose(g, np.array(0, dtype=dtype)) # if dr = a/alpha + sigma, then V_morse(dr, sigma, epsilon, alpha)/epsilon # should be independent of sigma, epsilon, and alpha, depending only on a. key, split_sigma, split_epsilon, split_alpha = random.split(key, 4) sigmas = random.uniform(split_sigma, (STOCHASTIC_SAMPLES, ), minval=0., maxval=3.0) epsilons = random.uniform(split_epsilon, (STOCHASTIC_SAMPLES, ), minval=0.1, maxval=4.0) alphas = random.uniform(split_alpha, (STOCHASTIC_SAMPLES, ), minval=1.0, maxval=30.0) for sigma, epsilon, alpha in zip(sigmas, epsilons, alphas): a = np.linspace(max(-2.5, -alpha * sigma), 8.0, 100) dr = np.array(a / alpha + sigma, dtype=dtype) U = energy.morse(dr, sigma, epsilon, alpha) / dtype(epsilon) Ucomp = np.array((dtype(1) - np.exp(-a))**dtype(2) - dtype(1), dtype=dtype) self.assertAllClose(U, Ucomp)
def testStopGradient(self): def f(x): return lax.sin(x) * lax.cos(lax.stop_gradient(x)) def f2(x, y): return lax.sin(x) * lax.cos(y) x = 3.14 ans = api.grad(f)(x) expected = api.grad(f2)(x, x) self.assertAllClose(ans, expected) ans = api.grad(api.grad(f))(x) expected = api.grad(api.grad(f2))(x, x) self.assertAllClose(ans, expected) ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.) expected = onp.array(0.0) self.assertAllClose(ans, expected, check_dtypes=False) with core.skipping_checks(): with self.assertRaises(TypeError): lax.stop_gradient(lambda x: x)
def main(unused_argv): from jax.api import grad, jit, vmap, pmap, device_put "The following is required to use TPU Driver as JAX's backend." if FLAGS.TPU: config.FLAGS.jax_xla_backend = "tpu_driver" config.FLAGS.jax_backend_target = "grpc://" + os.environ[ 'TPU_ADDR'] + ':8470' TPU_ADDR = os.environ['TPU_ADDR'] ndevices = xla_bridge.device_count() if not FLAGS.TPU: ndevices = 1 pmap = partial(pmap, axis_name='i') """Setup some experiment parameters.""" meas_step = FLAGS.meas_step training_epochs = int(FLAGS.epochs) tmult = 1.0 if FLAGS.physical: tmult = FLAGS.lr if FLAGS.physicalL2: tmult = FLAGS.L2 * tmult if FLAGS.physical: training_epochs = 1 + int(FLAGS.epochs / tmult) print('Evolving for {:}e'.format(training_epochs)) losst = FLAGS.losst learning_rate = FLAGS.lr batch_size_per_device = FLAGS.bs N = FLAGS.N K = FLAGS.K batch_size = batch_size_per_device * ndevices steps_per_epoch = 50000 // batch_size training_steps = training_epochs * steps_per_epoch "Filename from FLAGS" filename = 'wrnL2_' + losst + '_n' + str(N) + '_k' + str(K) if FLAGS.momentum: filename += '_mom' if FLAGS.L2_sch: filename += '_L2sch' + '_decay' + str(FLAGS.L2dec) + '_del' + str( FLAGS.delay) if FLAGS.seed != 1: filename += 'seed' + str(FLAGS.seed) filename += '_L2' + str(FLAGS.L2) if FLAGS.std_wrn_sch: filename += '_stddec' if FLAGS.physical: filename += 'phys' else: filename += '_ctlr' if not FLAGS.augment: filename += '_noaug' if not FLAGS.mix: filename += '_nomixup' filename += '_bs' + str(batch_size) + '_lr' + str(learning_rate) if FLAGS.jobdir is not None: filedir = os.path.join('wrnlogs', FLAGS.jobdir) else: filedir = 'wrnlogs' if not os.path.exists(filedir): os.makedirs(filedir) filedir = os.path.join(filedir, filename + '.csv') print('Saving log to ', filename) print('Found {} cores.'.format(ndevices)) """Load CIFAR10 data and create a minimal pipeline.""" train_images, train_labels, test_images, test_labels = utils.load_data( 'cifar10') train_images = np.reshape(train_images, (-1, 32, 32 * 3)) train = (train_images, train_labels) test = (test_images, test_labels) k = train_labels.shape[-1] train = utils.shard_data(train, ndevices) test = utils.shard_data(test, ndevices) """Create a Wide Resnet and replicate its parameters across the devices.""" initparams, f, _ = utils.WideResnetnt(N, K, k) "Loss and optimizer definitions" l2_norm = lambda params: tree_map(lambda x: np.sum(x**2), params) l2_reg = lambda params: tree_reduce(lambda x, y: x + y, l2_norm(params)) currL2 = FLAGS.L2 L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, ))) def xentr(params, images_and_labels): images, labels = images_and_labels return -np.mean(stax.logsoftmax(f(params, images)) * labels) def mse(params, data_tuple): """MSE loss.""" x, y = data_tuple return 0.5 * np.mean((y - f(params, x))**2) if losst == 'xentr': print('Using xentr') lossm = xentr else: print('Using mse') lossm = mse loss = lambda params, data, L2: lossm(params, data) + L2 * l2_reg(params) def accuracy(params, images_and_labels): images, labels = images_and_labels return np.mean( np.array(np.argmax(f(params, images), axis=1) == np.argmax(labels, axis=1), dtype=np.float32)) "Define optimizer" if FLAGS.std_wrn_sch: lr = learning_rate first_epoch = int(60 / 200 * training_epochs) learning_rate_fn = optimizers.piecewise_constant( np.array([1, 2, 3]) * first_epoch * steps_per_epoch, np.array([lr, lr * 0.2, lr * 0.2**2, lr * 0.2**3])) else: learning_rate_fn = optimizers.make_schedule(learning_rate) if FLAGS.momentum: momentum = 0.9 else: momentum = 0 @pmap def update_step(step, state, batch_state, L2): batch, batch_state = batch_fn(batch_state) params = get_params(state) dparams = grad_loss(params, batch, L2) dparams = tree_map(lambda x: lax.psum(x, 'i') / ndevices, dparams) return step + 1, apply_fn(step, dparams, state), batch_state @pmap def evaluate(state, data, L2): params = get_params(state) lossmm = lossm(params, data) l2mm = l2_reg(params) return lossmm + L2 * l2mm, accuracy(params, data), lossmm, l2mm "Initialization and loading" _, params = initparams(random.PRNGKey(0), (-1, 32, 32, 3)) replicate_array = lambda x: \ np.broadcast_to(x, (ndevices,) + x.shape) replicated_params = tree_map(replicate_array, params) grad_loss = jit(grad(loss)) init_fn, apply_fn, get_params = optimizers.momentum( learning_rate_fn, momentum) apply_fn = jit(apply_fn) key = random.PRNGKey(FLAGS.seed) batchinit_fn, batch_fn = utils.sharded_minibatcher(batch_size, ndevices, transform=FLAGS.augment, k=k, mix=FLAGS.mix) batch_state = pmap(batchinit_fn)(random.split(key, ndevices), train) state = pmap(init_fn)(replicated_params) if FLAGS.checkpointing: ## Loading of checkpoint if available/provided. single_state = init_fn(params) i0, load_state, load_params, filename0, batch_stateb = utils.load_weights( filename, single_state, params, full_file=FLAGS.load_w, ndevices=ndevices) if i0 is not None: filename = filename0 if batch_stateb is not None: batch_state = batch_stateb if load_params is not None: state = pmap(init_fn)(load_params) else: state = load_state else: i0 = 0 else: i0 = 0 if FLAGS.steps_from_load: training_steps = i0 + training_steps batch_xs, _ = pmap(batch_fn)(batch_state) train_loss = [] train_accuracy = [] lrL = [] test_loss = [] test_accuracy = [] test_L2, test_lm, train_lm, train_L2 = [], [], [], [] L2_t = [] idel0 = i0 start = time.time() step = pmap(lambda x: x)(i0 * np.ones((ndevices, ))) "Start training loop" if FLAGS.checkpointing: print('Evolving for {:}e and saving every {:}s'.format( training_epochs, FLAGS.checkpointing)) print( 'Epoch\tLearning Rate\tTrain bareLoss\t L2_norm \tTest Loss\tTrain Error\tTest Error\tTime / Epoch' ) for i in range(i0, training_steps): if i % meas_step == 0: # Make Measurement l, a, lm, L2m = evaluate(state, test, L2p) test_loss += [np.mean(l)] test_accuracy += [np.mean(a)] test_lm += [np.mean(lm)] test_L2 += [np.mean(L2m)] train_batch, _ = pmap(batch_fn)(batch_state) l, a, lm, L2m = evaluate(state, train_batch, L2p) train_loss += [np.mean(l)] train_accuracy += [np.mean(a)] train_lm += [np.mean(lm)] train_L2 += [np.mean(L2m)] L2_t.append(currL2) lrL += [learning_rate_fn(i)] if FLAGS.L2_sch and i > FLAGS.delay / currL2 + idel0 and len( train_lm) > 2 and ((minloss <= train_lm[-1] and minloss <= train_lm[-2]) or (maxacc >= train_accuracy[-1] and maxacc >= train_accuracy[-2])): # If AutoL2 is on and we are beyond the refractory period, decay if the loss or error have increased in the last two measurements. print('Decaying L2 to', currL2 / FLAGS.L2dec) currL2 = currL2 / FLAGS.L2dec L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, ))) idel0 = i elif FLAGS.L2_sch and len(train_lm) >= 2: # Update the minimum values. try: maxacc = max(train_accuracy[-2], maxacc) minloss = min(train_lm[-2], minloss) except: maxacc, minloss = train_accuracy[-2], train_lm[-2] if i % (meas_step * 10) == 0 or i == i0: # Save measurements to csv epoch = batch_size * i / 50000 dt = (time.time() - start) / (meas_step * 10) * steps_per_epoch print(('{}\t' + ('{: .4f}\t' * 7)).format( epoch, learning_rate_fn(i), train_lm[-1], train_L2[-1], test_loss[-1], train_accuracy[-1], test_accuracy[-1], dt)) start = time.time() data = { 'train_loss': train_loss, 'test_loss': test_loss, 'train_acc': train_accuracy, 'test_acc': test_accuracy } data['train_bareloss'] = train_lm data['train_L2'] = train_L2 data['test_bareloss'] = test_lm data['test_L2'] = test_L2 data['L2_t'] = L2_t df = pd.DataFrame(data) df['learning_rate'] = lrL df['width'] = K df['batch_size'] = batch_size df['step'] = i0 + onp.arange(0, len(train_loss)) * meas_step df.to_csv(filedir, index=False) if FLAGS.checkpointing: ### SAVE MODEL if i % FLAGS.checkpointing == 0 and i > i0: if not os.path.exists('weights/'): os.makedirs('weights/') saveparams = tree_flatten(state[0])[0] if ndevices > 1: saveparams = [el[0] for el in saveparams] saveparams = np.concatenate( [el.reshape(-1) for el in saveparams]) step0 = i print('Step', i) print('saving at', filename, step0, 'size:', saveparams.shape) utils.save_weights(filename, step0, saveparams, batch_state) ## UPDATE step, state, batch_state = update_step(step, state, batch_state, L2p) print('Training done') if FLAGS.TPU: with open('done/' + TPU_ADDR, 'w') as fp: fp.write(filedir) pass
def update(i, opt_state): params = minmax.get_params(opt_state) gradient = grad(objective)(params, i) return opt_update(i, gradient, opt_state)
def apply_carry(carry, _): i, x = carry new_x = x - 0.1 * api.grad(energy_fn)(x) new_carry = (i + 1, new_x) return new_carry, _
def apply_carry(x, i): return api.grad(fn, argnums=(0, ))(x)[0], i
st.header("Comparison against finite SGD-NNs") """ Finally, let's bring back the practioner loved Neural Networks for a comparison. """ learning_rate = st.slider("Learning rate", 1e-4, 1.0, 0.1, step=1e-4, format="%.4f") opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_update = jit(opt_update) loss = jit(lambda params, x, y: 0.5 * np.mean((apply_fn(params, x) - y)**2)) grad_loss = jit(lambda state, x, y: grad(loss)(get_params(state), x, y)) train_losses = [] test_losses = [] opt_state = opt_init(params) for i in range(training_steps): opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state) train_losses += [loss(get_params(opt_state), *train)] test_losses += [loss(get_params(opt_state), *test)] # NTK loss plt.loglog(ts, ntk_train_loss_mean, linewidth=3) plt.loglog(ts, ntk_test_loss_mean, linewidth=3)
def testScanRnn(self): r = npr.RandomState(0) n_in = 4 n_hid = 2 n_out = 1 length = 3 W_trans = r.randn(n_hid, n_hid + n_in) W_out = r.randn(n_out, n_hid + n_in) params = W_trans, W_out inputs = r.randn(length, n_in) targets = r.randn(length, n_out) def step(params, state, input): W_trans, W_out = params stacked = np.concatenate([state, input]) output = np.tanh(np.dot(W_out, stacked)) next_state = np.tanh(np.dot(W_trans, stacked)) return next_state, output def rnn(params, inputs): init_state = np.zeros(n_hid) _, outputs = lax.scan(partial(step, params), init_state, inputs) return outputs def loss(params, inputs, targets): predictions = rnn(params, inputs) return np.sum((predictions - targets)**2) # evaluation doesn't crash loss(params, inputs, targets) # jvp evaluation doesn't crash api.jvp(lambda params: loss(params, inputs, targets), (params, ), (params, )) # jvp numerical check passes jtu.check_grads(loss, (params, inputs, targets), order=2, modes=["fwd"]) # linearize works _, expected = api.jvp(loss, (params, inputs, targets), (params, inputs, targets)) _, linfun = api.linearize(loss, params, inputs, targets) ans = linfun(params, inputs, targets) self.assertAllClose(ans, expected, check_dtypes=False) # gradient evaluation doesn't crash api.grad(loss)(params, inputs, targets) # gradient check passes jtu.check_grads(loss, (params, inputs, targets), order=2) # we can vmap to batch things batch_size = 7 batched_inputs = r.randn(batch_size, length, n_in) batched_targets = r.randn(batch_size, length, n_out) batched_loss = api.vmap(lambda x, y: loss(params, x, y)) losses = batched_loss(batched_inputs, batched_targets) expected = onp.stack( list( map(lambda x, y: loss(params, x, y), batched_inputs, batched_targets))) self.assertAllClose(losses, expected, check_dtypes=False)
def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state)
def update(params, batch): grads = grad(loss)(params, batch) return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
def update(params, batch): grads = grad(loss)(params, batch) return [(w - learning_rate * dw, b - learning_rate * db) for (w, b), (dw, db) in zip(params, grads)]
#new define of loss function def computation(params, inputs, targets): logits = predict(params, inputs) preds = stax.logsoftmax(logits) return -np.mean(np.sum(preds * targets, axis=1)) #set up of index tl = test_labels index7 = tl.tolist().index([0, 0, 0, 0, 0, 0, 0, 1, 0, 0]) print(test_labels[index7]) #computing process to the new x input_image, input_label = shape_as_image(test_images[index7], test_labels[index7]) grad_newx = grad(computation, 1)(params, input_image, input_label) newx = input_image + hyper * np.sign(grad_newx) #start plot and its predicted vector target_class = np.argmax(input_label) predicted_class = np.argmax(predict(params, newx)) #predicted vector predict_vector = predict(params, newx) print('the target class is :', target_class) print('the predict class is :', predicted_class) print('the predicted vector is :', predict_vector) image = np.array(newx) image = image * 255 image = image.reshape(28, 28) plt.imshow(image) """## From here is Part 2"""
def testGradOfXlog1pyAtZero(self): partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
def testTrainedEnsemblePredCov(self, train_shape, test_shape, network, out_logits): if xla_bridge.get_backend().platform == 'gpu' and config.read( 'jax_enable_x64'): raise jtu.SkipTest('Not running GPU x64 to save time.') training_steps = 5000 learning_rate = 1.0 ensemble_size = 50 init_fn, apply_fn, ker_fn = stax.serial( stax.Dense(1024, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_update = jit(opt_update) key = random.PRNGKey(0) key, = random.split(key, 1) key, split = random.split(key) x_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) train = (x_train, y_train) key, split = random.split(key) x_test = np.cos(random.normal(split, test_shape)) ensemble_key = random.split(key, ensemble_size) loss = jit(lambda params, x, y: 0.5 * np.mean( (apply_fn(params, x) - y)**2)) grad_loss = jit(lambda state, x, y: grad(loss) (get_params(state), x, y)) def train_network(key): _, params = init_fn(key, (-1, ) + train_shape[1:]) opt_state = opt_init(params) for i in range(training_steps): opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state) return get_params(opt_state) params = vmap(train_network)(ensemble_key) ensemble_fx = vmap(apply_fn, (0, None))(params, x_test) ensemble_loss = vmap(loss, (0, None, None))(params, x_train, y_train) ensemble_loss = np.mean(ensemble_loss) self.assertLess(ensemble_loss, 1e-5, True) mean_emp = np.mean(ensemble_fx, axis=0) mean_subtracted = ensemble_fx - mean_emp cov_emp = np.einsum( 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / ( mean_subtracted.shape[0] * mean_subtracted.shape[-1]) reg = 1e-7 ntk_predictions = predict.gp_inference(ker_fn, x_train, y_train, x_test, 'ntk', reg, compute_cov=True) self.assertAllClose(mean_emp, ntk_predictions.mean, True, RTOL, ATOL) self.assertAllClose(cov_emp, ntk_predictions.covariance, True, RTOL, ATOL)
def testNTKMomentumPrediction(self, train_shape, test_shape, network, out_logits, fn_and_kernel): key = random.PRNGKey(0) key, split = random.split(key) x_train = random.normal(split, train_shape) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) x_test = random.normal(split, test_shape) params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) # Regress to an MSE loss. loss = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2) grad_loss = jit(grad(lambda params, x: loss(f(params, x), y_train))) g_dd = ntk(x_train, None, 'ntk') g_td = ntk(x_test, x_train, 'ntk') atol = ATOL rtol = RTOL step_size = 0.5 if len(train_shape) > 2: # Hacky way to up the tolerance just for convolutions. atol = ATOL * 2 rtol = RTOL * 2 step_size = 0.1 train_time = 100.0 steps = int(train_time / np.sqrt(step_size)) init, predictor, get = predict.momentum(g_dd, y_train, loss, step_size, g_td) opt_init, opt_update, get_params = momentum(step_size, 0.9) opt_state = opt_init(params) fx_initial_train = f(params, x_train) fx_initial_test = f(params, x_test) lin_state = init(fx_initial_train, fx_initial_test) fx_pred_train, fx_pred_test = get(lin_state) self.assertAllClose(fx_initial_train, fx_pred_train, True) self.assertAllClose(fx_initial_test, fx_pred_test, True) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) params = get_params(opt_state) fx_train = f(params, x_train) fx_test = f(params, x_test) lin_state = predictor(lin_state, train_time) fx_pred_train, fx_pred_test = get(lin_state) fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train)**2)) fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test)**2)) fx_error_train = (fx_train - fx_pred_train) / fx_disp_train fx_error_test = (fx_test - fx_pred_test) / fx_disp_test self.assertAllClose(fx_error_train, np.zeros_like(fx_error_train), True, rtol, atol) self.assertAllClose(fx_error_test, np.zeros_like(fx_error_test), True, rtol, atol)