def main(batch_size=32, nr_filters=8, epochs=10, step_size=.001, decay_rate=.999995, model_path=Path('./pixelcnn.params')): loss, _ = PixelCNNPP(nr_filters=nr_filters) get_train_batches, test_batches = dataset(batch_size) key, init_key = random.split(PRNGKey(0)) opt = Adam(exponential_decay(step_size, 1, decay_rate)) state = opt.init(loss.init_parameters(next(test_batches), key=init_key)) for epoch in range(epochs): for batch in get_train_batches(): key, update_key = random.split(key) i = opt.get_step(state) state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, key=update_key, jit=True) if i % 100 == 0 or i < 10: key, test_key = random.split(key) test_loss = loss.apply(opt.get_parameters(state), next(test_batches), key=test_key, jit=True) print(f"Epoch {epoch}, iteration {i}, " f"train loss {train_loss:.3f}, " f"test loss {test_loss:.3f} ") save(opt.get_parameters(state), model_path)
def testSgdVectorExponentialDecaySchedule(self): def loss(x): return np.dot(x, x) x0 = np.ones(2) step_schedule = optimizers.exponential_decay(0.1, 3, 2.) self._CheckFuns(optimizers.sgd, loss, x0, step_schedule)
def get_optimizer(optim_config): """ returns an ADAM optimizer with exponential learning-rate decay schedule specified in the config """ learning_rate = optimizers.exponential_decay( optim_config['base_lr'], optim_config['lr_decay_steps'], optim_config['lr_decay_rate']) opt = optimizers.adam(learning_rate) return opt
def get_scheduler(lr, train_steps, name='constant'): name = name.lower() if name == 'constant': scheduler = optimizers.constant(lr) elif name == 'inverse_time_decay': decay_steps = int(train_steps // 5) scheduler = optimizers.inverse_time_decay(lr, decay_steps, 2) elif name == 'exponential_decay': decay_steps = int(train_steps // 3) scheduler = optimizers.exponential_decay(lr, decay_steps, 0.3) else: raise ValueError(f'Not supported scheduler {name}.' f'Supported schedulers={supported_schedulers()}') print(f'Loaded a scheduler {name} - {scheduler}') return scheduler
def schedule_maker(schedule_tuple, learn_rate): """ Return a scheduler function given a tuple of the form: (sched_name, decay_steps, min_lr) This just wraps existing JAX schedulers, but using simplified syntax """ sched_type = schedule_tuple[0] assert learn_rate >= 0 assert sched_type in ['const', 'exp', 'poly', 'piecewise'] if sched_type == 'const': # Constant learning rate sched_fun = jopt.constant(learn_rate) elif sched_type == 'exp': # Exponentially decaying learning rate sched_fun = jopt.exponential_decay(learn_rate, schedule_tuple[1], 0.5) elif sched_type == 'poly': # Harmonically decaying stepped learning rate sched_fun = jopt.inverse_time_decay(learn_rate, schedule_tuple[1], 5, staircase=True) elif sched_type == 'piecewise': # Piecewise constant learning rate, drops by factor of 10 each time step_len = schedule_tuple[1] assert step_len > 0 bounds = [step_len * i for i in range(1, 10)] values = [learn_rate * 10**(-i) for i in range(10)] sched_fun = jopt.piecewise_constant(bounds, values) def my_sched_fun(epoch): lr = sched_fun(epoch) if len(schedule_tuple) <= 2: return lr else: return jnp.maximum(lr, schedule_tuple[2]) return my_sched_fun
def optimize_lfads(key, init_params, hps, opt_hps, train_data_fun, eval_data_fun): """Optimize the LFADS model and print batch based optimization data. This loop is at the cpu nonjax-numpy level. Arguments: init_params: a dict of parameters to be trained hps: dict of lfads model HPs opt_hps: dict of optimization HPs train_data_fun: function that takes a key and returns nexamples x time x ndims np array of data for training eval_data_fun: function that takes a key and returns nexamples x time x ndims np array of data for held out error Returns: a dictionary of trained parameters""" # Begin optimziation loop. all_tlosses = [] all_elosses = [] # Build some functions used in optimization. kl_warmup_fun = get_kl_warmup_fun(opt_hps) decay_fun = optimizers.exponential_decay(opt_hps['step_size'], opt_hps['decay_steps'], opt_hps['decay_factor']) opt_init, opt_update, get_params = optimizers.adam(step_size=decay_fun, b1=opt_hps['adam_b1'], b2=opt_hps['adam_b2'], eps=opt_hps['adam_eps']) opt_state = opt_init(init_params) def update_w_gc(i, opt_state, hps, opt_hps, key, x_bxt, kl_warmup): """Update fun for gradients, includes gradient clipping.""" params = get_params(opt_state) grads = grad(lfads.training_loss_jit)(params, hps, key, x_bxt, kl_warmup, opt_hps['keep_rate']) clipped_grads = optimizers.clip_grads(grads, opt_hps['max_grad_norm']) return opt_update(i, clipped_grads, opt_state) update_w_gc_jit = jit(update_w_gc, static_argnums=(2, 3)) # Run the optimization, pausing every so often to collect data and # print status. batch_size = hps['batch_size'] num_batches = opt_hps['num_batches'] print_every = opt_hps['print_every'] num_opt_loops = int(num_batches / print_every) params = get_params(opt_state) for oidx in range(num_opt_loops): batch_idx_start = oidx * print_every start_time = time.time() key, tkey, dtkey1, dtkey2, dekey1, dekey2 = \ random.split(random.fold_in(key, oidx), 6) opt_state = optimize_core_jit(tkey, batch_idx_start, print_every, update_w_gc_jit, kl_warmup_fun, opt_state, hps, opt_hps, train_data_fun) batch_time = time.time() - start_time # Losses params = get_params(opt_state) batch_pidx = batch_idx_start + print_every kl_warmup = kl_warmup_fun(batch_idx_start) # Training loss #didxs = onp.random.randint(0, train_data.shape[0], batch_size) #x_bxt = train_data[didxs].astype(onp.float32) x_bxt = train_data_fun(dtkey1) tlosses = lfads.losses_jit(params, hps, dtkey2, x_bxt, kl_warmup, 1.0) # Evaluation loss #didxs = onp.random.randint(0, eval_data.shape[0], batch_size) #ex_bxt = eval_data[didxs].astype(onp.float32) ex_bxt = eval_data_fun(dekey1) elosses = lfads.losses_jit(params, hps, dekey2, ex_bxt, kl_warmup, 1.0) # Saving, printing. resps = softmax(params['prior']['resps']) rmin = onp.min(resps) rmax = onp.max(resps) rmean = onp.mean(resps) rstd = onp.std(resps) all_tlosses.append(tlosses) all_elosses.append(elosses) s1 = "Batches {}-{} in {:0.2f} sec, Step size: {:0.5f}" s2 = " Training losses {:0.0f} = NLL {:0.0f} + KL {:0.1f},{:0.1f} + L2 {:0.2f} + II L2 {:0.2f} + <II> {:0.2f} " s3 = " Eval losses {:0.0f} = NLL {:0.0f} + KL {:0.1f},{:0.1f} + L2 {:0.2f} + II L2 {:0.2f} + <II> {:0.2f} " s4 = " Resps: min {:0.4f}, mean {:0.4f}, max {:0.4f}, std {:0.4f}" print( s1.format(batch_idx_start + 1, batch_pidx, batch_time, decay_fun(batch_pidx))) print( s2.format(tlosses['total'], tlosses['nlog_p_xgz'], tlosses['kl_prescale'], tlosses['kl'], tlosses['l2'], tlosses['ii_l2'], tlosses['ii_tavg'])) print( s3.format(elosses['total'], elosses['nlog_p_xgz'], elosses['kl_prescale'], elosses['kl'], elosses['l2'], elosses['ii_l2'], elosses['ii_tavg'])) print(s4.format(rmin, rmean, rmax, rstd)) tlosses_thru_training = utils.merge_losses_dicts(all_tlosses) elosses_thru_training = utils.merge_losses_dicts(all_elosses) optimizer_details = { 'tlosses': tlosses_thru_training, 'elosses': elosses_thru_training } return params, optimizer_details
T_train = df.pop("week").values E_train = df.pop("arrest").values X_train = df.values return X_train, T_train, E_train x_train, t_train, e_train = get_rossi_dataset() model = Model([Dense(18), Relu]) model.compile( optimizer=optimizers.adam, optimizer_kwargs={ "step_size": optimizers.exponential_decay(0.01, 10, 0.999) }, loss=losses.NonParametric(), ) model.fit(x_train, t_train, e_train, epochs=2, batch_size=32) print(model.predict_survival_function(x_train[0], np.arange(0, 10))) dump(model, "testsavefile") model = load("testsavefile") print(model.predict_survival_function(x_train[0], np.arange(0, 10))) model.fit( x_train,
def train(data_dict, train_dict, seed_dict, results_dict): """Train HM-NLICA model using a minibatch implementation of the algorithm described in the paper. Args: data_dict (dict.): dictionary of required data in the form of: {'x_data': observed signals (array), 's_data': true latent component, for evaluation (array), 'state_seq': true latent state sequece (array)}. train_dict (dict.): dictionary of variables related to optimization of form: {'mix_depth': num. layers in mixing/estimator MLP (int), for example mix_depth=1 is linear ICA, 'hidden_size': num. hidden units per MLP layer (int), 'learning_rate': step size for optimizer (float), 'num_epochs': num. training epochs (int), 'subseq_len': length of time sequences in a minibatch (int), 'minib_size': num. sub-sequences in a minibatch (int), 'decay_rate': multiplier for decaying learning rate (float), 'decay_steps': num. epochs per which to decay lr (int)}. seed_dict (dict.): dictionary of seeds for reproducible stochasticity of form: {'est_mlp_seed': seed to initialize MLP parameters (int), 'est_distrib_seed': seed to initialize exp fam params (int)}. results_dict (dict.): stores data to save (see main.py). Returns: s_est (array): estimated independent components. sort_idx (array): best matching indices of components to true indices. results_dict (dict): to save all evaluation and training results. est_params (list): list of all estimated parameter arrays. """ # unpack data x = data_dict['x_data'] s_true = data_dict['s_data'] state_seq = data_dict['state_seq'] # set data dimensions N = x.shape[1] T = x.shape[0] K = len(np.unique(state_seq)) # unpack training variables mix_depth = train_dict['mix_depth'] hidden_size = train_dict['hidden_size'] learning_rate = train_dict['learning_rate'] num_epochs = train_dict['num_epochs'] subseq_len = train_dict['subseq_len'] minib_size = train_dict['minib_size'] decay_rate = train_dict['decay_rate'] decay_steps = train_dict['decay_steps'] print("Training with N={n}, T={t}, K={k}\t" "mix_depth={md}".format(n=N, t=T, k=K, md=mix_depth)) # initialize parameters for mlp function approximator key = jrandom.PRNGKey(seed_dict['est_mlp_seed']) layer_sizes = [N] + [hidden_size] * (mix_depth - 1) + [N] mlp_params = init_mlp_params(key, layer_sizes) # initialize parameters for estimating distribution parameters np.random.seed(seed_dict['est_distrib_seed']) mu_est = np.random.uniform(-5., 5., size=(K, N)) var_est = np.random.uniform(1., 2., size=(K, N)) D_est = np.zeros(shape=(K, N, N)) for k in range(K): D_est[k] = np.diag(var_est[k]) # initialize transition parameter estimates A_est = np.eye(K) + 0.05 A_est = A_est / A_est.sum(1, keepdims=True) pi_est = A_est.sum(0) / A_est.sum() # set up optimizer schedule = optimizers.exponential_decay(learning_rate, decay_steps=decay_steps, decay_rate=decay_rate) opt_init, opt_update, get_params = optimizers.adam(schedule) # set up loss function and training step @jit def calc_loss(params, input_data, marginal_posteriors, mu_est, D_est, num_subseqs): """Calculates the loss for gradient M-step for function estimator. """ lp_x, lp_x_exc_J, lp_J, _ = mbatch_emission_likelihood( params, input_data, mu_est, D_est) expected_lp_x = jnp.sum(marginal_posteriors * lp_x, -1) # note correction for bias below return -expected_lp_x.mean() * num_subseqs @jit def training_step(iter_num, input_data, marginal_posteriors, mu_est, D_est, opt_state, num_subseqs): """Performs gradient m-step on the function estimator MLP parameters. """ params = get_params(opt_state) loss, g = value_and_grad( calc_loss, argnums=0)(params, input_data, lax.stop_gradient(marginal_posteriors), mu_est, D_est, num_subseqs) return loss, opt_update(iter_num, g, opt_state) # function to load subsequence data for minibatches @jit def get_subseq_data(orig_data, subseq_array_to_fill): """Collects all sub-sequences into an array. """ subseq_data = subseq_array_to_fill num_subseqs = subseq_data.shape[0] subseq_len = subseq_data.shape[1] def body_fun(i, subseq_data): """Function to loop over. """ subseq_i = lax.dynamic_slice_in_dim(orig_data, i, subseq_len) subseq_data = ops.index_update(subseq_data, ops.index[i, :, :], subseq_i) return subseq_data return lax.fori_loop(0, num_subseqs, body_fun, subseq_data) # set up minibatch training num_subseqs = T - subseq_len + 1 assert num_subseqs >= minib_size num_full_minibs, remainder = divmod(num_subseqs, minib_size) num_minibs = num_full_minibs + bool(remainder) sub_data_holder = jnp.zeros((num_subseqs, subseq_len, N)) sub_data = get_subseq_data(x, sub_data_holder) print("T: {t}\t" "subseq_len: {slen}\t" "minibatch size: {mbs}\t" "num minibatches: {nbs}".format(t=T, slen=subseq_len, mbs=minib_size, nbs=num_minibs)) # initialize and train best_logl = -np.inf itercount = itertools.count() opt_state = opt_init(mlp_params) all_subseqs_idx = np.arange(num_subseqs) for epoch in range(num_epochs): tic = time.time() # shuffle subseqs for added stochasticity np.random.shuffle(all_subseqs_idx) sub_data = sub_data.copy()[all_subseqs_idx] # train over minibatches for batch in range(num_minibs): # select sub-sequence for current minibatch batch_data = sub_data[batch * minib_size:(batch + 1) * minib_size] # calculate emission likelihood using most recent parameters params = get_params(opt_state) logp_x, logp_x_exc_J, lpj, s_est = mbatch_emission_likelihood( params, batch_data, mu_est, D_est) # forward-backward algorithm marg_posteriors, pw_posteriors, scalers = mbatch_fwd_bwd_algo( logp_x, A_est, pi_est) # exact M-step for mean and variance mu_est, D_est, A_est, pi_est = mbatch_m_step( s_est, marg_posteriors, pw_posteriors) # SGD for mlp parameters loss, opt_state = training_step(next(itercount), batch_data, marg_posteriors, mu_est, D_est, opt_state, num_subseqs) # gather full data after each epoch for evaluation params_latest = get_params(opt_state) logp_x_all, _, _, s_est_all = emission_likelihood( params_latest, x, mu_est, D_est) _, _, scalers = forward_backward_algo(logp_x_all, A_est, pi_est) logl_all = np.log(scalers).sum() # viterbi to estimate state prediction est_seq = viterbi_algo(logp_x_all, A_est, pi_est) cluster_acc = clustering_acc(np.array(est_seq), np.array(state_seq)) # evaluate correlation of estimated and true independent components mean_abs_corr, s_est_sorted, sort_idx = matching_sources_corr( np.array(s_est_all), np.array(s_true)) # save results if logl_all > best_logl: best_logl = logl_all best_logl_corr = mean_abs_corr best_logl_acc = cluster_acc results_dict['results'].append({ 'best_logl': best_logl, 'best_logl_corr': mean_abs_corr, 'best_logl_acc': cluster_acc }) results_dict['results'].append({ 'epoch': epoch, 'logl': logl_all, 'corr': mean_abs_corr, 'acc': cluster_acc }) # print them print("Epoch: [{0}/{1}]\t" "LogL: {logl:.2f}\t" "mean corr between s and s_est {corr:.2f}\t" "acc {acc:.2f}\t" "elapsed {time:.2f}".format(epoch, num_epochs, logl=logl_all, corr=mean_abs_corr, acc=cluster_acc, time=time.time() - tic)) # pack data into tuples results_dict['results'].append({ 'best_logl': best_logl, 'best_logl_corr': best_logl_corr, 'best_logl_acc': best_logl_acc }) est_params = (mu_est, D_est, A_est, est_seq) return s_est, sort_idx, results_dict, est_params
def testSgdVectorExponentialDecaySchedule(self): def loss(x, _): return np.dot(x, x) x0 = np.ones(2) num_iters = 100 step_schedule = optimizers.exponential_decay(0.1, 3, 2.) self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_schedule)
latent_shape, init_encoder_params = encoder_init(next(keys), (n_timesteps, n_neurons)) decoded_shape, init_decoder_params = decoder_init(next(keys), latent_shape) init_params = init_encoder_params, init_decoder_params # Optimizer # def kl_warmup_fun(batch_idx): progress_frac = ((batch_idx - kl_warmup_start) / (kl_warmup_end - kl_warmup_start)) _warmup = np.where(batch_idx < kl_warmup_start, kl_min, (kl_max - kl_min) * progress_frac + kl_min) return np.where(batch_idx > kl_warmup_end, kl_max, _warmup) decay_fun = optimizers.exponential_decay(STEP_SIZE, DECAY_STEPS, DECAY_FACTOR) # TODO: Check exponential_decay when using epochs / batches. opt_init, opt_update, get_params = optimizers.adam(step_size=decay_fun, b1=0.9, b2=0.999, eps=1e-1) # Seems big opt_state = opt_init(init_params) @jit def run_epoch(rng, _opt_state, epoch_idx): _rng, dat_keys = utils.keygen(rng, 1) _rng, batch_keys = utils.keygen(_rng, num_batches) # Randomize epoch data. epoch_data = random.shuffle(next(dat_keys), X_train, axis=0)
def optimize_lfads(key, init_params, hps, opt_hps, train_data_fun, eval_data_fun, ncompleted_batches=0, opt_state=None, callback_fun=None, do_print=True): """Optimize the LFADS model and print batch based optimization data. This loop is at the cpu nonjax-numpy level. Arguments: key: random.PRNGKey for randomness init_params: a dict of parameters to be trained hps: dict of lfads model HPs opt_hps: dict of optimization HPs train_data_fun: function that takes a key and returns nexamples x time x ndims np array of data for training eval_data_fun: function that takes a key and returns nexamples x time x ndims np array of data for held out error ncompleted_batches: (default 0), use this to restart training in the middle of the batch count. Used in tandem with opt_state (below). opt_state: (default None) 3-tuple (params, m - 1st moment, v - 2nd moment) from jax.experimental.optimizers.adam (None value starts optimizer anew). The params in opt_state[0] will *override* the init_params argument. callback_fun: (default None) function that the optimzie routine will call every print_every loops, in order to do whatever the user wants, typically saving, or reporting to a hyperparameter tuner, etc. callback_fun parameters are (current_batch_idx:int, hps:dict, opt_hps:dict, params:dict, opt_state:tuple, tlosses:dict, elosses:dict) do_print: (default True), print loss information Returns: A 3-tuple of (trained_params, opt_details - dictionary of optimization losses through training, (opt_state - a 3-tuple of trained params in odd pytree form, m 1st moment, v 2nd moment)). """ # Begin optimziation loop. all_tlosses = [] all_elosses = [] # Build some functions used in optimization. kl_warmup_fun = get_kl_warmup_fun(opt_hps) decay_fun = optimizers.exponential_decay(opt_hps['step_size'], opt_hps['decay_steps'], opt_hps['decay_factor']) opt_init, opt_update, get_params = optimizers.adam(step_size=decay_fun, b1=opt_hps['adam_b1'], b2=opt_hps['adam_b2'], eps=opt_hps['adam_eps']) print_every = opt_hps['print_every'] if ncompleted_batches > 0: print('Starting batch count at %d.' % (ncompleted_batches)) assert ncompleted_batches % print_every == 0 opt_loop_start_idx = int(ncompleted_batches / print_every) else: opt_loop_start_idx = 0 if opt_state is not None: print('Received opt_state, ignoring init_params argument.') else: opt_state = opt_init(init_params) def update_w_gc(i, opt_state, hps, opt_hps, key, x_bxt, kl_warmup): """Update fun for gradients, includes gradient clipping.""" params = get_params(opt_state) grads = grad(lfads.training_loss_jit)(params, hps, key, x_bxt, kl_warmup, opt_hps['keep_rate']) clipped_grads = optimizers.clip_grads(grads, opt_hps['max_grad_norm']) return opt_update(i, clipped_grads, opt_state) update_w_gc_jit = jit(update_w_gc, static_argnums=(2, 3)) # Run the optimization, pausing every so often to collect data and # print status. batch_size = hps['batch_size'] num_batches = opt_hps['num_batches'] assert num_batches % print_every == 0 num_opt_loops = int(num_batches / print_every) params = get_params(opt_state) for oidx in range(opt_loop_start_idx, num_opt_loops): batch_idx_start = oidx * print_every start_time = time.time() key, tkey, dtkey1, dtkey2, dekey1, dekey2 = \ random.split(random.fold_in(key, oidx), 6) opt_state = optimize_core_jit(tkey, batch_idx_start, print_every, update_w_gc_jit, kl_warmup_fun, opt_state, hps, opt_hps, train_data_fun) batch_time = time.time() - start_time # Losses params = get_params(opt_state) batch_pidx = batch_idx_start + print_every kl_warmup = kl_warmup_fun(batch_idx_start) # Training loss x_bxt = train_data_fun(dtkey1) tlosses = lfads.losses_jit(params, hps, dtkey2, x_bxt, kl_warmup, 1.0) # Evaluation loss ex_bxt = eval_data_fun(dekey1) elosses = lfads.losses_jit(params, hps, dekey2, ex_bxt, kl_warmup, 1.0) # Saving, printing. resps = softmax(params['prior']['resps']) rmin = onp.min(resps) rmax = onp.max(resps) rmean = onp.mean(resps) rstd = onp.std(resps) all_tlosses.append(tlosses) all_elosses.append(elosses) if do_print: s1 = "Batches {}-{} in {:0.2f} sec, Step size: {:0.5f}" s2 = " Training losses {:0.0f} = NLL {:0.0f} + KL {:0.1f},{:0.1f} + L2 {:0.2f} + II L2 {:0.2f} + <II> {:0.2f} " s3 = " Eval losses {:0.0f} = NLL {:0.0f} + KL {:0.1f},{:0.1f} + L2 {:0.2f} + II L2 {:0.2f} + <II> {:0.2f} " s4 = " Resps: min {:0.4f}, mean {:0.4f}, max {:0.4f}, std {:0.4f}" print( s1.format(batch_idx_start + 1, batch_pidx, batch_time, decay_fun(batch_pidx))) print( s2.format(tlosses['total'], tlosses['nlog_p_xgz'], tlosses['kl_prescale'], tlosses['kl'], tlosses['l2'], tlosses['ii_l2'], tlosses['ii_tavg'])) print( s3.format(elosses['total'], elosses['nlog_p_xgz'], elosses['kl_prescale'], elosses['kl'], elosses['l2'], elosses['ii_l2'], elosses['ii_tavg'])) print(s4.format(rmin, rmean, rmax, rstd)) if callback_fun is not None: callback_fun(batch_pidx, hps, opt_hps, params, opt_state, tlosses, elosses) tlosses_thru_training = utils.merge_losses_dicts(all_tlosses) elosses_thru_training = utils.merge_losses_dicts(all_elosses) optimizer_details = { 'tlosses': tlosses_thru_training, 'elosses': elosses_thru_training } return params, optimizer_details, opt_state
from jax.experimental.stax import Dense, Dropout, Tanh, Relu, randn from jax.experimental import optimizers import pandas as pd import lifelike.losses as losses from lifelike import Model from lifelike.callbacks import * from datasets.loaders import * x_train, t_train, e_train = get_generated_churn_dataset() model = Model([Dense(8), Relu, Dense(12), Relu, Dense(16), Relu]) model.compile( optimizer=optimizers.adam, optimizer_kwargs={"step_size": optimizers.exponential_decay(0.001, 1, 0.9995)}, weight_l2=0.00, smoothing_l2=100., loss=losses.NonParametric() ) print(model) model.fit( x_train, t_train, e_train, epochs=10000, batch_size=10000, validation_split=0.1, callbacks=[
data_loss = multiclass_xent(logits, batch['labels']) reg_loss = l2_pen * renn.norm(params) return data_loss + reg_loss f_df = jax.value_and_grad(xent) @jax.jit def accuracy(params, batch): logits = apply_fun(params, batch['inputs']) predictions = jnp.argmax(logits, axis=1) return jnp.mean(predictions == batch['labels']) learning_rate = optimizers.exponential_decay(2e-3, 1000, 0.8) init_opt, update_opt, get_params = optimizers.adam(learning_rate) state = init_opt(initial_params) losses = [] @jax.jit def step(k, opt_state, batch): params = get_params(opt_state) loss, gradients = f_df(params, batch) new_state = update_opt(k, gradients, opt_state) return new_state, loss def test_acc(params):
# return memo[n] # def reverse_fib(n): # """ Return the index of the greatest number from the Fibonacci sequence, # that is smaller than or equal to n. """ # i = 0 # while fib(i+1) <= n: # i += 1 # return i #-------------------- optimizer and LR schedule --------------------# step_size = 1e-2 decay_rate = 0.65 # 0.65 ** 10 = 0.01 ---> decaying the step size 10 times ammounts to dividing by 100 decay_steps = 10 step_fn = optimizers.exponential_decay(step_size=step_size, decay_rate=decay_rate, decay_steps=decay_steps) opt_init, opt_update, get_params = optimizers.nesterov(step_size=step_fn, mass=0.9) #-------------------- params training utilities --------------------# reg = 3e-5 clip_max_grad = 10.0 init_fun, apply_fun = model_fn() apply_fun = jax.jit(apply_fun) @jax.jit def l2_regularizer(params, reg=reg): """ Return the L2 regularization loss. """
def optimize_fps(rnn_fun, fp_candidates, hps, do_print=True): """Find fixed points of the rnn via optimization. This loop is at the cpu non-JAX level. Arguments: rnn_fun : RNN one step update function for a single hidden state vector h_t -> h_t+1, for which the fixed point candidates are trained to be fixed points fp_candidates: np array with shape (batch size, state dim) of hidden states of RNN to start training for fixed points hps: fixed point hyperparameters do_print: Print useful information? Returns: np array of numerically optimized fixed points""" total_fp_loss_fun = get_total_fp_loss_fun(rnn_fun) def get_update_fun(opt_update): """Update the parameters using gradient descent. Arguments: opt_update: a function that updates the parameters (from jax.optimizers) Returns: a 2-tuple (function which updates the parameters according to the optimizer, a dictionary of details of the optimization) """ def update(i, opt_state): params = optimizers.get_params(opt_state) grads = grad(total_fp_loss_fun)(params) return opt_update(i, grads, opt_state) return update # Build some functions used in optimization. decay_fun = optimizers.exponential_decay(hps['step_size'], hps['decay_steps'], hps['decay_factor']) opt_init, opt_update = optimizers.adam(step_size=decay_fun, b1=hps['adam_b1'], b2=hps['adam_b2'], eps=hps['adam_eps']) opt_state = opt_init(fp_candidates) update_fun = get_update_fun(opt_update) # Run the optimization, pausing every so often to collect data and # print status. batch_size = fp_candidates.shape[0] num_batches = hps['num_batches'] print_every = hps['opt_print_every'] num_opt_loops = int(num_batches / print_every) fps = optimizers.get_params(opt_state) fp_losses = [] do_stop = False for oidx in range(num_opt_loops): if do_stop: break batch_idx_start = oidx * print_every start_time = time.time() opt_state = optimize_fp_core_jit(batch_idx_start, print_every, update_fun, opt_state) batch_time = time.time() - start_time # Training loss fps = optimizers.get_params(opt_state) batch_pidx = batch_idx_start + print_every total_fp_loss = total_fp_loss_fun(fps) fp_losses.append(total_fp_loss) # Saving, printing. if do_print: s = " Batches {}-{} in {:0.2f} sec, Step size: {:0.5f}, Training loss {:0.5f}" print( s.format(batch_idx_start + 1, batch_pidx, batch_time, decay_fun(batch_pidx), total_fp_loss)) if total_fp_loss < hps['fp_opt_stop_tol']: do_stop = True if do_print: print( 'Stopping as mean training loss {:0.5f} is below tolerance {:0.5f}.' .format(total_fp_loss, hps['fp_opt_stop_tol'])) optimizer_details = {'fp_losses': fp_losses} return fps, optimizer_details
def optimize_lfads(init_params, lfads_hps, lfads_opt_hps, train_data, eval_data): """Optimize the LFADS model and print batch based optimization data. Arguments: init_params: a dict of parameters to be trained lfads_hps: dict of lfads model HPs lfads_opt_hps: dict of optimization HPs train_data: nexamples x time x ndims np array of data for training eval_data: nexamples x time x ndims np array of data for evaluation Returns: a dictionary of trained parameters""" batch_size = lfads_hps['batch_size'] num_batches = lfads_opt_hps['num_batches'] print_every = lfads_opt_hps['print_every'] # Build some functions used in optimization. kl_warmup_fun = get_kl_warmup_fun(lfads_opt_hps) decay_fun = optimizers.exponential_decay(lfads_opt_hps['step_size'], lfads_opt_hps['decay_steps'], lfads_opt_hps['decay_factor']) opt_init, opt_update = optimizers.adam(step_size=decay_fun, b1=lfads_opt_hps['adam_b1'], b2=lfads_opt_hps['adam_b2'], eps=lfads_opt_hps['adam_eps']) update_w_gc = get_update_w_gc_fun(init_params, opt_update) update_w_gc_jit = jit(update_w_gc, static_argnums=(2, 3)) # Begin optimziation loop. all_tlosses = [] all_elosses = [] start_time = time.time() opt_state = opt_init(init_params) for bidx in range(num_batches): kl_warmup = kl_warmup_fun(bidx) didxs = onp.random.randint(0, train_data.shape[0], batch_size) x_bxt = train_data[didxs].astype(onp.float32) key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT)) opt_state = update_w_gc_jit(bidx, opt_state, lfads_hps, lfads_opt_hps, key, x_bxt, kl_warmup) if bidx % print_every == 0: params = optimizers.get_params(opt_state) # Training loss didxs = onp.random.randint(0, train_data.shape[0], batch_size) x_bxt = train_data[didxs].astype(onp.float32) key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT)) tlosses = lfads.lfads_losses_jit(params, lfads_hps, key, x_bxt, kl_warmup, 1.0) # Evaluation loss key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT)) didxs = onp.random.randint(0, eval_data.shape[0], batch_size) ex_bxt = eval_data[didxs].astype(onp.float32) # Commented out lfads_eval_losses_jit cuz freezing. elosses = lfads.lfads_losses_jit(params, lfads_hps, key, ex_bxt, kl_warmup, 1.0) # Saving, printing. all_tlosses.append(tlosses) all_elosses.append(elosses) batch_time = time.time() - start_time s = "Batch {} in {:0.2f} sec, Step size: {:0.5f}, \ Training loss {:0.0f}, Eval loss {:0.0f}" print( s.format(bidx, batch_time, decay_fun(bidx), tlosses['total'], elosses['total'])) start_time = time.time() tlosses_thru_training = utils.merge_losses_dicts(all_tlosses) elosses_thru_training = utils.merge_losses_dicts(all_elosses) optimizer_details = { 'tlosses': tlosses_thru_training, 'elosses': elosses_thru_training } return optimizers.get_params(opt_state), optimizer_details
# Plot a few input/target examples to make sure things look sane. do_plot = False if do_plot: ntoplot = 10 key, subkey = random.split(key, 2) skeys = random.split(subkey, ntoplot) inputs, targets = integrator.build_inputs_and_targets_jit( input_params, skeys) plot_batch(ntimesteps, inputs, targets) ### TRAINING # Init some parameters for training. key, subkey = random.split(key, 2) init_params = rnn.random_vrnn_params(subkey, u, n, o, g=param_scale) decay_fun = optimizers.exponential_decay(step_size, decay_steps, decay_factor) opt_init, opt_update = optimizers.adam(decay_fun, adam_b1, adam_b2, adam_eps) opt_state = opt_init(init_params) # Run the optimization loop, first jit'd calls will take a minute. start_time = time.time() for batch in range(num_batchs): key, subkey = random.split(key, 2) skeys = random.split(subkey, batch_size) inputs, targets = integrator.build_inputs_and_targets_jit( input_params, skeys) opt_state = rnn.update_w_gc_jit(batch, opt_state, opt_update, inputs, targets, max_grad_norm, l2reg) if batch % print_every == 0: params = optimizers.get_params(opt_state) train_loss = rnn.loss_jit(params, inputs, targets, l2reg) batch_time = time.time() - start_time
def optimize_lfads(key, init_params, lfads_hps, lfads_opt_hps, train_data, eval_data): """Optimize the LFADS model and print batch based optimization data. This loop is at the cpu nonjax-numpy level. Arguments: init_params: a dict of parameters to be trained lfads_hps: dict of lfads model HPs lfads_opt_hps: dict of optimization HPs train_data: nexamples x time x ndims np array of data for training Returns: a dictionary of trained parameters""" # Begin optimziation loop. all_tlosses = [] all_elosses = [] # Build some functions used in optimization. kl_warmup_fun = get_kl_warmup_fun(lfads_opt_hps) decay_fun = optimizers.exponential_decay(lfads_opt_hps['step_size'], lfads_opt_hps['decay_steps'], lfads_opt_hps['decay_factor']) opt_init, opt_update, get_params = optimizers.adam( step_size=decay_fun, b1=lfads_opt_hps['adam_b1'], b2=lfads_opt_hps['adam_b2'], eps=lfads_opt_hps['adam_eps']) opt_state = opt_init(init_params) def update_w_gc(i, opt_state, lfads_hps, lfads_opt_hps, key, x_bxt, kl_warmup): """Update fun for gradients, includes gradient clipping.""" params = get_params(opt_state) grads = grad(lfads.lfads_training_loss)(params, lfads_hps, key, x_bxt, kl_warmup, lfads_opt_hps['keep_rate']) clipped_grads = optimizers.clip_grads(grads, lfads_opt_hps['max_grad_norm']) return opt_update(i, clipped_grads, opt_state) # Run the optimization, pausing every so often to collect data and # print status. batch_size = lfads_hps['batch_size'] num_batches = lfads_opt_hps['num_batches'] print_every = lfads_opt_hps['print_every'] num_opt_loops = int(num_batches / print_every) params = get_params(opt_state) for oidx in range(num_opt_loops): batch_idx_start = oidx * print_every start_time = time.time() key, tkey, dtkey, dekey = random.split(random.fold_in(key, oidx), 4) opt_state = optimize_lfads_core_jit(tkey, batch_idx_start, print_every, update_w_gc, kl_warmup_fun, opt_state, lfads_hps, lfads_opt_hps, train_data) batch_time = time.time() - start_time # Losses params = get_params(opt_state) batch_pidx = batch_idx_start + print_every kl_warmup = kl_warmup_fun(batch_idx_start) # Training loss didxs = onp.random.randint(0, train_data.shape[0], batch_size) x_bxt = train_data[didxs].astype(onp.float32) tlosses = lfads.lfads_losses_jit(params, lfads_hps, dtkey, x_bxt, kl_warmup, 1.0) # Evaluation loss didxs = onp.random.randint(0, eval_data.shape[0], batch_size) ex_bxt = eval_data[didxs].astype(onp.float32) elosses = lfads.lfads_losses_jit(params, lfads_hps, dekey, ex_bxt, kl_warmup, 1.0) # Saving, printing. all_tlosses.append(tlosses) all_elosses.append(elosses) s = "Batches {}-{} in {:0.2f} sec, Step size: {:0.5f}, Training loss {:0.0f}, Eval loss {:0.0f}" print( s.format(batch_idx_start + 1, batch_pidx, batch_time, decay_fun(batch_pidx), tlosses['total'], elosses['total'])) tlosses_thru_training = utils.merge_losses_dicts(all_tlosses) elosses_thru_training = utils.merge_losses_dicts(all_elosses) optimizer_details = { 'tlosses': tlosses_thru_training, 'elosses': elosses_thru_training } return params, optimizer_details