def generate_data_01(): batch_size = 8 input_shape = (batch_size, 4) def synth_batches(): while True: images = npr.rand(*input_shape).astype("float32") yield images batches = synth_batches() inputs = next(batches) init_func, predict_func = stax.serial( HomotopyDense(out_dim=4, W_init=glorot_uniform(), b_init=normal()), HomotopyDense(out_dim=1, W_init=glorot_uniform(), b_init=normal()), Sigmoid, ) ae_shape, ae_params = init_func(random.PRNGKey(0), input_shape) # assert ae_shape == input_shape bparam = [np.array([0.0], dtype=np.float64)] logits = predict_func(ae_params, inputs, bparam=bparam[0], activation_func=sigmoid) loss = np.mean( (np.subtract(logits, logits))) + l2_norm(ae_params) + l2_norm(bparam) return inputs, logits, ae_params, bparam, init_func, predict_func
def objective(params, bparam) -> float: logits = predict_fun(params, inputs, bparam=bparam[0], activation_func=sigmoid) loss = np.mean((np.subtract(logits, outputs))) loss += l2_norm(params) + l2_norm(bparam) return loss
def objective(params, bparam, batch) -> float: x, _ = batch x = np.reshape(x, (x.shape[0], -1)) logits = predict_fun(params, x, bparam=bparam[0], rng=key) keep = random.bernoulli(key, bparam[0], x.shape) inputs_d = np.where(keep, x, 0) loss = np.mean(np.square((np.subtract(logits, inputs_d)))) loss += 1e-8 * (l2_norm(params) + 1 / l2_norm(bparam)) return loss
def testUtilityClipGrads(self): g = (np.ones(2), (np.ones(3), np.ones(4))) norm = optimizers.l2_norm(g) ans = optimizers.clip_grads(g, 1.1 * norm) expected = g self.assertAllClose(ans, expected, check_dtypes=False) ans = optimizers.l2_norm(optimizers.clip_grads(g, 0.9 * norm)) expected = 0.9 * norm self.assertAllClose(ans, expected, check_dtypes=False)
def objective(params, bparam) -> float: def cross_entropy_loss(logits, labels): one_hot_labels = jax.nn.one_hot(labels, num_classes=10) return -np.mean(np.sum(one_hot_labels * logits, axis=-1)) logits = CNN().apply({"params": params[0]}, inputs) loss = cross_entropy_loss(logits, outputs) loss += l2_norm(params) + l2_norm(bparam) # vectorization of mini-batch of data. # #3rd argument's 0th-axis is vmaped. --> inputs(10) # batched_predict = vmap(neural_net_predict, in_axes=(None, None, 0)) return loss
def update(self, params, opt_state, batch: util.Transition): """The actual update function.""" (_, logs), grads = jax.value_and_grad(self._loss, has_aux=True)(params, batch) grad_norm_unclipped = optimizers.l2_norm(grads) updates, updated_opt_state = self._opt.update(grads, opt_state) params = optax.apply_updates(params, updates) weight_norm = optimizers.l2_norm(params) logs.update({ 'grad_norm_unclipped': grad_norm_unclipped, 'weight_norm': weight_norm, }) return params, updated_opt_state, logs
def lfads_losses(params, lfads_hps, key, x_bxt, kl_scale, keep_rate): """Compute the training loss of the LFADS autoencoder Arguments: params: a dictionary of LFADS parameters lfads_hps: a dictionary of LFADS hyperparameters key: random.PRNGKey for random bits x_bxt: np array of input with leading dims being batch and time keep_rate: dropout keep rate kl_scale: scale on KL Returns: a dictionary of all losses, including the key 'total' used for optimization """ B = lfads_hps['batch_size'] key, skeys = utils.keygen(key, 2) keys_b = random.split(next(skeys), B) lfads = batch_lfads(params, lfads_hps, keys_b, x_bxt, keep_rate) # Sum over time and state dims, average over batch. # KL - g0 ic_post_mean_b = lfads['ic_mean'] ic_post_logvar_b = lfads['ic_logvar'] kl_loss_g0_b = dists.batch_kl_gauss_gauss(ic_post_mean_b, ic_post_logvar_b, params['ic_prior'], lfads_hps['var_min']) kl_loss_g0_prescale = np.sum(kl_loss_g0_b) / B kl_loss_g0 = kl_scale * kl_loss_g0_prescale # KL - Inferred input ii_post_mean_bxt = lfads['ii_mean_t'] ii_post_var_bxt = lfads['ii_logvar_t'] keys_b = random.split(next(skeys), B) kl_loss_ii_b = dists.batch_kl_gauss_ar1(keys_b, ii_post_mean_bxt, ii_post_var_bxt, params['ii_prior'], lfads_hps['var_min']) kl_loss_ii_prescale = np.sum(kl_loss_ii_b) / B kl_loss_ii = kl_scale * kl_loss_ii_prescale # Log-likelihood of data given latents. lograte_bxt = lfads['lograte_t'] log_p_xgz = np.sum(dists.poisson_log_likelihood(x_bxt, lograte_bxt)) / B # L2 l2reg = lfads_hps['l2reg'] l2_loss = l2reg * optimizers.l2_norm(params)**2 loss = -log_p_xgz + kl_loss_g0 + kl_loss_ii + l2_loss all_losses = { 'total': loss, 'nlog_p_xgz': -log_p_xgz, 'kl_g0': kl_loss_g0, 'kl_g0_prescale': kl_loss_g0_prescale, 'kl_ii': kl_loss_ii, 'kl_ii_prescale': kl_loss_ii_prescale, 'l2': l2_loss } return all_losses
def objective(params, bparam, batch) -> float: x, targets = batch x = np.reshape(x, (x.shape[0], -1)) logits = predict_fun(params, x, bparam=bparam[0], activation_func=relu) loss = -np.mean(np.sum(logits * targets, axis=1)) loss += 5e-7 * (l2_norm(params)) #+ l2_norm(bparam)) return loss
def loss(params, inputs_bxtxu, targets_bxtxo, l2reg): """Compute the least squares loss of the output, plus L2 regularization.""" _, outs_bxtxo = batched_rnn_run(params, inputs_bxtxu) l2_loss = l2reg * optimizers.l2_norm(params)**2 lms_loss = np.mean((outs_bxtxo - targets_bxtxo)**2) total_loss = lms_loss + l2_loss return {'total' : total_loss, 'lms' : lms_loss, 'l2' : l2_loss}
def __init__( self, state, bparam, state_0, bparam_0, counter, objective, dual_objective, hparams, ): # states self._state_wrap = StateVariable(state, counter) self._bparam_wrap = StateVariable( bparam, counter ) # Todo : save tree def, always unlfatten before compute_grads self._prev_state = state_0 self._prev_bparam = bparam_0 # objectives self.objective = objective self.dual_objective = dual_objective self.value_func = jit(self.objective) self.hparams = hparams self._value_wrap = StateVariable( 1.0, counter) # TODO: fix with a static batch (test/train) self._quality_wrap = StateVariable(l2_norm(self._state_wrap.state), counter) # optimizer self.opt = OptimizerCreator( opt_string=hparams["meta"]["optimizer"], learning_rate=hparams["natural_lr"]).get_optimizer() self.ascent_opt = OptimizerCreator( opt_string=hparams["meta"]["ascent_optimizer"], learning_rate=hparams["ascent_lr"], ).get_optimizer() # every step hparams self.continuation_steps = hparams["continuation_steps"] self._lagrange_multiplier = hparams["lagrange_init"] self._delta_s = hparams["delta_s"] self._omega = hparams["omega"] # grad functions # should be pure functional self.compute_min_grad_fn = jit(grad(self.dual_objective, [0, 1])) self.compute_max_grad_fn = jit(grad(self.dual_objective, [2])) self.compute_grad_fn = jit(grad(self.objective, [0])) # extras self.sw = None self.state_tree_def = None self.bparam_tree_def = None self.output_file = hparams["meta"]["output_dir"] self.prev_secant_direction = None
def loss(params, batch): """ The idxes of the batch indicate which nodes are used to compute the loss. """ inputs, targets, adj, is_training, rng, idx = batch preds = predict_fun(params, inputs, adj, is_training=is_training, rng=rng) ce_loss = -np.mean(np.sum(preds[idx] * targets[idx], axis=1)) l2_loss = 5e-4 * optimizers.l2_norm(params)**2 # tf doesn't use sqrt return ce_loss + l2_loss
def grouper(iterable, threshold=0.01): prev = None group = [] for item in iterable: if not prev or l2_norm(pytree_sub(item, prev)) <= threshold: group.append(item) else: yield group group = [item] prev = item if group: yield group
def correction_step(self) -> Tuple: """Given the current state optimize to the correct state. Returns: (state: problem parameters, bparam: continuation parameter) Tuple """ quality = 1.0 ma_loss = [] stop = False print("learn_rate", self.opt.lr) for k in range(self.warmup_period): for b_j in range(self.num_batches): batch = next(self.data_loader) grads = self.grad_fn(self._state, self._bparam, batch) self._state = self.opt.update_params(self._state, grads[0]) quality = l2_norm(grads) value = self.value_fn(self._state, self._bparam, batch) ma_loss.append(value) self.opt.lr = exp_decay(k, self.hparams["natural_lr"]) if self.hparams["local_test_measure"] == "norm_gradients": if quality > self.hparams["quality_thresh"]: pass print(f"quality {quality}, {self.opt.lr} ,{k}") else: stop = True print(f"quality {quality} stopping at , {k}th step") else: if len(ma_loss) >= 20: tmp_means = running_mean(ma_loss, 10) if math.isclose( tmp_means[-1], tmp_means[-2], abs_tol=self.hparams["loss_tol"], ): print(f"stopping at , {k}th step") stop = True if stop: print("breaking") break val_loss = self.value_fn(self._state, self._bparam, (self.test_images, self.test_labels)) val_acc = self.accuracy_fn(self._state, self._bparam, (self.test_images, self.test_labels)) return self._state, self._bparam, quality, value, val_loss, val_acc
def loss(params, inputs_bxtxu, targets_bxtxo, targets_mask_t, l2reg): """Compute the least squares loss of the output, plus L2 regularization. Args: params: dict RNN parameters inputs_bxtxu: np array of inputs batch x time x input dim targets_bxtxo: np array of targets batx x time x output dim targets_mask_t: list of time indices where target is active l2reg: float, hyper parameter controlling strength of L2 regularization Returns: dict of losses """ _, outs_bxtxo = batched_rnn_run(params, inputs_bxtxu) l2_loss = l2reg * optimizers.l2_norm(params)**2 outs_bxsxo = outs_bxtxo[:, targets_mask_t, :] targets_bxsxo = targets_bxtxo[:, targets_mask_t, :] lms_loss = np.mean((outs_bxsxo - targets_bxsxo)**2) total_loss = lms_loss + l2_loss return {'total': total_loss, 'lms': lms_loss, 'l2': l2_loss}
def loss_fn(params, data, rng, batch_size, ic_prior, var_min, kl_scale, l2reg): """ :param params: :param data: :param rng: :param batch_size: :param ic_prior: :param var_min: :param kl_scale: :param l2reg: :return: """ keys = random.split(rng, batch_size) # Run the data through the model! result = lfads_batch(params, keys, data) # Get KL Loss kl_loss_g0 = dists.batch_kl_gauss_gauss(result['ic_post_mean'], result['ic_post_logvar'], ic_prior, var_min) kl_loss_g0 = np.sum(kl_loss_g0) / batch_size kl_loss_g0 = kl_scale * kl_loss_g0 # Log-likelihood of data given neuron_log_rates. log_p_xgz = np.sum( dists.poisson_log_likelihood(data, result['neuron_log_rates'])) / batch_size # L2 l2_loss = l2reg * optimizers.l2_norm(params)**2 total_loss = -log_p_xgz + kl_loss_g0 + l2_loss return total_loss
def ssvm_loss(params, x, y, lamb=0.01, max_steps=80, step_size=0.1, pretrain_global_energy=False): prediction = y is None x_hat = compute_feature_energy(params, x) if pretrain_global_energy: x_hat = lax.stop_gradient(x_hat) grad_fun = inference_step if prediction else cost_augmented_inference_step opt_init, opt_update, get_params = momentum(0.01, 0.95) # opt_state = opt_init(np.full(x.shape[:-1] + (LABELS,), 1. / LABELS)) opt_state = opt_init(np.zeros(x.shape[:-1] + (LABELS, ))) prev_energy = None for step in range(max_steps): y_hat = project(get_params(opt_state)) g, energy = grad_fun(y_hat, y, x_hat, params) opt_state = opt_update(step, g, opt_state) if step > 0 and check_saddle_point(step, get_params(opt_state), y_hat, energy, prev_energy): break prev_energy = energy y_hat = lax.stop_gradient(project(get_params(opt_state))) if prediction: return y_hat y = lax.stop_gradient(y) pred_energy = compute_global_energy(params, x_hat, y_hat) true_energy = compute_global_energy(params, x_hat, y) delta = np.square(y_hat - y).sum(axis=1) loss = np.mean(np.maximum(delta + true_energy - pred_energy, 0)) return loss + lamb * l2_norm(params)
def l2(params): emb_params, rnn_params, readout_params = params return l2_penalty * jnp.power(optimizers.l2_norm(rnn_params), 2)
def pretrain_loss(params, batch, l2=0.05): _, lr_params = params X, y = batch y_hat = predict(params, X) return -np.mean(np.sum(y * y_hat, axis=1)) + (l2 / 2) * l2_norm(lr_params)**2.
print(f"Epoch {epoch}, seed={seed}") train_loader = tfds.as_numpy( train_data.shuffle(n_train, seed=seed).batch(args.batch_size_train)) X, Y = None, None for i, batch in enumerate(train_loader): batch = process_data(batch, flatten=args.flatten, centralize_y=args.centralize_y, split='train' if args.data_augment else 'test', seed=seed * args.seed_separator + i, num_classes=args.num_classes) X, Y = batch['image'], batch['label'] params_curr = get_params(state) loss_curr, grad_curr = value_and_grad_loss(params_curr, X, Y) # monitor gradient norm grad_norm = optimizers.l2_norm(grad_curr) writer.add_scalar(f'grad_norm/{tb_flag}', grad_norm.item(), global_step) if np.isnan(loss_curr): sys.exit() running_loss += loss_curr running_count += 1 if args.grad_norm_thresh > 0: grad_curr = optimizers.clip_grads(grad_curr, args.grad_norm_thresh) state = opt_apply(epoch, grad_curr, state) global_step += 1 print( f"Step {global_step}, training loss={loss_curr:.4f}, grad norm={grad_norm:.4f}" ) # Evaluate on the test set
def projection(tree, max_norm=1.): """Clip gradients stored as a tree of arrays to maximum norm `max_norm`.""" norm = l2_norm(tree) normalize = lambda g: np.where(norm < max_norm, g, g * (max_norm / norm)) return tree_map(normalize, tree)
def testUtilityNorm(self): x0 = (np.ones(2), (np.ones(3), np.ones(4))) norm = optimizers.l2_norm(x0) expected = onp.sqrt(onp.sum(onp.ones(2 + 3 + 4)**2)) self.assertAllClose(norm, expected, check_dtypes=False)
opt_state = opt_init(params) temp, rng = random.split(rng) batches = data_stream(temp, batch_size, X, y) for i in tqdm(range(iterations)): opt_state = update(i, opt_state, next(batches)) fe_params, lr_params = get_params(opt_state) print('Accuracy (train): {:.4f}'.format( accuracy((fe_params, lr_params), predict, X, y))) print('Accuracy (test): {:.4f}'.format( accuracy((fe_params, lr_params), predict, X_test, y_test))) print(lr_params) print('L2 norm: {:.4f}'.format(l2_norm(lr_params))) # Extract features X_train_proj = onp.asarray(feature_extractor(fe_params, X)) y_train_proj = onp.asarray(y) X_test_proj = onp.asarray(feature_extractor(fe_params, X_test)) y_test_proj = onp.asarray(y_test) n = str(X.shape[0]) dim = str(X_train_proj.shape[1]) directory = 'n={}_d={}'.format(n, dim) if os.path.exists(directory): shutil.rmtree(directory) os.makedirs(directory) # Dump training and eval script
def loss(params, batch, l2=0.05): X, y = batch y_hat = predict(params, X).reshape(-1) return -np.mean(np.log(y * y_hat + (1. - y) * (1. - y_hat))) + (l2 / 2) * l2_norm( params[1])**2.
def pytree_normalized(x): return tree_util.tree_map(lambda a: a / l2_norm(x), x)
def loss(W, b): logits = predict(W, b, inputs) preds = logits - logsumexp(logits, axis=1, keepdims=True) loss = -jnp.mean(jnp.sum(preds * targets, axis=1)) loss += 0.001 * (l2_norm(W) + l2_norm(b)) return loss
def pytree_relative_error(x, y): partial_error = tree_util.tree_multimap( lambda a, b: l2_norm(pytree_sub(a, b)) / (l2_norm(a) + 1e-5), x, y) return tree_util.tree_reduce(lax.add, partial_error)
def objective_grad(self, params, bparam): # TODO: JIT? grad_J = grad(self.objective, [0, 1]) params_grad, bparam_grad = grad_J(params, bparam) result = l2_norm(params_grad) + l2_norm(bparam_grad) return result
def cross_entropy_loss(params, x, y, lamb=0.001): neglogprob = -np.mean(sigmoid_cross_entropy(-apply_mlp(params, x), y)) return neglogprob + lamb * l2_norm(params)
def run(self): """Runs the continuation strategy. A continuation strategy that defines how predictor and corrector components of the algorithm interact with the states of the mathematical system. """ self.sw = StateWriter(f"{self.output_file}/version_{self.key_state}.json") for i in range(self.continuation_steps): print(self._value_wrap.get_record(), self._bparam_wrap.get_record()) self._state_wrap.counter = i self._bparam_wrap.counter = i self._value_wrap.counter = i self.sw.write( [ self._state_wrap.get_record(), self._bparam_wrap.get_record(), self._value_wrap.get_record(), ] ) concat_states = [ (self._state_wrap.state, self._bparam_wrap.state), (self._prev_state, self._prev_bparam), self.prev_secant_direction, ] predictor = SecantPredictor( concat_states=concat_states, delta_s=self._delta_s, omega=self._omega, net_spacing_param=self.hparams["net_spacing_param"], net_spacing_bparam=self.hparams["net_spacing_bparam"], hparams=self.hparams, ) predictor.prediction_step() self.prev_secant_direction = predictor.secant_direction self.hparams["sphere_radius"] = ( 0.005 * self.hparams["omega"] * l2_norm(predictor.secant_direction) ) concat_states = [ predictor.state, predictor.bparam, predictor.secant_direction, predictor.get_secant_concat(), ] del predictor gc.collect() corrector = PerturbedCorrecter( optimizer=self.opt, objective=self.objective, dual_objective=self.dual_objective, lagrange_multiplier=self._lagrange_multiplier, concat_states=concat_states, delta_s=self._delta_s, ascent_opt=self.ascent_opt, key_state=self.key_state, compute_min_grad_fn=self.compute_min_grad_fn, compute_max_grad_fn=self.compute_max_grad_fn, compute_grad_fn=self.compute_grad_fn, hparams=self.hparams, pred_state=[self._state_wrap.state, self._bparam_wrap.state], pred_prev_state=[self._state_wrap.state, self._bparam_wrap.state], counter=self.continuation_steps, ) self._prev_state = copy.deepcopy(self._state_wrap.state) self._prev_bparam = copy.deepcopy(self._bparam_wrap.state) state, bparam, quality = corrector.correction_step() value = self.value_func(state, bparam) print( "How far ....", pytree_relative_error(self._bparam_wrap.state, bparam) ) self._state_wrap.state = state self._bparam_wrap.state = bparam self._value_wrap.state = value del corrector gc.collect()
def losses(params, hps, key, x_bxt, class_id_b, kl_scale, keep_rate): """Compute the training loss of the LFADS autoencoder. Args: params: a dictionary of LFADS parameters hps: a dictionary of LFADS hyperparameters key: random.PRNGKey for random bits x_bxt: np array of input with leading dims being batch and time class_id_b: class ids, np array of integers for what classes x_bxt are in kl_scale: scale on KL keep_rate: dropout keep rate Returns: a dictionary of all losses, including the key 'total' used for optimization """ B = hps['batch_size'] I = hps['ii_dim'] T = hps['ntimesteps'] keys = random.split(key, 2) keys_b = random.split(keys[0], B) use_mean = False lfads = batch_forward_pass(params, hps, keys_b, x_bxt, class_id_b, keep_rate, use_mean) post_mean_bxz = batch_compose_latent(hps, lfads['ib_post_mean'], lfads['ic_post_mean'], lfads['ii_post_mean_t']) post_logvar_bxz = batch_compose_latent(hps, lfads['ib_post_logvar'], lfads['ic_post_logvar'], lfads['ii_post_logvar_t']) prior_mean_bxz = params['prior']['means'][class_id_b] prior_logvar_bxz = params['prior']['logvars'][class_id_b] # Sum over time and state dims, average over batch. # KL - g0 kl_loss_bxz = \ dists.batch_kl_gauss_b_gauss_b(post_mean_bxz, post_logvar_bxz, prior_mean_bxz, prior_logvar_bxz, hps['var_min']) kl_loss_prescale = np.mean(np.sum(kl_loss_bxz, axis=1)) kl_loss = kl_scale * kl_loss_prescale # Log-likelihood of data given latents. if hps['train_on'] == 'spike_counts': spikes = x_bxt log_p_xgz = (np.sum(dists.poisson_log_likelihood(spikes, lfads['lograte_t'])) / float(B)) elif hps['train_on'] == 'continuous': continuous = x_bxt mean = lfads['lograte_t'] logvar = np.zeros(mean.shape) # TODO(sussillo): hyperparameter log_p_xgz = (np.sum(dists.diag_gaussian_log_likelihood(continuous, mean, logvar)) / float(B)) else: raise NotImplementedError # Implements the idea that inputs to the generator should be minimal, in the # sense of attempting to interpret the inferred inputs as actual inputs to a # recurrent system, under an assumption of minimal intervention to that # system. _, _, ii_post_mean_bxtxi = batch_decompose_latent(hps, post_mean_bxz) _, _, ii_prior_mean_bxtxi = batch_decompose_latent(hps, prior_mean_bxz) ii_l2_loss = hps['ii_l2_reg'] * (np.sum(ii_prior_mean_bxtxi**2) / float(B) + np.sum(ii_post_mean_bxtxi**2) / float(B)) # Implements the idea that the average inferred input should be zero. if ii_post_mean_bxtxi.shape[2] > 0: ii_tavg_loss = (hps['ii_tavg_reg'] * (np.mean(np.mean(ii_prior_mean_bxtxi, axis=1)**2) + np.mean(np.mean(ii_post_mean_bxtxi, axis=1)**2))) else: ii_tavg_loss = 0.0 # L2 - TODO(sussillo): exclusion method is not general to pytrees l2reg = hps['l2reg'] l2_ignore = ['prior'] l2_params = [p for k, p in params.items() if k not in l2_ignore] l2_loss = l2reg * optimizers.l2_norm(l2_params)**2 loss = -log_p_xgz + kl_loss + l2_loss + ii_l2_loss + ii_tavg_loss all_losses = {'total': loss, 'nlog_p_xgz': -log_p_xgz, 'kl': kl_loss, 'kl_prescale': kl_loss_prescale, 'ii_l2': ii_l2_loss, 'ii_tavg': ii_tavg_loss, 'l2': l2_loss} return all_losses