def testRmspropVectorPiecewiseConstantSchedule(self): def loss(x): return np.dot(x, x) x0 = np.ones(2) step_schedule = optimizers.piecewise_constant([25, 75], [1.0, 0.5, 0.1]) self._CheckFuns(optimizers.rmsprop, loss, x0, step_schedule)
def schedule_maker(schedule_tuple, learn_rate): """ Return a scheduler function given a tuple of the form: (sched_name, decay_steps, min_lr) This just wraps existing JAX schedulers, but using simplified syntax """ sched_type = schedule_tuple[0] assert learn_rate >= 0 assert sched_type in ['const', 'exp', 'poly', 'piecewise'] if sched_type == 'const': # Constant learning rate sched_fun = jopt.constant(learn_rate) elif sched_type == 'exp': # Exponentially decaying learning rate sched_fun = jopt.exponential_decay(learn_rate, schedule_tuple[1], 0.5) elif sched_type == 'poly': # Harmonically decaying stepped learning rate sched_fun = jopt.inverse_time_decay(learn_rate, schedule_tuple[1], 5, staircase=True) elif sched_type == 'piecewise': # Piecewise constant learning rate, drops by factor of 10 each time step_len = schedule_tuple[1] assert step_len > 0 bounds = [step_len * i for i in range(1, 10)] values = [learn_rate * 10**(-i) for i in range(10)] sched_fun = jopt.piecewise_constant(bounds, values) def my_sched_fun(epoch): lr = sched_fun(epoch) if len(schedule_tuple) <= 2: return lr else: return jnp.maximum(lr, schedule_tuple[2]) return my_sched_fun
def main(unused_argv): from jax.api import grad, jit, vmap, pmap, device_put "The following is required to use TPU Driver as JAX's backend." if FLAGS.TPU: config.FLAGS.jax_xla_backend = "tpu_driver" config.FLAGS.jax_backend_target = "grpc://" + os.environ[ 'TPU_ADDR'] + ':8470' TPU_ADDR = os.environ['TPU_ADDR'] ndevices = xla_bridge.device_count() if not FLAGS.TPU: ndevices = 1 pmap = partial(pmap, axis_name='i') """Setup some experiment parameters.""" meas_step = FLAGS.meas_step training_epochs = int(FLAGS.epochs) tmult = 1.0 if FLAGS.physical: tmult = FLAGS.lr if FLAGS.physicalL2: tmult = FLAGS.L2 * tmult if FLAGS.physical: training_epochs = 1 + int(FLAGS.epochs / tmult) print('Evolving for {:}e'.format(training_epochs)) losst = FLAGS.losst learning_rate = FLAGS.lr batch_size_per_device = FLAGS.bs N = FLAGS.N K = FLAGS.K batch_size = batch_size_per_device * ndevices steps_per_epoch = 50000 // batch_size training_steps = training_epochs * steps_per_epoch "Filename from FLAGS" filename = 'wrnL2_' + losst + '_n' + str(N) + '_k' + str(K) if FLAGS.momentum: filename += '_mom' if FLAGS.L2_sch: filename += '_L2sch' + '_decay' + str(FLAGS.L2dec) + '_del' + str( FLAGS.delay) if FLAGS.seed != 1: filename += 'seed' + str(FLAGS.seed) filename += '_L2' + str(FLAGS.L2) if FLAGS.std_wrn_sch: filename += '_stddec' if FLAGS.physical: filename += 'phys' else: filename += '_ctlr' if not FLAGS.augment: filename += '_noaug' if not FLAGS.mix: filename += '_nomixup' filename += '_bs' + str(batch_size) + '_lr' + str(learning_rate) if FLAGS.jobdir is not None: filedir = os.path.join('wrnlogs', FLAGS.jobdir) else: filedir = 'wrnlogs' if not os.path.exists(filedir): os.makedirs(filedir) filedir = os.path.join(filedir, filename + '.csv') print('Saving log to ', filename) print('Found {} cores.'.format(ndevices)) """Load CIFAR10 data and create a minimal pipeline.""" train_images, train_labels, test_images, test_labels = utils.load_data( 'cifar10') train_images = np.reshape(train_images, (-1, 32, 32 * 3)) train = (train_images, train_labels) test = (test_images, test_labels) k = train_labels.shape[-1] train = utils.shard_data(train, ndevices) test = utils.shard_data(test, ndevices) """Create a Wide Resnet and replicate its parameters across the devices.""" initparams, f, _ = utils.WideResnetnt(N, K, k) "Loss and optimizer definitions" l2_norm = lambda params: tree_map(lambda x: np.sum(x**2), params) l2_reg = lambda params: tree_reduce(lambda x, y: x + y, l2_norm(params)) currL2 = FLAGS.L2 L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, ))) def xentr(params, images_and_labels): images, labels = images_and_labels return -np.mean(stax.logsoftmax(f(params, images)) * labels) def mse(params, data_tuple): """MSE loss.""" x, y = data_tuple return 0.5 * np.mean((y - f(params, x))**2) if losst == 'xentr': print('Using xentr') lossm = xentr else: print('Using mse') lossm = mse loss = lambda params, data, L2: lossm(params, data) + L2 * l2_reg(params) def accuracy(params, images_and_labels): images, labels = images_and_labels return np.mean( np.array(np.argmax(f(params, images), axis=1) == np.argmax(labels, axis=1), dtype=np.float32)) "Define optimizer" if FLAGS.std_wrn_sch: lr = learning_rate first_epoch = int(60 / 200 * training_epochs) learning_rate_fn = optimizers.piecewise_constant( np.array([1, 2, 3]) * first_epoch * steps_per_epoch, np.array([lr, lr * 0.2, lr * 0.2**2, lr * 0.2**3])) else: learning_rate_fn = optimizers.make_schedule(learning_rate) if FLAGS.momentum: momentum = 0.9 else: momentum = 0 @pmap def update_step(step, state, batch_state, L2): batch, batch_state = batch_fn(batch_state) params = get_params(state) dparams = grad_loss(params, batch, L2) dparams = tree_map(lambda x: lax.psum(x, 'i') / ndevices, dparams) return step + 1, apply_fn(step, dparams, state), batch_state @pmap def evaluate(state, data, L2): params = get_params(state) lossmm = lossm(params, data) l2mm = l2_reg(params) return lossmm + L2 * l2mm, accuracy(params, data), lossmm, l2mm "Initialization and loading" _, params = initparams(random.PRNGKey(0), (-1, 32, 32, 3)) replicate_array = lambda x: \ np.broadcast_to(x, (ndevices,) + x.shape) replicated_params = tree_map(replicate_array, params) grad_loss = jit(grad(loss)) init_fn, apply_fn, get_params = optimizers.momentum( learning_rate_fn, momentum) apply_fn = jit(apply_fn) key = random.PRNGKey(FLAGS.seed) batchinit_fn, batch_fn = utils.sharded_minibatcher(batch_size, ndevices, transform=FLAGS.augment, k=k, mix=FLAGS.mix) batch_state = pmap(batchinit_fn)(random.split(key, ndevices), train) state = pmap(init_fn)(replicated_params) if FLAGS.checkpointing: ## Loading of checkpoint if available/provided. single_state = init_fn(params) i0, load_state, load_params, filename0, batch_stateb = utils.load_weights( filename, single_state, params, full_file=FLAGS.load_w, ndevices=ndevices) if i0 is not None: filename = filename0 if batch_stateb is not None: batch_state = batch_stateb if load_params is not None: state = pmap(init_fn)(load_params) else: state = load_state else: i0 = 0 else: i0 = 0 if FLAGS.steps_from_load: training_steps = i0 + training_steps batch_xs, _ = pmap(batch_fn)(batch_state) train_loss = [] train_accuracy = [] lrL = [] test_loss = [] test_accuracy = [] test_L2, test_lm, train_lm, train_L2 = [], [], [], [] L2_t = [] idel0 = i0 start = time.time() step = pmap(lambda x: x)(i0 * np.ones((ndevices, ))) "Start training loop" if FLAGS.checkpointing: print('Evolving for {:}e and saving every {:}s'.format( training_epochs, FLAGS.checkpointing)) print( 'Epoch\tLearning Rate\tTrain bareLoss\t L2_norm \tTest Loss\tTrain Error\tTest Error\tTime / Epoch' ) for i in range(i0, training_steps): if i % meas_step == 0: # Make Measurement l, a, lm, L2m = evaluate(state, test, L2p) test_loss += [np.mean(l)] test_accuracy += [np.mean(a)] test_lm += [np.mean(lm)] test_L2 += [np.mean(L2m)] train_batch, _ = pmap(batch_fn)(batch_state) l, a, lm, L2m = evaluate(state, train_batch, L2p) train_loss += [np.mean(l)] train_accuracy += [np.mean(a)] train_lm += [np.mean(lm)] train_L2 += [np.mean(L2m)] L2_t.append(currL2) lrL += [learning_rate_fn(i)] if FLAGS.L2_sch and i > FLAGS.delay / currL2 + idel0 and len( train_lm) > 2 and ((minloss <= train_lm[-1] and minloss <= train_lm[-2]) or (maxacc >= train_accuracy[-1] and maxacc >= train_accuracy[-2])): # If AutoL2 is on and we are beyond the refractory period, decay if the loss or error have increased in the last two measurements. print('Decaying L2 to', currL2 / FLAGS.L2dec) currL2 = currL2 / FLAGS.L2dec L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, ))) idel0 = i elif FLAGS.L2_sch and len(train_lm) >= 2: # Update the minimum values. try: maxacc = max(train_accuracy[-2], maxacc) minloss = min(train_lm[-2], minloss) except: maxacc, minloss = train_accuracy[-2], train_lm[-2] if i % (meas_step * 10) == 0 or i == i0: # Save measurements to csv epoch = batch_size * i / 50000 dt = (time.time() - start) / (meas_step * 10) * steps_per_epoch print(('{}\t' + ('{: .4f}\t' * 7)).format( epoch, learning_rate_fn(i), train_lm[-1], train_L2[-1], test_loss[-1], train_accuracy[-1], test_accuracy[-1], dt)) start = time.time() data = { 'train_loss': train_loss, 'test_loss': test_loss, 'train_acc': train_accuracy, 'test_acc': test_accuracy } data['train_bareloss'] = train_lm data['train_L2'] = train_L2 data['test_bareloss'] = test_lm data['test_L2'] = test_L2 data['L2_t'] = L2_t df = pd.DataFrame(data) df['learning_rate'] = lrL df['width'] = K df['batch_size'] = batch_size df['step'] = i0 + onp.arange(0, len(train_loss)) * meas_step df.to_csv(filedir, index=False) if FLAGS.checkpointing: ### SAVE MODEL if i % FLAGS.checkpointing == 0 and i > i0: if not os.path.exists('weights/'): os.makedirs('weights/') saveparams = tree_flatten(state[0])[0] if ndevices > 1: saveparams = [el[0] for el in saveparams] saveparams = np.concatenate( [el.reshape(-1) for el in saveparams]) step0 = i print('Step', i) print('saving at', filename, step0, 'size:', saveparams.shape) utils.save_weights(filename, step0, saveparams, batch_state) ## UPDATE step, state, batch_state = update_step(step, state, batch_state, L2p) print('Training done') if FLAGS.TPU: with open('done/' + TPU_ADDR, 'w') as fp: fp.write(filedir) pass
def run(): """ Run the experiment. """ # init the model first so that jax gets enough GPU memory before TFDS forward, model = init_model(43) # how do you sleep at night grad_fn = jax.grad(lambda *args: loss_fn(forward, *args)) ds_train, ds_test_eval, meta = init_data() num_batches = meta["num_batches"] num_test_batches = meta["num_test_batches"] lr_schedule = optimizers.piecewise_constant( boundaries=[9000, 12750], # 300 epochs, 425 epochs values=[1e-3, 1e-4, 1e-5]) opt_init, opt_update, get_params = optimizers.adam(step_size=lr_schedule) unravel_opt = ravel_pytree(opt_init(model["params"]))[1] if os.path.exists(parse_args.ckpt_path): outfile = open(parse_args.ckpt_path, 'rb') state_dict = pickle.load(outfile) outfile.close() opt_state = unravel_opt(state_dict["opt_state"]) load_itr = state_dict["itr"] else: init_params = model["params"] opt_state = opt_init(init_params) load_itr = 0 @jax.jit def update(_itr, _opt_state, _key, _batch): """ Update the params based on grad for current batch. """ return opt_update(_itr, grad_fn(get_params(_opt_state), _batch, _key), _opt_state) @jax.jit def sep_losses(_opt_state, _batch, _key): """ Convenience function for calculating losses separately. """ z, delta_logp, r2_regs, fro_regs, kin_regs = model["forward_all"]( _key, get_params(_opt_state), _batch) loss_ = _loss_fn(z, delta_logp) r2_reg_ = _reg_loss_fn(r2_regs) fro_reg_ = _reg_loss_fn(fro_regs) kin_reg_ = _reg_loss_fn(kin_regs) total_loss_ = loss_ + lam * r2_reg_ + lam_fro * fro_reg_ + lam_kin * kin_reg_ return total_loss_, loss_, r2_reg_, fro_reg_, kin_reg_ def evaluate_loss(opt_state, _key, ds_eval): """ Convenience function for evaluating loss over train set in smaller batches. """ sep_loss_aug_, sep_loss_, sep_loss_r2_reg_, sep_loss_fro_reg_, sep_loss_kin_reg_, nfe, bs = \ [], [], [], [], [], [], [] for test_batch_num in range(num_test_batches): _key, = jax.random.split(_key, num=1) test_batch = next(ds_eval) test_batch_loss_aug_, test_batch_loss_, \ test_batch_loss_r2_reg_, test_batch_loss_fro_reg_, test_batch_loss_kin_reg_ = \ sep_losses(opt_state, test_batch, _key) if count_nfe: nfe.append(model["nfe"](_key, get_params(opt_state), test_batch)) else: nfe.append(0) sep_loss_aug_.append(test_batch_loss_aug_) sep_loss_.append(test_batch_loss_) sep_loss_r2_reg_.append(test_batch_loss_r2_reg_) sep_loss_fro_reg_.append(test_batch_loss_fro_reg_) sep_loss_kin_reg_.append(test_batch_loss_kin_reg_) bs.append(len(test_batch)) sep_loss_aug_ = jnp.array(sep_loss_aug_) sep_loss_ = jnp.array(sep_loss_) sep_loss_r2_reg_ = jnp.array(sep_loss_r2_reg_) sep_loss_fro_reg_ = jnp.array(sep_loss_fro_reg_) sep_loss_kin_reg_ = jnp.array(sep_loss_kin_reg_) nfe = jnp.array(nfe) bs = jnp.array(bs) return jnp.average(sep_loss_aug_, weights=bs), \ jnp.average(sep_loss_, weights=bs), \ jnp.average(sep_loss_r2_reg_, weights=bs), \ jnp.average(sep_loss_fro_reg_, weights=bs), \ jnp.average(sep_loss_kin_reg_, weights=bs), \ jnp.average(nfe, weights=bs) itr = 0 info = collections.defaultdict(dict) key = rng for epoch in range(parse_args.nepochs): for i in range(num_batches): key, = jax.random.split(key, num=1) batch = next(ds_train) itr += 1 if itr <= load_itr: continue update_start = time.time() opt_state = update(itr, opt_state, key, batch) tree_flatten(opt_state)[0][0].block_until_ready() update_end = time.time() time_str = "%d %.18f %d\n" % (itr, update_end - update_start, load_itr) outfile = open( "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_time.txt" % (dirname, reg, reg_type, lam, lam_fro, lam_kin), "a") outfile.write(time_str) outfile.close() if itr % parse_args.test_freq == 0: loss_aug_, loss_, loss_r2_reg_, loss_fro_reg_, loss_kin_reg_, nfe_ = \ evaluate_loss(opt_state, key, ds_test_eval) print_str = 'Iter {:04d} | Total (Regularized) Loss {:.6f} | Loss {:.6f} | ' \ 'r {:.6f} | fro {:.6f} | kin {:.6f} | ' \ 'NFE {:.6f}'.format(itr, loss_aug_, loss_, loss_r2_reg_, loss_fro_reg_, loss_kin_reg_, nfe_) print(print_str) outfile = open( "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_info.txt" % (dirname, reg, reg_type, lam, lam_fro, lam_kin), "a") outfile.write(print_str + "\n") outfile.close() info[itr]["loss_aug"] = loss_aug_ info[itr]["loss"] = loss_ info[itr]["loss_r2_reg"] = loss_r2_reg_ info[itr]["loss_fro_reg"] = loss_fro_reg_ info[itr]["loss_kin_reg"] = loss_kin_reg_ info[itr]["nfe"] = nfe_ if itr % parse_args.save_freq == 0: param_filename = "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_%d_fargs.pickle" \ % (dirname, reg, reg_type, lam, lam_fro, lam_kin, itr) fargs = get_params(opt_state) outfile = open(param_filename, "wb") pickle.dump(fargs, outfile) outfile.close() if itr % parse_args.ckpt_freq == 0: state_dict = { "opt_state": ravel_pytree(opt_state)[0], "itr": itr, } # only save ckpts if a directory has been made for them (allow easy switching between v1 and v2) try: outfile = open(parse_args.ckpt_path, 'wb') pickle.dump(state_dict, outfile) outfile.close() except IOError: print("Unable to save ck.pt %d" % itr, file=sys.stderr) meta = {"info": info, "args": parse_args} outfile = open( "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_%d_meta.pickle" % (dirname, reg, reg_type, lam, lam_fro, lam_kin, itr), "wb") pickle.dump(meta, outfile) outfile.close()