def testUnpackPackRoundTrip(self): opt_init, _, _ = optimizers.momentum(0.1, mass=0.9) params = [{'w': onp.random.randn(1, 2), 'bias': onp.random.randn(2)}] expected = opt_init(params) ans = optimizers.pack_optimizer_state( optimizers.unpack_optimizer_state(expected)) self.assertEqual(ans, expected)
def args_to_op(optimizer_string, lr, mom=0.9, var=0.999, eps=1e-7): return { "gd": lambda lr, *unused: op.sgd(lr), "sgd": lambda lr, *unused: op.sgd(lr), "momentum": lambda lr, mom, *unused: op.momentum(lr, mom), "adam": lambda lr, mom, var, eps: op.adam(lr, mom, var, eps), }[optimizer_string.lower()](lr, mom, var, eps)
def main(unused_argv): # Build data and . print('Loading data.') x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', permute_train=True) # Build the network init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum( FLAGS.learning_rate, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 steps_per_epoch = 50000 // FLAGS.batch_size for i, (x, y) in enumerate( datasets.minibatch(x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 # Print out summary data comparing the linear / nonlinear model. x, y = x_train[:10000], y_train[:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits, fn_and_kernel, momentum, learning_rate, t, loss): key, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape, train_shape) params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) g_dd = ntk(x_train, None, 'ntk') g_td = ntk(x_test, x_train, 'ntk') # Regress to an MSE loss. loss_fn = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2) grad_loss = jit(grad(lambda params, x: loss_fn(f(params, x), y_train))) trace_axes = () if g_dd.ndim == 4 else (-1,) if loss == 'mse_analytic': if momentum is not None: raise absltest.SkipTest(momentum) predictor = predict.gradient_descent_mse(g_dd, y_train, learning_rate=learning_rate, trace_axes=trace_axes) elif loss == 'mse': predictor = predict.gradient_descent(loss_fn, g_dd, y_train, learning_rate=learning_rate, momentum=momentum, trace_axes=trace_axes) else: raise NotImplementedError(loss) predictor = jit(predictor) fx_train_0 = f(params, x_train) fx_test_0 = f(params, x_test) self._test_zero_time(predictor, fx_train_0, fx_test_0, g_td, momentum) self._test_multi_step(predictor, fx_train_0, fx_test_0, g_td, momentum) if loss == 'mse_analytic': self._test_inf_time(predictor, fx_train_0, fx_test_0, g_td, y_train) if momentum is None: opt_init, opt_update, get_params = optimizers.sgd(learning_rate) else: opt_init, opt_update, get_params = optimizers.momentum(learning_rate, momentum) opt_state = opt_init(params) for i in range(t): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) params = get_params(opt_state) fx_train_nn, fx_test_nn = f(params, x_train), f(params, x_test) fx_train_t, fx_test_t = predictor(t, fx_train_0, fx_test_0, g_td) self.assertAllClose(fx_train_nn, fx_train_t, rtol=RTOL, atol=ATOL) self.assertAllClose(fx_test_nn, fx_test_t, rtol=RTOL, atol=ATOL)
def minimize(f, x, num_steps=10000, step_size=0.000001, mass=0.9): opt_init, opt_update, get_params = optimizers.momentum(step_size, mass) @jit def update(i, opt_state): x = get_params(opt_state) return opt_update(i, grad(f)(x), opt_state) opt_state = opt_init(x) for i in range(num_steps): opt_state = update(i, opt_state) return get_params(opt_state)
def get_optimizer(optimizer, sched, b1=0.9, b2=0.999): if optimizer.lower() == 'adagrad': return optimizers.adagrad(sched) elif optimizer.lower() == 'adam': return optimizers.adam(sched, b1, b2) elif optimizer.lower() == 'rmsprop': return optimizers.rmsprop(sched) elif optimizer.lower() == 'momentum': return optimizers.momentum(sched, 0.9) elif optimizer.lower() == 'sgd': return optimizers.sgd(sched) else: raise Exception('Invalid optimizer: {}'.format(optimizer))
def testMaxLearningRate(self, train_shape, network, out_logits, fn_and_kernel, lr_factor, momentum): key = random.PRNGKey(0) key, split = random.split(key) if len(train_shape) == 2: train_shape = (train_shape[0] * 5, train_shape[1] * 10) else: train_shape = (16, 8, 8, 3) x_train = random.normal(split, train_shape) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) # Regress to an MSE loss. loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train) ** 2) grad_loss = jit(grad(loss)) def get_loss(opt_state): return loss(get_params(opt_state), x_train) steps = 30 params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) g_dd = ntk(x_train, None, 'ntk') step_size = predict.max_learning_rate( g_dd, y_train_size=y_train.size, momentum=momentum) * lr_factor opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum) opt_state = opt_init(params) init_loss = get_loss(opt_state) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) trained_loss = get_loss(opt_state) loss_ratio = trained_loss / (init_loss + 1e-12) if lr_factor < 1.: self.assertLess(loss_ratio, 0.1) elif lr_factor == 1: # At the threshold, the loss decays slowly self.assertLess(loss_ratio, 1.) if lr_factor > 2.: if not math.isnan(loss_ratio): self.assertGreater(loss_ratio, 10.)
def omniglot(): n_way, n_support, n_query = 50, 15, 5 net_init, f = conv_net(n_output=n_way, n_conv_layer=4, n_filter=64, bias_coef=1, activation='relu', norm='None') _, params_init = net_init(rng=random.PRNGKey(42), input_shape=(-1, 28, 28, 1)) def loss(params, batch): inputs, targets = batch logits = f(params, inputs) outputs = logsoftmax(logits) return -np.sum(outputs * targets) / targets.shape[0] def accuracy(params, batch): inputs, targets = batch target_class = np.argmax(targets, axis=-1) predicted_class = np.argmax(f(params, inputs), axis=-1) return np.mean(predicted_class == target_class) splits = load_omniglot(n_support=n_support, n_query=n_query) task = omniglot_task(splits['train'], n_way=n_way, n_support=n_support, n_query=n_query) opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-0, mass=0.9) @jit def update(i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) opt_state = opt_init(params_init) n_update = 10000 for i in range(n_update): opt_state = update(i, opt_state, (task['x_train'], task['y_train'])) if i == 0 or (i + 1) % (n_update // 100) == 0: print( i, f"train loss: {loss(get_params(opt_state), (task['x_train'], task['y_train']))}," f"\ttest loss: {loss(get_params(opt_state), (task['x_test'], task['y_test']))}" ) trained_params = get_params(opt_state)
def subset_train(seed, subset_ratio): jrng = random.PRNGKey(seed) step_size = 0.1 num_epochs = 10 batch_size = 128 momentum_mass = 0.9 num_train_total = mnist_data['train_images'].shape[0] num_train = int(num_train_total * subset_ratio) num_batches = int(np.ceil(num_train / batch_size)) rng = npr.RandomState(seed) subset_idx = rng.choice(num_train_total, size=num_train, replace=False) train_images = mnist_data['train_images'][subset_idx] train_labels = mnist_data['train_labels'][subset_idx] def data_stream(shuffle=True): while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass) @jit def update(i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) _, init_params = init_random_params(jrng, (-1, 28 * 28)) opt_state = opt_init(init_params) itercount = itertools.count() for epoch in range(num_epochs): for _ in range(num_batches): opt_state = update(next(itercount), opt_state, next(batches)) params = get_params(opt_state) trainset_correctness = batch_correctness( params, (mnist_data['train_images'], mnist_data['train_labels'])) testset_correctness = batch_correctness( params, (mnist_data['test_images'], mnist_data['test_labels'])) trainset_mask = np.zeros(num_train_total, dtype=np.bool) trainset_mask[subset_idx] = True return trainset_mask, np.asarray(trainset_correctness), np.asarray(testset_correctness)
def optimizer(name="adam", momentum_mass=0.9, rmsprop_gamma=0.9, rmsprop_eps=1e-8, adam_b1=0.9, adam_b2=0.997, adam_eps=1e-8): """Return the optimizer, by name.""" if name == "sgd": return optimizers.sgd(learning_rate) if name == "momentum": return optimizers.momentum(learning_rate, mass=momentum_mass) if name == "rmsprop": return optimizers.rmsprop( learning_rate, gamma=rmsprop_gamma, eps=rmsprop_eps) if name == "adam": return optimizers.adam(learning_rate, b1=adam_b1, b2=adam_b2, eps=adam_eps) raise ValueError("Unknown optimizer %s" % str(name))
def main(): rng = random.PRNGKey(0) batch_size = 128 step_size = 0.001 num_epochs = 10 momentum_mass = 0.9 train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) # define data stream def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_indices = perm[i*batch_size:(i+1)*batch_size] yield train_images[batch_size], train_labels[batch_indices] batches = data_stream() # define optimizer opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass) @jit def update(i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) _, init_params = init_random_params(rng, (-1, 28*28)) opt_state = opt_init(init_params) itercount = itertools.count() print('\nStarting training...') for epoch in range(num_epochs): start_tm = time.time() for _ in range(num_epochs): opt_state = update(next(itercount), opt_state, next(batches)) epoch_tm = time.time() - start_tm params = get_params(opt_state) train_acc = accuracy(params, (train_images, train_labels)) test_acc = accuracy(params, (test_images, test_labels)) print(f'Epoch {epoch} in {epoch_tm:0.2f} sec') print(f'Training set accuracy {train_acc}') print(f'Test set accuracy {test_acc}') print('DONE')
def train_opt(loss_fn_xy, size, initial_params, lr, momentum): opt_init, opt_update, get_params = optimizers.momentum(lr, momentum) def step(step, opt_state): loss, grads = jax.value_and_grad(loss_fn_xy)(get_params(opt_state)) opt_state = opt_update(step, grads, opt_state) return loss, opt_state def scan_fn(opt_state, i): loss, opt_state = step(i, opt_state) return opt_state, {"loss": loss, "params": get_params(opt_state)} def train(initial_params): init_opt_state = opt_init(initial_params) opt_state, memo = jax.lax.scan(scan_fn, init_opt_state, jnp.arange(size)) return get_params(opt_state), memo return train(initial_params)
def sinusoid(): net_init, net_fn = mlp(n_output=1, n_hidden_layer=2, bias_coef=1.0, n_hidden_unit=40, activation='relu', norm='batch_norm') rng = random.PRNGKey(42) in_shape = (-1, 1) out_shape, net_params = net_init(rng, in_shape) def loss(params, batch): inputs, targets = batch predictions = net_fn(params, inputs) return np.mean((predictions - targets)**2) opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-2, mass=0.9) opt_update = jit(opt_update) @jit def step(i, opt_state, batch): params = get_params(opt_state) g = grad(loss)(params, batch) return opt_update(i, g, opt_state) task = sinusoid_task(n_support=1000, n_query=100) opt_state = opt_init(net_params) for i, (x, y) in enumerate( minibatch(task['x_train'], task['y_train'], batch_size=256, train_epochs=1000)): opt_state = step(i, opt_state, batch=(x, y)) if i == 0 or (i + 1) % 100 == 0: print( f"train loss: {loss(get_params(opt_state), (task['x_train'], task['y_train']))}," f"\ttest loss: {loss(get_params(opt_state), (task['x_test'], task['y_test']))}" )
def __init__(self, CNN="true", L=40, step_size=0.001, seed=0): """Defines neural network architecture, parameter initialization, and optimizer""" self.CNN = CNN if CNN: self.input_shape = (-1, L, L, 1) # following network gave highest accuracy on test data of all the ones I tried self.init_random_params, self.predict = stax.serial( PeriodicConv(out_chan=10, filter_shape=(2, 2), strides=(1, 1), padding='VALID'), Relu, #MaxPool(window_shape=(2, 2), strides=(2, 2), padding='VALID'), Flatten, Dense(100), Relu, # Dropout(0.4), # doesn't work yet since prng key has to be passed to predict() Dense(1), Sigmoid) else: self.input_shape = (-1, L * L) self.init_random_params, self.predict = stax.serial( Dense(100), Relu, Dense(100), Relu, # Dropout(0.4), # doesn't work yet since prng key has to be passed to predict() Dense(1), Sigmoid) momentum_mass = 0.9 self.opt_init, self.opt_update, self.get_params = optimizers.momentum( step_size, mass=momentum_mass) #self.opt_init, self.opt_update, self.get_params = optimizers.adam(0.0001) rng = random.PRNGKey(seed) _, self.init_params = self.init_random_params(rng, self.input_shape) self.opt_state = self.opt_init(self.init_params) self.params = self.init_params
def get_optimizer(self, optim=None, stage='learn', step_size=None): if optim is None: if stage == 'learn': optim = self.optim_learn else: optim = self.optim_proj if step_size is None: step_size = self.step_size if optim == 1: if self.verb > 2: print("With momentum optimizer") opt_init, opt_update, get_params = momentum(step_size=step_size, mass=0.95) elif optim == 2: if self.verb > 2: print("With rmsprop optimizer") opt_init, opt_update, get_params = rmsprop(step_size, gamma=0.9, eps=1e-8) elif optim == 3: if self.verb > 2: print("With adagrad optimizer") opt_init, opt_update, get_params = adagrad(step_size, momentum=0.9) elif optim == 4: if self.verb > 2: print("With Nesterov optimizer") opt_init, opt_update, get_params = nesterov(step_size, 0.9) elif optim == 5: if self.verb > 2: print("With SGD optimizer") opt_init, opt_update, get_params = sgd(step_size) else: if self.verb > 2: print("With adam optimizer") opt_init, opt_update, get_params = adam(step_size) return opt_init, opt_update, get_params
def ssvm_loss(params, x, y, lamb=0.01, max_steps=80, step_size=0.1, pretrain_global_energy=False): prediction = y is None x_hat = compute_feature_energy(params, x) if pretrain_global_energy: x_hat = lax.stop_gradient(x_hat) grad_fun = inference_step if prediction else cost_augmented_inference_step opt_init, opt_update, get_params = momentum(0.01, 0.95) # opt_state = opt_init(np.full(x.shape[:-1] + (LABELS,), 1. / LABELS)) opt_state = opt_init(np.zeros(x.shape[:-1] + (LABELS, ))) prev_energy = None for step in range(max_steps): y_hat = project(get_params(opt_state)) g, energy = grad_fun(y_hat, y, x_hat, params) opt_state = opt_update(step, g, opt_state) if step > 0 and check_saddle_point(step, get_params(opt_state), y_hat, energy, prev_energy): break prev_energy = energy y_hat = lax.stop_gradient(project(get_params(opt_state))) if prediction: return y_hat y = lax.stop_gradient(y) pred_energy = compute_global_energy(params, x_hat, y_hat) true_energy = compute_global_energy(params, x_hat, y) delta = np.square(y_hat - y).sum(axis=1) loss = np.mean(np.maximum(delta + true_energy - pred_energy, 0)) return loss + lamb * l2_norm(params)
train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() opt_init, opt_update = optimizers.momentum(step_size, mass=momentum_mass) @jit def update(i, opt_state, batch): params = optimizers.get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) _, init_params = init_random_params(rng, (-1, 28 * 28)) opt_state = opt_init(init_params) itercount = itertools.count() print("\nStarting training...") for epoch in range(num_epochs): start_time = time.time() for _ in range(num_batches): opt_state = update(next(itercount), opt_state, next(batches))
# Here we clone the rng used in computing the objective # so that we can show exactly the same samples. rngs = random.split(random.PRNGKey(t), num_samples) samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params) ax.plot(samples[:, 0], samples[:, 1], 'b.') plt.draw() plt.pause(1.0 / 60.0) # Set up optimizer. D = 2 init_mean = np.zeros(D) init_std = np.zeros(D) init_params = (init_mean, init_std) opt_init, opt_update = optimizers.momentum(step_size=0.1, mass=0.9) opt_state = opt_init(init_params) @jit def update(i, opt_state): params = optimizers.get_params(opt_state) gradient = grad(objective)(params, i) return opt_update(i, gradient, opt_state) # Main loop. print("Optimizing variational parameters...") for t in range(100): opt_state = update(t, opt_state) params = optimizers.get_params(opt_state) callback(params, t) plt.show(block=True)
train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass) @jit def update(i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) _, init_params = init_random_params(rng, (-1, 28 * 28)) opt_state = opt_init(init_params) itercount = itertools.count() print("\nStarting training...") for epoch in range(num_epochs): start_time = time.time() for _ in range(num_batches): opt_state = update(next(itercount), opt_state, next(batches))
def weight_space(train_embedding, test_embedding, data_set): init_fn, f, _ = stax.serial( stax.Dense(512, 1., 0.05), stax.Erf(), # 2 denotes 2 type of classes stax.Dense(2, 1., 0.05)) key = random.PRNGKey(0) # (-1, 135), 135 denotes the feature length, here is 9 * 15 = 135 _, params = init_fn(key, (-1, 135)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum(1.0, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 # Use whole batch batch_size = 64 train_epochs = 10 steps_per_epoch = 100 for i, (x, y) in enumerate( datasets.mini_batch(train_embedding, data_set['Y_train'], batch_size, train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 if i / steps_per_epoch == train_epochs: break # Print out summary data comparing the linear / nonlinear model. x, y = train_embedding[:10000], data_set['Y_train'][:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', data_set['Y_test'], f(params, test_embedding), f_lin(params_lin, test_embedding), loss)
# Here we clone the rng used in computing the objective # so that we can show exactly the same samples. rngs = random.split(random.PRNGKey(t), num_samples) samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params) ax.plot(samples[:, 0], samples[:, 1], 'b.') plt.draw() plt.pause(1.0 / 60.0) # Set up optimizer. D = 2 init_mean = jnp.zeros(D) init_std = jnp.zeros(D) init_params = (init_mean, init_std) opt_init, opt_update, get_params = optimizers.momentum(step_size=0.1, mass=0.9) opt_state = opt_init(init_params) @jit def update(i, opt_state): params = get_params(opt_state) gradient = grad(objective)(params, i) return opt_update(i, gradient, opt_state) # Main loop. print("Optimizing variational parameters...") for t in range(100): opt_state = update(t, opt_state) params = get_params(opt_state) callback(params, t) plt.show(block=True)
def main(unused_argv): from jax.api import grad, jit, vmap, pmap, device_put "The following is required to use TPU Driver as JAX's backend." if FLAGS.TPU: config.FLAGS.jax_xla_backend = "tpu_driver" config.FLAGS.jax_backend_target = "grpc://" + os.environ[ 'TPU_ADDR'] + ':8470' TPU_ADDR = os.environ['TPU_ADDR'] ndevices = xla_bridge.device_count() if not FLAGS.TPU: ndevices = 1 pmap = partial(pmap, axis_name='i') """Setup some experiment parameters.""" meas_step = FLAGS.meas_step training_epochs = int(FLAGS.epochs) tmult = 1.0 if FLAGS.physical: tmult = FLAGS.lr if FLAGS.physicalL2: tmult = FLAGS.L2 * tmult if FLAGS.physical: training_epochs = 1 + int(FLAGS.epochs / tmult) print('Evolving for {:}e'.format(training_epochs)) losst = FLAGS.losst learning_rate = FLAGS.lr batch_size_per_device = FLAGS.bs N = FLAGS.N K = FLAGS.K batch_size = batch_size_per_device * ndevices steps_per_epoch = 50000 // batch_size training_steps = training_epochs * steps_per_epoch "Filename from FLAGS" filename = 'wrnL2_' + losst + '_n' + str(N) + '_k' + str(K) if FLAGS.momentum: filename += '_mom' if FLAGS.L2_sch: filename += '_L2sch' + '_decay' + str(FLAGS.L2dec) + '_del' + str( FLAGS.delay) if FLAGS.seed != 1: filename += 'seed' + str(FLAGS.seed) filename += '_L2' + str(FLAGS.L2) if FLAGS.std_wrn_sch: filename += '_stddec' if FLAGS.physical: filename += 'phys' else: filename += '_ctlr' if not FLAGS.augment: filename += '_noaug' if not FLAGS.mix: filename += '_nomixup' filename += '_bs' + str(batch_size) + '_lr' + str(learning_rate) if FLAGS.jobdir is not None: filedir = os.path.join('wrnlogs', FLAGS.jobdir) else: filedir = 'wrnlogs' if not os.path.exists(filedir): os.makedirs(filedir) filedir = os.path.join(filedir, filename + '.csv') print('Saving log to ', filename) print('Found {} cores.'.format(ndevices)) """Load CIFAR10 data and create a minimal pipeline.""" train_images, train_labels, test_images, test_labels = utils.load_data( 'cifar10') train_images = np.reshape(train_images, (-1, 32, 32 * 3)) train = (train_images, train_labels) test = (test_images, test_labels) k = train_labels.shape[-1] train = utils.shard_data(train, ndevices) test = utils.shard_data(test, ndevices) """Create a Wide Resnet and replicate its parameters across the devices.""" initparams, f, _ = utils.WideResnetnt(N, K, k) "Loss and optimizer definitions" l2_norm = lambda params: tree_map(lambda x: np.sum(x**2), params) l2_reg = lambda params: tree_reduce(lambda x, y: x + y, l2_norm(params)) currL2 = FLAGS.L2 L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, ))) def xentr(params, images_and_labels): images, labels = images_and_labels return -np.mean(stax.logsoftmax(f(params, images)) * labels) def mse(params, data_tuple): """MSE loss.""" x, y = data_tuple return 0.5 * np.mean((y - f(params, x))**2) if losst == 'xentr': print('Using xentr') lossm = xentr else: print('Using mse') lossm = mse loss = lambda params, data, L2: lossm(params, data) + L2 * l2_reg(params) def accuracy(params, images_and_labels): images, labels = images_and_labels return np.mean( np.array(np.argmax(f(params, images), axis=1) == np.argmax(labels, axis=1), dtype=np.float32)) "Define optimizer" if FLAGS.std_wrn_sch: lr = learning_rate first_epoch = int(60 / 200 * training_epochs) learning_rate_fn = optimizers.piecewise_constant( np.array([1, 2, 3]) * first_epoch * steps_per_epoch, np.array([lr, lr * 0.2, lr * 0.2**2, lr * 0.2**3])) else: learning_rate_fn = optimizers.make_schedule(learning_rate) if FLAGS.momentum: momentum = 0.9 else: momentum = 0 @pmap def update_step(step, state, batch_state, L2): batch, batch_state = batch_fn(batch_state) params = get_params(state) dparams = grad_loss(params, batch, L2) dparams = tree_map(lambda x: lax.psum(x, 'i') / ndevices, dparams) return step + 1, apply_fn(step, dparams, state), batch_state @pmap def evaluate(state, data, L2): params = get_params(state) lossmm = lossm(params, data) l2mm = l2_reg(params) return lossmm + L2 * l2mm, accuracy(params, data), lossmm, l2mm "Initialization and loading" _, params = initparams(random.PRNGKey(0), (-1, 32, 32, 3)) replicate_array = lambda x: \ np.broadcast_to(x, (ndevices,) + x.shape) replicated_params = tree_map(replicate_array, params) grad_loss = jit(grad(loss)) init_fn, apply_fn, get_params = optimizers.momentum( learning_rate_fn, momentum) apply_fn = jit(apply_fn) key = random.PRNGKey(FLAGS.seed) batchinit_fn, batch_fn = utils.sharded_minibatcher(batch_size, ndevices, transform=FLAGS.augment, k=k, mix=FLAGS.mix) batch_state = pmap(batchinit_fn)(random.split(key, ndevices), train) state = pmap(init_fn)(replicated_params) if FLAGS.checkpointing: ## Loading of checkpoint if available/provided. single_state = init_fn(params) i0, load_state, load_params, filename0, batch_stateb = utils.load_weights( filename, single_state, params, full_file=FLAGS.load_w, ndevices=ndevices) if i0 is not None: filename = filename0 if batch_stateb is not None: batch_state = batch_stateb if load_params is not None: state = pmap(init_fn)(load_params) else: state = load_state else: i0 = 0 else: i0 = 0 if FLAGS.steps_from_load: training_steps = i0 + training_steps batch_xs, _ = pmap(batch_fn)(batch_state) train_loss = [] train_accuracy = [] lrL = [] test_loss = [] test_accuracy = [] test_L2, test_lm, train_lm, train_L2 = [], [], [], [] L2_t = [] idel0 = i0 start = time.time() step = pmap(lambda x: x)(i0 * np.ones((ndevices, ))) "Start training loop" if FLAGS.checkpointing: print('Evolving for {:}e and saving every {:}s'.format( training_epochs, FLAGS.checkpointing)) print( 'Epoch\tLearning Rate\tTrain bareLoss\t L2_norm \tTest Loss\tTrain Error\tTest Error\tTime / Epoch' ) for i in range(i0, training_steps): if i % meas_step == 0: # Make Measurement l, a, lm, L2m = evaluate(state, test, L2p) test_loss += [np.mean(l)] test_accuracy += [np.mean(a)] test_lm += [np.mean(lm)] test_L2 += [np.mean(L2m)] train_batch, _ = pmap(batch_fn)(batch_state) l, a, lm, L2m = evaluate(state, train_batch, L2p) train_loss += [np.mean(l)] train_accuracy += [np.mean(a)] train_lm += [np.mean(lm)] train_L2 += [np.mean(L2m)] L2_t.append(currL2) lrL += [learning_rate_fn(i)] if FLAGS.L2_sch and i > FLAGS.delay / currL2 + idel0 and len( train_lm) > 2 and ((minloss <= train_lm[-1] and minloss <= train_lm[-2]) or (maxacc >= train_accuracy[-1] and maxacc >= train_accuracy[-2])): # If AutoL2 is on and we are beyond the refractory period, decay if the loss or error have increased in the last two measurements. print('Decaying L2 to', currL2 / FLAGS.L2dec) currL2 = currL2 / FLAGS.L2dec L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, ))) idel0 = i elif FLAGS.L2_sch and len(train_lm) >= 2: # Update the minimum values. try: maxacc = max(train_accuracy[-2], maxacc) minloss = min(train_lm[-2], minloss) except: maxacc, minloss = train_accuracy[-2], train_lm[-2] if i % (meas_step * 10) == 0 or i == i0: # Save measurements to csv epoch = batch_size * i / 50000 dt = (time.time() - start) / (meas_step * 10) * steps_per_epoch print(('{}\t' + ('{: .4f}\t' * 7)).format( epoch, learning_rate_fn(i), train_lm[-1], train_L2[-1], test_loss[-1], train_accuracy[-1], test_accuracy[-1], dt)) start = time.time() data = { 'train_loss': train_loss, 'test_loss': test_loss, 'train_acc': train_accuracy, 'test_acc': test_accuracy } data['train_bareloss'] = train_lm data['train_L2'] = train_L2 data['test_bareloss'] = test_lm data['test_L2'] = test_L2 data['L2_t'] = L2_t df = pd.DataFrame(data) df['learning_rate'] = lrL df['width'] = K df['batch_size'] = batch_size df['step'] = i0 + onp.arange(0, len(train_loss)) * meas_step df.to_csv(filedir, index=False) if FLAGS.checkpointing: ### SAVE MODEL if i % FLAGS.checkpointing == 0 and i > i0: if not os.path.exists('weights/'): os.makedirs('weights/') saveparams = tree_flatten(state[0])[0] if ndevices > 1: saveparams = [el[0] for el in saveparams] saveparams = np.concatenate( [el.reshape(-1) for el in saveparams]) step0 = i print('Step', i) print('saving at', filename, step0, 'size:', saveparams.shape) utils.save_weights(filename, step0, saveparams, batch_state) ## UPDATE step, state, batch_state = update_step(step, state, batch_state, L2p) print('Training done') if FLAGS.TPU: with open('done/' + TPU_ADDR, 'w') as fp: fp.write(filedir) pass
def _JaxMomentum(machine, learning_rate, beta=0.9, l2reg=0): return Wrap(machine, jaxopt.momentum(learning_rate, beta))
train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, wandb.config.batch_size) num_batches = num_complete_batches + bool(leftover) def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * wandb.config.batch_size:(i + 1) * wandb.config.batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() opt_init, opt_update, get_params = optimizers.momentum(wandb.config.step_size, mass=wandb.config.momentum_mass) @jit def update(i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) _, init_params = init_random_params(rng, (-1, 28 * 28)) opt_state = opt_init(init_params) itercount = itertools.count() print("\nStarting training...") for epoch in range(wandb.config.num_epochs): start_time = time.time() for _ in range(num_batches): opt_state = update(next(itercount), opt_state, next(batches))
def main(unused_argv): # print(f'Available GPU memory: {util.get_gpu_memory()}') # Load and normalize data print('Loading data...') x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', n_train=60000, n_test=10000, permute_train=True) # print(f'Available GPU memory: {util.get_gpu_memory()}') # Reformat MNIST data to 28x28x1 pictures x_train = np.asarray(x_train.reshape(-1, 28, 28, 1)) x_test = np.asarray(x_test.reshape(-1, 28, 28, 1)) print('Data loaded and reshaped') # print(f'Available GPU memory: {util.get_gpu_memory()}') # Set random seed key = random.PRNGKey(0) # # Add random translation to images # x_train = util.add_translation(x_train, FLAGS.max_pixel) # x_test = util.add_translation(x_test, FLAGS.max_pixel) # print(f'Random translation by up to {FLAGS.max_pixel} pixels added') # # Add random translations with padding # x_train = util.add_padded_translation(x_train, 10) # x_test = util.add_padded_translation(x_test, 10) # print(f'Random translations with additional padding up to 10 pixels added') # Build the LeNet network with NTK parameterization init_fn, f, kernel_fn = util.build_le_net(FLAGS.network_width) print(f'Network of width x{FLAGS.network_width} built.') # # Construct the kernel function # kernel_fn = nt.batch(kernel_fn, device_count=-1, batch_size=FLAGS.batch_size_kernel) # print('Kernel constructed') # print(f'Available GPU memory: {util.get_gpu_memory()}') # Compute random initial parameters _, params = init_fn(key, (-1, 28, 28, 1)) params_lin = params print('Initial parameters constructed') # print(f'Available GPU memory: {util.get_gpu_memory()}') # # Save initial parameters # with open('init_params.npy', 'wb') as file: # np.save(file, params) # Linearize the network about its initial parameters. # Use jit for faster GPU computation (only feasible for width < 25) f_lin = nt.linearize(f, params) if FLAGS.network_width <= 10: f_jit = jit(f) f_lin_jit = jit(f_lin) else: f_jit = f f_lin_jit = f_lin # Create a callable function for dynamic learning rates # Starts with learning_rate, divided by 10 after learning_decline epochs. dynamic_learning_rate = lambda iteration_step: FLAGS.learning_rate / 10**( (iteration_step // (x_train.shape[0] // FLAGS.batch_size)) // FLAGS.learning_decline) # Create and initialize an optimizer for both f and f_lin. # Use momentum with coefficient 0.9 and jit opt_init, opt_apply, get_params = optimizers.momentum( dynamic_learning_rate, 0.9) opt_apply = jit(opt_apply) # Compute the initial states state = opt_init(params) state_lin = opt_init(params) # Define the accuracy function accuracy = lambda fx, y_hat: np.mean( np.argmax(fx, axis=1) == np.argmax(y_hat, axis=1)) # Define mean square error loss function loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) # # Create a cross-entropy loss function. # loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print( f'Training with dynamic learning decline after {FLAGS.learning_decline} epochs...' ) print( 'Epoch\tTime\tAccuracy\tLin. Accuracy\tLoss\tLin. Loss\tAccuracy Train\tLin.Accuracy Train' ) print( '----------------------------------------------------------------------------------------------------------' ) # Initialize training epoch = 0 steps_per_epoch = x_train.shape[0] // FLAGS.batch_size # Set start time (total and 100 epochs) start = time.time() start_epoch = time.time() for i, (x, y) in enumerate( datasets.minibatch(x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)): # Update the parameters params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) # Print information after each 100 epochs if (i + 1) % (steps_per_epoch * 100) == 0: time_point = time.time() - start_epoch # Update epoch epoch += 100 # Accuracy in batches f_x = util.output_in_batches(x_train, params, f_jit, FLAGS.batch_count_accuracy) f_x_test = util.output_in_batches(x_test, params, f_jit, FLAGS.batch_count_accuracy) f_x_lin = util.output_in_batches(x_train, params_lin, f_lin_jit, FLAGS.batch_count_accuracy) f_x_lin_test = util.output_in_batches(x_test, params_lin, f_lin_jit, FLAGS.batch_count_accuracy) # time_point = time.time() - start_epoch # Print information about past 100 epochs print( '{}\t{:.3f}\t{:.4f}\t\t{:.4f}\t\t{:.4f}\t{:.4f}\t\t{:.4f}\t\t{:.4f}' .format(epoch, time_point, accuracy(f_x, y_train) * 100, accuracy(f_x_lin, y_train) * 100, loss(f_x, y_train), loss(f_x_lin, y_train), accuracy(f_x_test, y_test) * 100, accuracy(f_x_lin_test, y_test) * 100)) # # Save params if epoch is multiple of learning decline or multiple of fixed value # if epoch % FLAGS.learning_decline == 0: # filename = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_pmod_{epoch}_{FLAGS.learning_decline}.npy' # with open(filename, 'wb') as file: # np.save(file, params) # filename_lin = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_pmod_{epoch}_{FLAGS.learning_decline}_lin.npy' # with open(filename_lin, 'wb') as file_lin: # np.save(file_lin, params_lin) # Reset timer start_epoch = time.time() duration = time.time() - start print( '----------------------------------------------------------------------------------------------------------' ) print(f'Training complete in {duration} seconds.') # # Save final params in file # filename_final = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_final_pmod_{FLAGS.train_epochs}_{FLAGS.learning_decline}.npy ' # with open(filename_final, 'wb') as final: # np.save(final, params) # filename_final_lin = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_final_pmod_{FLAGS.train_epochs}_{FLAGS.learning_decline}_lin.npy' # with open(filename_final_lin, 'wb') as final_lin: # np.save(final_lin, params_lin) # Compute output in batches f_x = util.output_in_batches(x_train, params, f_jit, FLAGS.batch_count_accuracy) f_x_lin = util.output_in_batches(x_train, params_lin, f_lin_jit, FLAGS.batch_count_accuracy) f_x_test = util.output_in_batches(x_test, params, f_jit, FLAGS.batch_count_accuracy) f_x_lin_test = util.output_in_batches(x_test, params_lin, f_lin_jit, FLAGS.batch_count_accuracy) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, f_x, f_x_lin, loss) util.print_summary('test', y_test, f_x_test, f_x_lin_test, loss)
def __init__(self, learning_rate, mass=0.9): super().__init__(learning_rate) self.mass = mass self.opt_init, self.opt_update, self.get_params = momentum( step_size=self.lr, mass=self.mass)
pi = jnp.array([1, 1]) / 2 casino = HMMJax(A, B, pi) num_hidden, num_obs = 2, 6 seed = 0 rng_key = PRNGKey(seed) rng_key, rng_sample = split(rng_key) n_obs_seq, max_len = 4, 5000 num_epochs = 400 observations, lens = pad_sequences( *hmm_sample_n(casino, hmm_sample_jax, n_obs_seq, max_len, rng_sample)) optimizer = optimizers.momentum(step_size=1e-3, mass=0.95) # Mini Batch Gradient Descent batch_size = 2 params_mbgd, losses_mbgd = fit(observations, lens, num_hidden, num_obs, batch_size, optimizer, rng_key=None, num_epochs=num_epochs) # Full Batch Gradient Descent batch_size = n_obs_seq params_fbgd, losses_fbgd = fit(observations,
def main(_): logging.info('Starting experiment.') configs = FLAGS.config # Create model folder for outputs try: gfile.MakeDirs(FLAGS.exp_dir) except gfile.GOSError: pass stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.exp_dir), 'w+') logging.info('Loading data.') tic = time.time() train_images, train_labels, _ = datasets.get_dataset_split( FLAGS.dataset, 'train') n_train = len(train_images) train_mu, train_std = onp.mean(train_images), onp.std(train_images) train = data.DataChunk(X=(train_images - train_mu) / train_std, Y=train_labels, image_size=32, image_channels=3, label_dim=1, label_format='numeric') test_images, test_labels, _ = datasets.get_dataset_split( FLAGS.dataset, 'test') test = data.DataChunk( X=(test_images - train_mu) / train_std, # normalize w train mean/std Y=test_labels, image_size=32, image_channels=3, label_dim=1, label_format='numeric') # Data augmentation if configs.augment_data: augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5), data.RandomCrop(4), data.ToDevice) else: augmentation = None batch = data.minibatcher(train, configs.batch_size, transform=augmentation) # Model architecture if configs.architect == 'wrn': init_random_params, predict = wide_resnet(configs.block_size, configs.channel_multiplier, 10) elif configs.architect == 'cnn': init_random_params, predict = cnn() else: raise ValueError('Model architecture not implemented.') if configs.seed is not None: key = random.PRNGKey(configs.seed) else: key = random.PRNGKey(int(time.time())) _, params = init_random_params(key, (-1, 32, 32, 3)) # count params of JAX model def count_parameters(params): return tree_util.tree_reduce( operator.add, tree_util.tree_map(lambda x: np.prod(x.shape), params)) logging.info('Number of parameters: %d', count_parameters(params)) stdout_log.write('Number of params: {}\n'.format(count_parameters(params))) # loss functions def cross_entropy_loss(params, x_img, y_lbl): return -np.mean(stax.logsoftmax(predict(params, x_img)) * y_lbl) def mse_loss(params, x_img, y_lbl): return 0.5 * np.mean((y_lbl - predict(params, x_img))**2) def accuracy(y_lbl_hat, y_lbl): target_class = np.argmax(y_lbl, axis=1) predicted_class = np.argmax(y_lbl_hat, axis=1) return np.mean(predicted_class == target_class) # Loss and gradient if configs.loss == 'xent': loss = cross_entropy_loss elif configs.loss == 'mse': loss = mse_loss else: raise ValueError('Loss function not implemented.') grad_loss = jit(grad(loss)) # learning rate schedule and optimizer def cosine(initial_step_size, train_steps): k = np.pi / (2.0 * train_steps) def schedule(i): return initial_step_size * np.cos(k * i) return schedule if configs.optimization == 'sgd': lr_schedule = optimizers.make_schedule(configs.learning_rate) opt_init, opt_update, get_params = optimizers.sgd(lr_schedule) elif configs.optimization == 'momentum': lr_schedule = cosine(configs.learning_rate, configs.train_steps) opt_init, opt_update, get_params = optimizers.momentum( lr_schedule, 0.9) else: raise ValueError('Optimizer not implemented.') opt_state = opt_init(params) def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier, batch_size): """Return differentially private gradients of params, evaluated on batch.""" def _clipped_grad(params, single_example_batch): """Evaluate gradient for a single-example batch and clip its grad norm.""" grads = grad_loss(params, single_example_batch[0].reshape( (-1, 32, 32, 3)), single_example_batch[1]) nonempty_grads, tree_def = tree_util.tree_flatten(grads) total_grad_norm = np.linalg.norm( [np.linalg.norm(neg.ravel()) for neg in nonempty_grads]) divisor = stop_gradient( np.amax((total_grad_norm / l2_norm_clip, 1.))) normalized_nonempty_grads = [ neg / divisor for neg in nonempty_grads ] return tree_util.tree_unflatten(tree_def, normalized_nonempty_grads) px_clipped_grad_fn = vmap(partial(_clipped_grad, params)) std_dev = l2_norm_clip * noise_multiplier noise_ = lambda n: n + std_dev * random.normal(rng, n.shape) normalize_ = lambda n: n / float(batch_size) sum_ = lambda n: np.sum(n, 0) # aggregate aggregated_clipped_grads = tree_util.tree_map( sum_, px_clipped_grad_fn(batch)) noised_aggregated_clipped_grads = tree_util.tree_map( noise_, aggregated_clipped_grads) normalized_noised_aggregated_clipped_grads = (tree_util.tree_map( normalize_, noised_aggregated_clipped_grads)) return normalized_noised_aggregated_clipped_grads # summarize measurements steps_per_epoch = n_train // configs.batch_size def summarize(step, params): """Compute measurements in a zipped way.""" set_entries = [train, test] set_bsizes = [configs.train_eval_bsize, configs.test_eval_bsize] set_names, loss_dict, acc_dict = ['train', 'test'], {}, {} for set_entry, set_bsize, set_name in zip(set_entries, set_bsizes, set_names): temp_loss, temp_acc, points = 0.0, 0.0, 0 for b in data.batch(set_entry, set_bsize): temp_loss += loss(params, b.X, b.Y) * b.X.shape[0] temp_acc += accuracy(predict(params, b.X), b.Y) * b.X.shape[0] points += b.X.shape[0] loss_dict[set_name] = temp_loss / float(points) acc_dict[set_name] = temp_acc / float(points) logging.info('Step: %s', str(step)) logging.info('Train acc : %.4f', acc_dict['train']) logging.info('Train loss: %.4f', loss_dict['train']) logging.info('Test acc : %.4f', acc_dict['test']) logging.info('Test loss : %.4f', loss_dict['test']) stdout_log.write('Step: {}\n'.format(step)) stdout_log.write('Train acc : {}\n'.format(acc_dict['train'])) stdout_log.write('Train loss: {}\n'.format(loss_dict['train'])) stdout_log.write('Test acc : {}\n'.format(acc_dict['test'])) stdout_log.write('Test loss : {}\n'.format(loss_dict['test'])) return acc_dict['test'] toc = time.time() logging.info('Elapsed SETUP time: %s', str(toc - tic)) stdout_log.write('Elapsed SETUP time: {}\n'.format(toc - tic)) # BEGIN: training steps logging.info('Training network.') tic = time.time() t = time.time() for s in range(configs.train_steps): b = next(batch) params = get_params(opt_state) # t0 = time.time() if FLAGS.dpsgd: key = random.fold_in(key, s) # get new key for new random numbers opt_state = opt_update( s, private_grad(params, (b.X.reshape( (-1, 1, 32, 32, 3)), b.Y), key, configs.l2_norm_clip, configs.noise_multiplier, configs.batch_size), opt_state) else: opt_state = opt_update(s, grad_loss(params, b.X, b.Y), opt_state) # t1 = time.time() # logging.info('batch update time: %s', str(t1 - t0)) if s % steps_per_epoch == 0: with gfile.Open( '{}/ckpt_{}'.format(FLAGS.exp_dir, int(s / steps_per_epoch)), 'wr') as fckpt: pickle.dump(optimizers.unpack_optimizer_state(opt_state), fckpt) if FLAGS.dpsgd: eps = compute_epsilon(s, configs.batch_size, n_train, configs.target_delta, configs.noise_multiplier) stdout_log.write( 'For delta={:.0e}, current epsilon is: {:.2f}\n'.format( configs.target_delta, eps)) logging.info('Elapsed EPOCH time: %s', str(time.time() - t)) stdout_log.write('Elapsed EPOCH time: {}'.format(time.time() - t)) stdout_log.flush() t = time.time() toc = time.time() summarize(configs.train_steps, params) logging.info('Elapsed TRAIN time: %s', str(toc - tic)) stdout_log.write('Elapsed TRAIN time: {}'.format(toc - tic)) stdout_log.close()
def run(): """ Run the experiment. """ ds_train, ds_train_eval, meta = init_data() num_batches = meta["num_batches"] num_test_batches = meta["num_test_batches"] forward, model = init_model() forward_all = model["model"]["forward_all"] grad_fn = jax.grad(lambda *args: loss_fn(forward, *args)) def lr_schedule(train_itr): """ The learning rate schedule. """ _epoch = train_itr // num_batches id = lambda x: x return lax.cond( _epoch < 60, 1e-1, id, 0, lambda _: lax.cond( _epoch < 100, 1e-2, id, 0, lambda _: lax.cond( _epoch < 140, 1e-3, id, 1e-4, id))) opt_init, opt_update, get_params = optimizers.momentum( step_size=lr_schedule, mass=0.9) if parse_args.load_ckpt: file_ = open(parse_args.load_ckpt, 'rb') init_params = pickle.load(file_) file_.close() # parse itr from the checkpoint load_itr = int(os.path.basename(parse_args.load_ckpt).split("_")[-2]) else: init_params = model["params"] load_itr = 0 opt_state = opt_init(init_params) #@jax.jit def update(_itr, _opt_state, _key, _batch): """ Update the params based on grad for current batch. """ images, labels = _batch return opt_update( _itr, grad_fn(get_params(_opt_state), images, labels, _key), _opt_state) # @jax.jit def sep_losses(_opt_state, _batch, key): """ Convenience function for calculating losses separately. """ params = get_params(_opt_state) images, labels = _batch logits, r2_regs, fro_regs, kin_regs = forward_all(key, params, images) loss_ = _loss_fn(logits, labels) r2_reg_ = _reg_loss_fn(r2_regs) fro_reg_ = _reg_loss_fn(fro_regs) kin_reg_ = _reg_loss_fn(kin_regs) total_loss_ = loss_ + lam * r2_reg_ + lam_fro * fro_reg_ + lam_kin * kin_reg_ acc_ = _acc_fn(logits, labels) return acc_, total_loss_, loss_, r2_reg_, fro_reg_, kin_reg_ def evaluate_loss(opt_state, _key, ds_train_eval): """ Convenience function for evaluating loss over train set in smaller batches. """ sep_acc_, sep_loss_aug_, sep_loss_, \ sep_loss_r2_reg_, sep_loss_fro_reg_, sep_loss_kin_reg_, nfe = [], [], [], [], [], [], [] for test_batch_num in range(num_test_batches): test_batch = next(ds_train_eval) _key, = jax.random.split(_key, num=1) test_batch_acc_, test_batch_loss_aug_, test_batch_loss_, \ test_batch_loss_r2_reg_, test_batch_loss_fro_reg_, test_batch_loss_kin_reg_ = \ sep_losses(opt_state, test_batch, _key) if count_nfe: nfe.append(model["nfe"](get_params(opt_state), *test_batch)) else: nfe.append(0) sep_acc_.append(test_batch_acc_) sep_loss_aug_.append(test_batch_loss_aug_) sep_loss_.append(test_batch_loss_) sep_loss_r2_reg_.append(test_batch_loss_r2_reg_) sep_loss_fro_reg_.append(test_batch_loss_fro_reg_) sep_loss_kin_reg_.append(test_batch_loss_kin_reg_) sep_acc_ = jnp.array(sep_acc_) sep_loss_aug_ = jnp.array(sep_loss_aug_) sep_loss_ = jnp.array(sep_loss_) sep_loss_r2_reg_ = jnp.array(sep_loss_r2_reg_) sep_loss_fro_reg_ = jnp.array(sep_loss_fro_reg_) sep_loss_kin_reg_ = jnp.array(sep_loss_kin_reg_) nfe = jnp.array(nfe) return jnp.mean(sep_acc_), jnp.mean(sep_loss_aug_), jnp.mean(sep_loss_), \ jnp.mean(sep_loss_r2_reg_), jnp.mean(sep_loss_fro_reg_), jnp.mean(sep_loss_kin_reg_), jnp.mean(nfe) itr = 0 info = collections.defaultdict(dict) key = rng #创建迭代器 iterator = iter(ds_train) for epoch in range(parse_args.nepochs): for i in range(num_batches): batch = next(iterator) key, = jax.random.split(key, num=1) itr += 1 if parse_args.load_ckpt: if itr <= load_itr: continue update_start = time.time() opt_state = update(itr, opt_state, key, batch) tree_flatten(opt_state)[0][0].block_until_ready() update_end = time.time() time_str = "%d %.18f %d\n" % (itr, update_end - update_start, load_itr) outfile = open( "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_time.txt" % (dirname, reg, reg_type, lam, lam_fro, lam_kin), "a") outfile.write(time_str) outfile.close() if itr % parse_args.test_freq == 0: acc_, loss_aug_, loss_, \ loss_r2_reg_, loss_fro_reg_, loss_kin_reg_, nfe_ = evaluate_loss(opt_state, key, ds_train_eval) print_str = 'Iter {:04d} | Total (Regularized) Loss {:.6f} | Loss {:.6f} | ' \ 'r {:.6f} | fro {:.6f} | kin {:.6f} | ' \ 'NFE {:.6f}'.format(itr, loss_aug_, loss_, loss_r2_reg_, loss_fro_reg_, loss_kin_reg_, nfe_) print(print_str) outfile = open( "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_info.txt" % (dirname, reg, reg_type, lam, lam_fro, lam_kin), "a") outfile.write(print_str + "\n") outfile.close() info[itr]["acc"] = acc_ info[itr]["loss_aug"] = loss_aug_ info[itr]["loss"] = loss_ info[itr]["loss_r2_reg"] = loss_r2_reg_ info[itr]["loss_fro_reg"] = loss_fro_reg_ info[itr]["loss_kin_reg"] = loss_kin_reg_ info[itr]["nfe"] = nfe_ if itr % parse_args.save_freq == 0: param_filename = "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_%d_fargs.pickle" \ % (dirname, reg, reg_type, lam, lam_fro, lam_kin, itr) fargs = get_params(opt_state) outfile = open(param_filename, "wb") pickle.dump(fargs, outfile) outfile.close() meta = {"info": info, "args": parse_args} outfile = open( "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_%d_meta.pickle" % (dirname, reg, reg_type, lam, lam_fro, lam_kin, itr), "wb") pickle.dump(meta, outfile) outfile.close()
num_complete_batches, leftover = divmod(num_train, config.batch_size) num_batches = num_complete_batches + bool(leftover) def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * config.batch_size:(i + 1) * config.batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() opt_init, opt_update, get_params = optimizers.momentum(config.learning_rate, mass=config.momentum_mass) @jit def update(i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) _, init_params = init_random_params(rng, (-1, 28 * 28)) opt_state = opt_init(init_params) itercount = itertools.count() print("\nStarting training...") for epoch in range(num_epochs): start_time = time.time()