def madgrad(step_size=0.01, momentum=0.9, epsilon=1.0e-6): """ Implementation of MADGRAD (Defazio and Jelassi, arXiv:2101.11075) """ step_size = make_schedule(step_size) momentum = make_schedule(momentum) def init(x): s = jnp.zeros_like(x) nu = jnp.zeros_like(x) x0 = x return x, s, nu, x0 def update(i, g, state): x, s, nu, x0 = state lbda = step_size(i) * jnp.sqrt(i + 1) s = s + lbda * g nu = nu + lbda * g * g z = x0 - s / (jnp.power(nu, 1.0 / 3.0) + epsilon) x = (1 - momentum(i)) * x + momentum(i) * z return x, s, nu, x0 def get_params(state): x, s, nu, x0 = state return x return Optimizer(init, update, get_params)
def full_solve_cga(step_size_f, step_size_g, f, g): """CGA using a naive implementation which build the full hessians.""" step_size_f = optimizers.make_schedule(step_size_f) step_size_g = optimizers.make_schedule(step_size_g) def init(inputs): return CGAState( x=inputs[0], y=inputs[1], delta_x=np.zeros_like(inputs[0]), delta_y=np.zeros_like(inputs[1]), ) def update(i, grads, inputs, *args, **kwargs): if len(inputs) < 4: x, y = inputs delta_x = None delta_y = None else: x, y, delta_x, delta_y = inputs grad_xf, grad_yg = grads eta_f = step_size_f(i) eta_g = step_size_g(i) Dxyf = make_mixed_hessian(partial(f, *args, **kwargs), 0, 1)(x, y) Dyxg = make_mixed_hessian(partial(g, *args, **kwargs), 1, 0)(x, y) bx = grad_xf + eta_f * np.dot(Dxyf, grad_yg) delta_x = np.linalg.solve( np.eye(x.shape[0]) - eta_f**2 * np.dot(Dxyf, Dyxg), bx, ) by = grad_yg + eta_g * np.dot(Dyxg, grad_xf) delta_y = np.linalg.solve( np.eye(y.shape[0]) - eta_g**2 * np.dot(Dyxg, Dxyf), by, ) x = x + eta_f * delta_x y = y + eta_g * delta_y return CGAState(x, y, delta_x, delta_y) def get_params(state): return state[:2] return init, update, get_params
def ngd_cg(step_size, b1=0.9, b2=0.999, eps=1e-8, lmda=0.001, decay=0.9): """Construct optimizer triple for Adam. Args: step_size: positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar. b1: optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9). b2: optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999). eps: optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e-8). Returns: An (init_fun, update_fun, get_params) triple. """ step_size = make_schedule(step_size) def init(x0): return x0, def update(i, g, state): # Get gradients # Solve cg ng = cg_solve(Fvp_fn, g) # compute step size based on stats lr = step_size(i) alpha = np.sqrt(np.abs(lr / (np.dot(g, ng) + 1e-20))) # update params x = x - alpha * ng return x def get_params(state): x, = state return x return init, update, get_params
def adam_custom(step_size, b1=0.9, b2=0.999, eps=1e-8): """Construct optimizer triple for Adam. Args: step_size: positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar. b1: optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9). b2: optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999). eps: optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e-8). Returns: An (init_fun, update_fun, get_params) triple. """ step_size = make_schedule(step_size) def init(x0): m0 = np.zeros_like(x0) v0 = np.zeros_like(x0) return x0, m0, v0 def update(i, g, state): x_step, m, v = state m = (1 - b1) * g + b1 * m # First moment estimate. v = (1 - b2) * np.square(g) + b2 * v # Second moment estimate. mhat = m / (1 - b1 ** (i + 1)) # Bias correction. vhat = v / (1 - b2 ** (i + 1)) x_step = step_size(i) * mhat / (np.sqrt(vhat) + eps) return x_step, m, v def get_params(state): x_step, _, _ = state return x_step return init, update, get_params
def sgd(step_size): step_size = jax_opt.make_schedule(step_size) def init(x0): return copy.deepcopy(x0) def update(i, g, x): return x - step_size(i) * g def get_params(x): return x return init, update, get_params
def momentum(learning_rate, momentum=0.9): """A standard momentum optimizer for testing. """ learning_rate = opt.make_schedule(learning_rate) def init_fun(x0): v0 = np.zeros_like(x0) return x0, v0 def update_fun(i, g, x, velocity): velocity = momentum * velocity + g x = x - learning_rate(i) * velocity return x, velocity return init_fun, update_fun
def momentum(learning_rate, momentum=0.9): """A standard momentum optimizer for testing. Different from `jax.experimental.optimizers.momentum` (Nesterov). """ learning_rate = opt.make_schedule(learning_rate) def init_fun(x0): v0 = np.zeros_like(x0) return x0, v0 def update_fun(i, g, state): x, velocity = state velocity = momentum * velocity + g x = x - learning_rate(i) * velocity return x, velocity def get_params(state): x, _ = state return x return init_fun, update_fun, get_params
def momentum(step_size, mass, weight_decay=0.): step_size = jax_opt.make_schedule(step_size) def init(x0): v0 = np.zeros_like(x0) return x0, v0 def update(i, g, state): x, velocity = state if weight_decay != 0.: g = g + weight_decay * x velocity = mass * velocity + g x = x - step_size(i) * velocity return x, velocity def get_params(state): x, _ = state return x return init, update, get_params
def adamW(step_size, b1=0.9, b2=0.999, eps=1e-8, w=0.01): """Construct optimizer triple for Adam. This docstring is different from the rest because we want to submit this to the jax library, so DON'T CHANGE IT TO SPHINX-STYLE! Args: step_size: positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar. b1: optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9). b2: optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999). eps: optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e-8). w: optional, weight decay term (default 0.01) Returns: An (init_fun, update_fun, get_params) triple. """ step_size = make_schedule(step_size) def init(x0): m0 = np.zeros_like(x0) v0 = np.zeros_like(x0) return x0, m0, v0 def update(i, g, state): x, m, v = state m = (1 - b1) * g + b1 * m # First moment estimate. v = (1 - b2) * (g ** 2) + b2 * v # Second moment estimate. mhat = m / (1 - b1 ** (i + 1)) # Bias correction. vhat = v / (1 - b2 ** (i + 1)) x = x - step_size(i) * (mhat / (np.sqrt(vhat) + eps) + w * x) return x, m, v def get_params(state): x, m, v = state return x return init, update, get_params
def adahessian(step_size=1e-1, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.0, hessian_power=1): """Construct optimizer triple for AdaHessian. Args: step_size: positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar. b1: optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9). b2: optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999). eps: optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e-4). weight_decay: optional, weight decay (L2 penalty) (default 0). hessian_power: optional, Hessian power (default 1) Returns: An (init_fun, update_fun, get_params) triple. """ step_size = make_schedule(step_size) def init(x0): m0 = jnp.zeros_like(x0) v0 = jnp.zeros_like(x0) return x0, m0, v0 def update(i, g, h, state): x, m, v = state h = average_magnitude(h) m = (1 - b1) * g + b1 * m # First moment estimate. v = (1 - b2) * jnp.square( h) + b2 * v # Second moment estimate for the Hessian. mhat = m / (1 - b1**(i + 1)) # Bias correction. vhat = v / (1 - b2**(i + 1)) x = x - step_size(i) * (mhat / (jnp.sqrt(vhat)**hessian_power + eps) + weight_decay * x) return x, m, v def get_params(state): x, _, _ = state return x return init, update, get_params
def rmomentum(step_size, manifold, mass): """Construct optimizer triple for stochastic gradient descent. Args: step_size: positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar. manifold: the manifold to perform riemannian optimization on. mass: positive scaler representing the momentum coefficient Returns: An (init_fun, update_fun, get_params) triple. """ step_size = make_schedule(step_size) def init(x0): return x0, jax.numpy.zeros_like(x0) def update(i, grad, state): ''' x, velocity = state velocity = mass * velocity + g x = x - step_size(i) * velocity return x, velocity ''' x, velocity = state rgrad = manifold.egrad_to_rgrad(x, grad) # velocity = mass * velocity + rgrad # both are in tangent space Tx velocity = mass * velocity + (1 - mass) * rgrad # both are in tangent space Tx new_x, velocity =\ manifold.retraction_transport(x, velocity, -step_size(i) * velocity) return new_x, velocity def get_params(state): x, _ = state return x return init, update, get_params
def gradient_ascent(step_size): """Construct optimizer triple for stochastic gradient descent. Args: step_size: positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar. Returns: An (init_fun, update_fun, get_params) triple. """ step_size = make_schedule(step_size) def init(x0): return x0 def update(i, g, x): return x + step_size(i) * g def get_params(x): return x return init, update, get_params
def rsgd(step_size, manifold): """Construct optimizer triple for stochastic gradient descent. Args: step_size: positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar. Returns: An (init_fun, update_fun, get_params) triple. """ step_size = make_schedule(step_size) def init(x0): return x0 def update(i, grad, x): rgrad = manifold.egrad_to_rgrad(x, grad) new_x = manifold.retraction(x, -step_size(i) * rgrad) return new_x def get_params(x): return x return init, update, get_params
def madam(step_size=0.01, b2=0.999, g_bound=10): step_size = optimizers.make_schedule(step_size) def init(x0): s0 = np.sqrt(np.mean(x0 * x0)) # Initial scale. v0 = np.zeros_like(x0) # 2nd moment. return x0, s0, v0 def update(i, g, state): x, s, v = state v = (1 - b2) * np.square(g) + b2 * v # Update 2nd moment. vhat = v / (1 - b2**(i + 1)) # Bias correction. g_norm = np.nan_to_num(g / np.sqrt(vhat)) # Normalise gradient. g_norm = np.clip(g_norm, -g_bound, g_bound) # Bound g. x *= np.exp(-step_size(i) * g_norm * np.sign(x)) # Multiplicative update. x = np.clip(x, -s, s) # Bound parameters. return x, s, v def get_params(state): x, s, v = state return x return init, update, get_params
def main(unused_argv): from jax.api import grad, jit, vmap, pmap, device_put "The following is required to use TPU Driver as JAX's backend." if FLAGS.TPU: config.FLAGS.jax_xla_backend = "tpu_driver" config.FLAGS.jax_backend_target = "grpc://" + os.environ[ 'TPU_ADDR'] + ':8470' TPU_ADDR = os.environ['TPU_ADDR'] ndevices = xla_bridge.device_count() if not FLAGS.TPU: ndevices = 1 pmap = partial(pmap, axis_name='i') """Setup some experiment parameters.""" meas_step = FLAGS.meas_step training_epochs = int(FLAGS.epochs) tmult = 1.0 if FLAGS.physical: tmult = FLAGS.lr if FLAGS.physicalL2: tmult = FLAGS.L2 * tmult if FLAGS.physical: training_epochs = 1 + int(FLAGS.epochs / tmult) print('Evolving for {:}e'.format(training_epochs)) losst = FLAGS.losst learning_rate = FLAGS.lr batch_size_per_device = FLAGS.bs N = FLAGS.N K = FLAGS.K batch_size = batch_size_per_device * ndevices steps_per_epoch = 50000 // batch_size training_steps = training_epochs * steps_per_epoch "Filename from FLAGS" filename = 'wrnL2_' + losst + '_n' + str(N) + '_k' + str(K) if FLAGS.momentum: filename += '_mom' if FLAGS.L2_sch: filename += '_L2sch' + '_decay' + str(FLAGS.L2dec) + '_del' + str( FLAGS.delay) if FLAGS.seed != 1: filename += 'seed' + str(FLAGS.seed) filename += '_L2' + str(FLAGS.L2) if FLAGS.std_wrn_sch: filename += '_stddec' if FLAGS.physical: filename += 'phys' else: filename += '_ctlr' if not FLAGS.augment: filename += '_noaug' if not FLAGS.mix: filename += '_nomixup' filename += '_bs' + str(batch_size) + '_lr' + str(learning_rate) if FLAGS.jobdir is not None: filedir = os.path.join('wrnlogs', FLAGS.jobdir) else: filedir = 'wrnlogs' if not os.path.exists(filedir): os.makedirs(filedir) filedir = os.path.join(filedir, filename + '.csv') print('Saving log to ', filename) print('Found {} cores.'.format(ndevices)) """Load CIFAR10 data and create a minimal pipeline.""" train_images, train_labels, test_images, test_labels = utils.load_data( 'cifar10') train_images = np.reshape(train_images, (-1, 32, 32 * 3)) train = (train_images, train_labels) test = (test_images, test_labels) k = train_labels.shape[-1] train = utils.shard_data(train, ndevices) test = utils.shard_data(test, ndevices) """Create a Wide Resnet and replicate its parameters across the devices.""" initparams, f, _ = utils.WideResnetnt(N, K, k) "Loss and optimizer definitions" l2_norm = lambda params: tree_map(lambda x: np.sum(x**2), params) l2_reg = lambda params: tree_reduce(lambda x, y: x + y, l2_norm(params)) currL2 = FLAGS.L2 L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, ))) def xentr(params, images_and_labels): images, labels = images_and_labels return -np.mean(stax.logsoftmax(f(params, images)) * labels) def mse(params, data_tuple): """MSE loss.""" x, y = data_tuple return 0.5 * np.mean((y - f(params, x))**2) if losst == 'xentr': print('Using xentr') lossm = xentr else: print('Using mse') lossm = mse loss = lambda params, data, L2: lossm(params, data) + L2 * l2_reg(params) def accuracy(params, images_and_labels): images, labels = images_and_labels return np.mean( np.array(np.argmax(f(params, images), axis=1) == np.argmax(labels, axis=1), dtype=np.float32)) "Define optimizer" if FLAGS.std_wrn_sch: lr = learning_rate first_epoch = int(60 / 200 * training_epochs) learning_rate_fn = optimizers.piecewise_constant( np.array([1, 2, 3]) * first_epoch * steps_per_epoch, np.array([lr, lr * 0.2, lr * 0.2**2, lr * 0.2**3])) else: learning_rate_fn = optimizers.make_schedule(learning_rate) if FLAGS.momentum: momentum = 0.9 else: momentum = 0 @pmap def update_step(step, state, batch_state, L2): batch, batch_state = batch_fn(batch_state) params = get_params(state) dparams = grad_loss(params, batch, L2) dparams = tree_map(lambda x: lax.psum(x, 'i') / ndevices, dparams) return step + 1, apply_fn(step, dparams, state), batch_state @pmap def evaluate(state, data, L2): params = get_params(state) lossmm = lossm(params, data) l2mm = l2_reg(params) return lossmm + L2 * l2mm, accuracy(params, data), lossmm, l2mm "Initialization and loading" _, params = initparams(random.PRNGKey(0), (-1, 32, 32, 3)) replicate_array = lambda x: \ np.broadcast_to(x, (ndevices,) + x.shape) replicated_params = tree_map(replicate_array, params) grad_loss = jit(grad(loss)) init_fn, apply_fn, get_params = optimizers.momentum( learning_rate_fn, momentum) apply_fn = jit(apply_fn) key = random.PRNGKey(FLAGS.seed) batchinit_fn, batch_fn = utils.sharded_minibatcher(batch_size, ndevices, transform=FLAGS.augment, k=k, mix=FLAGS.mix) batch_state = pmap(batchinit_fn)(random.split(key, ndevices), train) state = pmap(init_fn)(replicated_params) if FLAGS.checkpointing: ## Loading of checkpoint if available/provided. single_state = init_fn(params) i0, load_state, load_params, filename0, batch_stateb = utils.load_weights( filename, single_state, params, full_file=FLAGS.load_w, ndevices=ndevices) if i0 is not None: filename = filename0 if batch_stateb is not None: batch_state = batch_stateb if load_params is not None: state = pmap(init_fn)(load_params) else: state = load_state else: i0 = 0 else: i0 = 0 if FLAGS.steps_from_load: training_steps = i0 + training_steps batch_xs, _ = pmap(batch_fn)(batch_state) train_loss = [] train_accuracy = [] lrL = [] test_loss = [] test_accuracy = [] test_L2, test_lm, train_lm, train_L2 = [], [], [], [] L2_t = [] idel0 = i0 start = time.time() step = pmap(lambda x: x)(i0 * np.ones((ndevices, ))) "Start training loop" if FLAGS.checkpointing: print('Evolving for {:}e and saving every {:}s'.format( training_epochs, FLAGS.checkpointing)) print( 'Epoch\tLearning Rate\tTrain bareLoss\t L2_norm \tTest Loss\tTrain Error\tTest Error\tTime / Epoch' ) for i in range(i0, training_steps): if i % meas_step == 0: # Make Measurement l, a, lm, L2m = evaluate(state, test, L2p) test_loss += [np.mean(l)] test_accuracy += [np.mean(a)] test_lm += [np.mean(lm)] test_L2 += [np.mean(L2m)] train_batch, _ = pmap(batch_fn)(batch_state) l, a, lm, L2m = evaluate(state, train_batch, L2p) train_loss += [np.mean(l)] train_accuracy += [np.mean(a)] train_lm += [np.mean(lm)] train_L2 += [np.mean(L2m)] L2_t.append(currL2) lrL += [learning_rate_fn(i)] if FLAGS.L2_sch and i > FLAGS.delay / currL2 + idel0 and len( train_lm) > 2 and ((minloss <= train_lm[-1] and minloss <= train_lm[-2]) or (maxacc >= train_accuracy[-1] and maxacc >= train_accuracy[-2])): # If AutoL2 is on and we are beyond the refractory period, decay if the loss or error have increased in the last two measurements. print('Decaying L2 to', currL2 / FLAGS.L2dec) currL2 = currL2 / FLAGS.L2dec L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, ))) idel0 = i elif FLAGS.L2_sch and len(train_lm) >= 2: # Update the minimum values. try: maxacc = max(train_accuracy[-2], maxacc) minloss = min(train_lm[-2], minloss) except: maxacc, minloss = train_accuracy[-2], train_lm[-2] if i % (meas_step * 10) == 0 or i == i0: # Save measurements to csv epoch = batch_size * i / 50000 dt = (time.time() - start) / (meas_step * 10) * steps_per_epoch print(('{}\t' + ('{: .4f}\t' * 7)).format( epoch, learning_rate_fn(i), train_lm[-1], train_L2[-1], test_loss[-1], train_accuracy[-1], test_accuracy[-1], dt)) start = time.time() data = { 'train_loss': train_loss, 'test_loss': test_loss, 'train_acc': train_accuracy, 'test_acc': test_accuracy } data['train_bareloss'] = train_lm data['train_L2'] = train_L2 data['test_bareloss'] = test_lm data['test_L2'] = test_L2 data['L2_t'] = L2_t df = pd.DataFrame(data) df['learning_rate'] = lrL df['width'] = K df['batch_size'] = batch_size df['step'] = i0 + onp.arange(0, len(train_loss)) * meas_step df.to_csv(filedir, index=False) if FLAGS.checkpointing: ### SAVE MODEL if i % FLAGS.checkpointing == 0 and i > i0: if not os.path.exists('weights/'): os.makedirs('weights/') saveparams = tree_flatten(state[0])[0] if ndevices > 1: saveparams = [el[0] for el in saveparams] saveparams = np.concatenate( [el.reshape(-1) for el in saveparams]) step0 = i print('Step', i) print('saving at', filename, step0, 'size:', saveparams.shape) utils.save_weights(filename, step0, saveparams, batch_state) ## UPDATE step, state, batch_state = update_step(step, state, batch_state, L2p) print('Training done') if FLAGS.TPU: with open('done/' + TPU_ADDR, 'w') as fp: fp.write(filedir) pass
def proximal_optimizer(cls, prior=None, step_size=0.75, **kwargs): """Return an optimizer triplet, like jax.experimental.optimizers, to perform proximal gradient ascent on the likelihood with a penalty on the KL divergence between distributions from one iteration to the next. This boils down to taking a convex combination of sufficient statistics from this data and those that have been accumulated from past data. Returns: initial_state :: dictionary of optimizer state (sufficient statistics and number of datapoints) update :: minibatch, itr, state -> state get_distribution :: state -> Distribution object """ initial_state = dict(suff_stats=None, num_datapoints=0.0) schedule = make_schedule(step_size) @format_dataset def update(itr, dataset, state, weights=None, suff_stats=None, num_datapoints=0.0, scale_factor=1.0): # Compute the sufficient statistics and the number of datapoints if suff_stats is None: num_datapoints = 0.0 for data_dict, these_weights in zip(dataset, weights): these_stats = cls.sufficient_statistics( **data_dict, **kwargs) # weight the statistics if weights are given if these_weights is not None: these_stats = tuple( np.tensordot(these_weights, s, axes=(0, 0)) for s in these_stats) else: these_stats = tuple(s.sum(axis=0) for s in these_stats) # add to our accumulated statistics suff_stats = sum_tuples(suff_stats, these_stats) # update the number of datapoints num_datapoints += these_weights.sum() else: # assume suff_stats and num_datapoints are given pass # Scale the sufficient statistics by the given scale factor. # This is as if the sufficient statistics were accumulated # from the entire dataset rather than a batch. suff_stats = tuple(scale_factor * ss for ss in suff_stats) num_datapoints = scale_factor * num_datapoints # Take a convex combination of sufficient statistics from # this batch and those accumulated thus far. if state["suff_stats"] is not None: state["suff_stats"] = convex_combination( state["suff_stats"], suff_stats, schedule(itr)) state["num_datapoints"] = convex_combination( state["num_datapoints"], num_datapoints, schedule(itr)) else: state = dict(suff_stats=suff_stats, num_datapoints=num_datapoints) return state def get_distribution(state): # Update parameters with the average stats return cls.fit_with_stats(state["suff_stats"], state["num_datapoints"], prior=prior, **kwargs) return initial_state, update, get_distribution
def adadp(step_size=1e-3, tol=1.0, stability_check=True, alpha_min=0.9, alpha_max=1.1): """Construct optimizer triple for the adaptive learning rate optimizer of Koskela and Honkela. Reference: A. Koskela, A. Honkela: Learning Rate Adaptation for Federated and Differentially Private Learning (https://arxiv.org/abs/1809.03832). Args: step_size: the initial step size tol: error tolerance for the discretized gradient steps stability_check: settings to True rejects some updates in favor of a more stable algorithm alpha_min: lower multiplitcative bound of learning rate update per step alpha_max: upper multiplitcative bound of learning rate update per step Returns: An (init_fun, update_fun, get_params) triple. """ step_size = make_schedule(step_size) def init(x0): lr = step_size(0) x_stepped = tree_map(lambda n: jnp.zeros_like(n), x0) return x0, lr, x_stepped, x0 def _compute_update_step(x, g, step_size_): return tree_multimap(lambda x_, g_: x_ - step_size_ * g_, x, g) def _update_even_step(args): g, state, new_x = args x, lr, x_stepped, x_prev = state x_prev = x x_stepped = _compute_update_step(x, g, lr) return new_x, lr, x_stepped, x_prev def _update_odd_step(args): g, state, new_x = args x, lr, x_stepped, x_prev = state norm_partials = tree_multimap( lambda x_full, x_halfs: jnp.sum( ((x_full - x_halfs) / jnp.maximum(1., x_full))**2), x_stepped, new_x) err_e = jnp.array(tree_leaves(norm_partials)) # note(lumip): paper specifies the approximate error function as # using absolute values, but since we square anyways, those are # not required here; the resulting array is partial squared sums # of the l2-norm over all gradient elements (per gradient site) err_e = jnp.sqrt(jnp.sum(err_e)) # summing partial gradient norm new_lr = lr * jnp.minimum(jnp.maximum(jnp.sqrt(tol / err_e), 0.9), 1.1) new_x = tree_multimap( lambda x_prev, new_x: jnp.where(stability_check and err_e > tol, x_prev, new_x), x_prev, new_x) return new_x, new_lr, x_stepped, x_prev def update(i, g, state): x, lr, x_stepped, x_prev = state new_x = _compute_update_step(x, g, 0.5 * lr) return lax.cond(i % 2 == 0, (g, state, new_x), _update_even_step, (g, state, new_x), _update_odd_step) def get_params(state): x = state[0] return x return init, update, get_params
def main(_): logging.info('Starting experiment.') configs = FLAGS.config # Create model folder for outputs try: gfile.MakeDirs(FLAGS.exp_dir) except gfile.GOSError: pass stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.exp_dir), 'w+') logging.info('Loading data.') tic = time.time() train_images, train_labels, _ = datasets.get_dataset_split( FLAGS.dataset, 'train') n_train = len(train_images) train_mu, train_std = onp.mean(train_images), onp.std(train_images) train = data.DataChunk(X=(train_images - train_mu) / train_std, Y=train_labels, image_size=32, image_channels=3, label_dim=1, label_format='numeric') test_images, test_labels, _ = datasets.get_dataset_split( FLAGS.dataset, 'test') test = data.DataChunk( X=(test_images - train_mu) / train_std, # normalize w train mean/std Y=test_labels, image_size=32, image_channels=3, label_dim=1, label_format='numeric') # Data augmentation if configs.augment_data: augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5), data.RandomCrop(4), data.ToDevice) else: augmentation = None batch = data.minibatcher(train, configs.batch_size, transform=augmentation) # Model architecture if configs.architect == 'wrn': init_random_params, predict = wide_resnet(configs.block_size, configs.channel_multiplier, 10) elif configs.architect == 'cnn': init_random_params, predict = cnn() else: raise ValueError('Model architecture not implemented.') if configs.seed is not None: key = random.PRNGKey(configs.seed) else: key = random.PRNGKey(int(time.time())) _, params = init_random_params(key, (-1, 32, 32, 3)) # count params of JAX model def count_parameters(params): return tree_util.tree_reduce( operator.add, tree_util.tree_map(lambda x: np.prod(x.shape), params)) logging.info('Number of parameters: %d', count_parameters(params)) stdout_log.write('Number of params: {}\n'.format(count_parameters(params))) # loss functions def cross_entropy_loss(params, x_img, y_lbl): return -np.mean(stax.logsoftmax(predict(params, x_img)) * y_lbl) def mse_loss(params, x_img, y_lbl): return 0.5 * np.mean((y_lbl - predict(params, x_img))**2) def accuracy(y_lbl_hat, y_lbl): target_class = np.argmax(y_lbl, axis=1) predicted_class = np.argmax(y_lbl_hat, axis=1) return np.mean(predicted_class == target_class) # Loss and gradient if configs.loss == 'xent': loss = cross_entropy_loss elif configs.loss == 'mse': loss = mse_loss else: raise ValueError('Loss function not implemented.') grad_loss = jit(grad(loss)) # learning rate schedule and optimizer def cosine(initial_step_size, train_steps): k = np.pi / (2.0 * train_steps) def schedule(i): return initial_step_size * np.cos(k * i) return schedule if configs.optimization == 'sgd': lr_schedule = optimizers.make_schedule(configs.learning_rate) opt_init, opt_update, get_params = optimizers.sgd(lr_schedule) elif configs.optimization == 'momentum': lr_schedule = cosine(configs.learning_rate, configs.train_steps) opt_init, opt_update, get_params = optimizers.momentum( lr_schedule, 0.9) else: raise ValueError('Optimizer not implemented.') opt_state = opt_init(params) def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier, batch_size): """Return differentially private gradients of params, evaluated on batch.""" def _clipped_grad(params, single_example_batch): """Evaluate gradient for a single-example batch and clip its grad norm.""" grads = grad_loss(params, single_example_batch[0].reshape( (-1, 32, 32, 3)), single_example_batch[1]) nonempty_grads, tree_def = tree_util.tree_flatten(grads) total_grad_norm = np.linalg.norm( [np.linalg.norm(neg.ravel()) for neg in nonempty_grads]) divisor = stop_gradient( np.amax((total_grad_norm / l2_norm_clip, 1.))) normalized_nonempty_grads = [ neg / divisor for neg in nonempty_grads ] return tree_util.tree_unflatten(tree_def, normalized_nonempty_grads) px_clipped_grad_fn = vmap(partial(_clipped_grad, params)) std_dev = l2_norm_clip * noise_multiplier noise_ = lambda n: n + std_dev * random.normal(rng, n.shape) normalize_ = lambda n: n / float(batch_size) sum_ = lambda n: np.sum(n, 0) # aggregate aggregated_clipped_grads = tree_util.tree_map( sum_, px_clipped_grad_fn(batch)) noised_aggregated_clipped_grads = tree_util.tree_map( noise_, aggregated_clipped_grads) normalized_noised_aggregated_clipped_grads = (tree_util.tree_map( normalize_, noised_aggregated_clipped_grads)) return normalized_noised_aggregated_clipped_grads # summarize measurements steps_per_epoch = n_train // configs.batch_size def summarize(step, params): """Compute measurements in a zipped way.""" set_entries = [train, test] set_bsizes = [configs.train_eval_bsize, configs.test_eval_bsize] set_names, loss_dict, acc_dict = ['train', 'test'], {}, {} for set_entry, set_bsize, set_name in zip(set_entries, set_bsizes, set_names): temp_loss, temp_acc, points = 0.0, 0.0, 0 for b in data.batch(set_entry, set_bsize): temp_loss += loss(params, b.X, b.Y) * b.X.shape[0] temp_acc += accuracy(predict(params, b.X), b.Y) * b.X.shape[0] points += b.X.shape[0] loss_dict[set_name] = temp_loss / float(points) acc_dict[set_name] = temp_acc / float(points) logging.info('Step: %s', str(step)) logging.info('Train acc : %.4f', acc_dict['train']) logging.info('Train loss: %.4f', loss_dict['train']) logging.info('Test acc : %.4f', acc_dict['test']) logging.info('Test loss : %.4f', loss_dict['test']) stdout_log.write('Step: {}\n'.format(step)) stdout_log.write('Train acc : {}\n'.format(acc_dict['train'])) stdout_log.write('Train loss: {}\n'.format(loss_dict['train'])) stdout_log.write('Test acc : {}\n'.format(acc_dict['test'])) stdout_log.write('Test loss : {}\n'.format(loss_dict['test'])) return acc_dict['test'] toc = time.time() logging.info('Elapsed SETUP time: %s', str(toc - tic)) stdout_log.write('Elapsed SETUP time: {}\n'.format(toc - tic)) # BEGIN: training steps logging.info('Training network.') tic = time.time() t = time.time() for s in range(configs.train_steps): b = next(batch) params = get_params(opt_state) # t0 = time.time() if FLAGS.dpsgd: key = random.fold_in(key, s) # get new key for new random numbers opt_state = opt_update( s, private_grad(params, (b.X.reshape( (-1, 1, 32, 32, 3)), b.Y), key, configs.l2_norm_clip, configs.noise_multiplier, configs.batch_size), opt_state) else: opt_state = opt_update(s, grad_loss(params, b.X, b.Y), opt_state) # t1 = time.time() # logging.info('batch update time: %s', str(t1 - t0)) if s % steps_per_epoch == 0: with gfile.Open( '{}/ckpt_{}'.format(FLAGS.exp_dir, int(s / steps_per_epoch)), 'wr') as fckpt: pickle.dump(optimizers.unpack_optimizer_state(opt_state), fckpt) if FLAGS.dpsgd: eps = compute_epsilon(s, configs.batch_size, n_train, configs.target_delta, configs.noise_multiplier) stdout_log.write( 'For delta={:.0e}, current epsilon is: {:.2f}\n'.format( configs.target_delta, eps)) logging.info('Elapsed EPOCH time: %s', str(time.time() - t)) stdout_log.write('Elapsed EPOCH time: {}'.format(time.time() - t)) stdout_log.flush() t = time.time() toc = time.time() summarize(configs.train_steps, params) logging.info('Elapsed TRAIN time: %s', str(toc - tic)) stdout_log.write('Elapsed TRAIN time: {}'.format(toc - tic)) stdout_log.close()
def __init__(self, step_size=0.01): self.step_size = experimental.make_schedule(step_size)
def cga(step_size_f, step_size_g, f, g, linear_op_solver=None, default_max_iter=1000, solve_order='alternating'): if linear_op_solver is None: def default_convergence_test(x_new, x_old): min_type = converge.tree_smallest_float_dtype(x_new) rtol, atol = converge.adjust_tol_for_dtype(1e-10, 1e-10, min_type) return converge.max_diff_test(x_new, x_old, rtol, atol) def default_solver(linear_op, bvec, init_x=None): if init_x is None: init_x = bvec def _step_default_solver(i, x): del i return tree_util.tree_multimap(lax.add, linear_op(x), bvec) return loop.fixed_point_iteration( init_x=init_x, func=_step_default_solver, convergence_test=default_convergence_test, max_iter=default_max_iter, ) linear_op_solver = default_solver step_size_f = optimizers.make_schedule(step_size_f) step_size_g = optimizers.make_schedule(step_size_g) def init(inputs): delta_x, delta_y = tree_util.tree_map(np.zeros_like, inputs) return CGAState( x=inputs[0], y=inputs[1], delta_x=delta_x, delta_y=delta_y, ) def update(i, grads, inputs, *args, **kwargs): if len(inputs) < 4: x, y = inputs delta_x = None delta_y = None else: x, y, delta_x, delta_y = inputs grad_xf, grad_yg = grads eta_f = step_size_f(i) eta_g = step_size_g(i) eta_fg = eta_g * eta_f jvp_xyf = make_mixed_jvp(partial(f, *args, **kwargs), x, y) jvp_yxg = make_mixed_jvp(partial(g, *args, **kwargs), x, y, opposite=True) def linear_op_x(x): return tree_util.tree_map(lambda v: eta_fg * v, jvp_xyf(jvp_yxg(x))) def linear_op_y(y): return tree_util.tree_map(lambda v: eta_fg * v, jvp_yxg(jvp_xyf(y))) def solve_delta_x(init_x): bx = tree_util.tree_multimap( lambda grad_xf, z: grad_xf + eta_g * z, grad_xf, jvp_xyf(grad_yg), ) delta_x = linear_op_solver(linear_op=linear_op_x, bvec=bx, init_x=init_x).value return delta_x def solve_delta_y(init_y): by = tree_util.tree_multimap( lambda z, grad_yg: grad_yg + eta_f * z, jvp_yxg(grad_xf), grad_yg) delta_y = linear_op_solver(linear_op=linear_op_y, bvec=by, init_x=init_y).value return delta_y def solve_x_update_y(deltas): delta_x, _ = deltas delta_x = solve_delta_x(delta_x) delta_y = tree_util.tree_multimap(lambda g_y, v: (g_y + eta_f * v), grad_yg, jvp_yxg(delta_x)) return delta_x, delta_y def solve_y_update_x(deltas): _, delta_y = deltas delta_y = solve_delta_y(delta_y) delta_x = tree_util.tree_multimap(lambda g_x, v: (g_x + eta_g * v), grad_xf, jvp_xyf(delta_y)) return delta_x, delta_y def solve_both(deltas): delta_x, delta_y = deltas delta_x = solve_delta_x(delta_x) delta_y = solve_delta_y(delta_y) return delta_x, delta_y def solve_alternating(deltas): return lax.cond( np.mod(i, 2).astype(bool), deltas, solve_x_update_y, deltas, solve_y_update_x) solver = { 'simultaneous': solve_both, 'alternating': solve_alternating, 'xy': solve_x_update_y, 'yx': solve_y_update_x } delta_x, delta_y = solver[solve_order]((delta_x, delta_y)) x = tree_util.tree_multimap(lambda x, delta_x: x + eta_f * delta_x, x, delta_x) y = tree_util.tree_multimap(lambda y, delta_y: y + eta_g * delta_y, y, delta_y) return CGAState(x, y, delta_x, delta_y) def get_params(state): return state[:2] return init, update, get_params