def gru_params(key, n, u, ifactor=1.0, hfactor=1.0, hscale=0.0): """Generate GRU parameters Arguments: key: random.PRNGKey for random bits n: hidden state size u: input size ifactor: scaling factor for input weights hfactor: scaling factor for hidden -> hidden weights hscale: scale on h0 initial condition Returns: a dictionary of parameters """ key, skeys = utils.keygen(key, 5) ifactor = ifactor / np.sqrt(u) hfactor = hfactor / np.sqrt(n) wRUH = random.normal(next(skeys), (n + n, n)) * hfactor wRUX = random.normal(next(skeys), (n + n, u)) * ifactor wRUHX = np.concatenate([wRUH, wRUX], axis=1) wCH = random.normal(next(skeys), (n, n)) * hfactor wCX = random.normal(next(skeys), (n, u)) * ifactor wCHX = np.concatenate([wCH, wCX], axis=1) return { 'h0': random.normal(next(skeys), (n, )) * hscale, 'wRUHX': wRUHX, 'wCHX': wCHX, 'bRU': np.zeros((n + n, )), 'bC': np.zeros((n, )) }
def lfads_encode(params, lfads_hps, key, x_t, keep_rate): """Run the LFADS network from input to generator initial condition vars. Arguments: params: a dictionary of LFADS parameters lfads_hps: a dictionary of LFADS hyperparameters key: random.PRNGKey for random bits x_t: np array input for lfads with leading dimension being time keep_rate: dropout keep rate Returns: 3-tuple of np arrays: generator initial condition mean, log variance and also bidirectional encoding of x_t, with leading dim being time """ key, skeys = utils.keygen(key, 3) # Encode the input x_t = run_dropout(x_t, next(skeys), keep_rate) con_ins_t, gen_pre_ics = run_bidirectional_rnn(params['ic_enc'], gru, gru, x_t) # Push through to posterior mean and variance for initial conditions. xenc_t = dropout(con_ins_t, next(skeys), keep_rate) gen_pre_ics = dropout(gen_pre_ics, next(skeys), keep_rate) ic_gauss_params = affine(params['gen_ic'], gen_pre_ics) ic_mean, ic_logvar = np.split(ic_gauss_params, 2, axis=0) return ic_mean, ic_logvar, xenc_t
def lfads_losses(params, lfads_hps, key, x_bxt, kl_scale, keep_rate): """Compute the training loss of the LFADS autoencoder Arguments: params: a dictionary of LFADS parameters lfads_hps: a dictionary of LFADS hyperparameters key: random.PRNGKey for random bits x_bxt: np array of input with leading dims being batch and time keep_rate: dropout keep rate kl_scale: scale on KL Returns: a dictionary of all losses, including the key 'total' used for optimization """ B = lfads_hps['batch_size'] key, skeys = utils.keygen(key, 2) keys_b = random.split(next(skeys), B) lfads = batch_lfads(params, lfads_hps, keys_b, x_bxt, keep_rate) # Sum over time and state dims, average over batch. # KL - g0 ic_post_mean_b = lfads['ic_mean'] ic_post_logvar_b = lfads['ic_logvar'] kl_loss_g0_b = dists.batch_kl_gauss_gauss(ic_post_mean_b, ic_post_logvar_b, params['ic_prior'], lfads_hps['var_min']) kl_loss_g0_prescale = np.sum(kl_loss_g0_b) / B kl_loss_g0 = kl_scale * kl_loss_g0_prescale # KL - Inferred input ii_post_mean_bxt = lfads['ii_mean_t'] ii_post_var_bxt = lfads['ii_logvar_t'] keys_b = random.split(next(skeys), B) kl_loss_ii_b = dists.batch_kl_gauss_ar1(keys_b, ii_post_mean_bxt, ii_post_var_bxt, params['ii_prior'], lfads_hps['var_min']) kl_loss_ii_prescale = np.sum(kl_loss_ii_b) / B kl_loss_ii = kl_scale * kl_loss_ii_prescale # Log-likelihood of data given latents. lograte_bxt = lfads['lograte_t'] log_p_xgz = np.sum(dists.poisson_log_likelihood(x_bxt, lograte_bxt)) / B # L2 l2reg = lfads_hps['l2reg'] l2_loss = l2reg * optimizers.l2_norm(params)**2 loss = -log_p_xgz + kl_loss_g0 + kl_loss_ii + l2_loss all_losses = { 'total': loss, 'nlog_p_xgz': -log_p_xgz, 'kl_g0': kl_loss_g0, 'kl_g0_prescale': kl_loss_g0_prescale, 'kl_ii': kl_loss_ii, 'kl_ii_prescale': kl_loss_ii_prescale, 'l2': l2_loss } return all_losses
def optimize_lfads_core(key, batch_idx_start, num_batches, update_fun, kl_warmup_fun, opt_state, lfads_hps, lfads_opt_hps, train_data): """Make gradient updates to the LFADS model. Uses lax.fori_loop instead of a Python loop to reduce JAX overhead. This loop will be jit'd and run on device. Arguments: init_params: a dict of parameters to be trained batch_idx_start: Where are we in the total number of batches num_batches: how many batches to run update_fun: the function that changes params based on grad of loss kl_warmup_fun: function to compute the kl warmup opt_state: the jax optimizer state, containing params and opt state lfads_hps: dict of lfads model HPs lfads_opt_hps: dict of optimization HPs train_data: nexamples x time x ndims np array of data for training Returns: opt_state: the jax optimizer state, containing params and optimizer state""" key, dkeyg = utils.keygen(key, num_batches) # data key, fkeyg = utils.keygen(key, num_batches) # forward pass # Begin optimziation loop. Explicitly avoiding a python for-loop # so that jax will not trace it for the sake of a gradient we will not use. def run_update(batch_idx, opt_state): kl_warmup = kl_warmup_fun(batch_idx) didxs = random.randint(next(dkeyg), [lfads_hps['batch_size']], 0, train_data.shape[0]) x_bxt = train_data[didxs].astype(np.float32) opt_state = update_fun(batch_idx, opt_state, lfads_hps, lfads_opt_hps, next(fkeyg), x_bxt, kl_warmup) return opt_state lower = batch_idx_start upper = batch_idx_start + num_batches return lax.fori_loop(lower, upper, run_update, opt_state)
def lfads_params(key, lfads_hps): """Instantiate random LFADS parameters. Arguments: key: random.PRNGKey for random bits lfads_hps: a dict of LFADS hyperparameters Returns: a dictionary of LFADS parameters """ key, skeys = utils.keygen(key, 10) data_dim = lfads_hps['data_dim'] ntimesteps = lfads_hps['ntimesteps'] enc_dim = lfads_hps['enc_dim'] con_dim = lfads_hps['con_dim'] ii_dim = lfads_hps['ii_dim'] gen_dim = lfads_hps['gen_dim'] factors_dim = lfads_hps['factors_dim'] ic_enc_params = { 'fwd_rnn': gru_params(next(skeys), enc_dim, data_dim), 'bwd_rnn': gru_params(next(skeys), enc_dim, data_dim) } gen_ic_params = affine_params(next(skeys), 2 * gen_dim, 2 * enc_dim) #m,v <- bi ic_prior_params = dists.diagonal_gaussian_params(next(skeys), gen_dim, 0.0, lfads_hps['ic_prior_var']) con_params = gru_params(next(skeys), con_dim, 2 * enc_dim + factors_dim) con_out_params = affine_params(next(skeys), 2 * ii_dim, con_dim) #m,v ii_prior_params = dists.ar1_params(next(skeys), ii_dim, lfads_hps['ar_mean'], lfads_hps['ar_autocorrelation_tau'], lfads_hps['ar_noise_variance']) gen_params = gru_params(next(skeys), gen_dim, ii_dim) factors_params = linear_params(next(skeys), factors_dim, gen_dim) lograte_params = affine_params(next(skeys), data_dim, factors_dim) return { 'ic_enc': ic_enc_params, 'gen_ic': gen_ic_params, 'ic_prior': ic_prior_params, 'con': con_params, 'con_out': con_out_params, 'ii_prior': ii_prior_params, 'gen': gen_params, 'factors': factors_params, 'logrates': lograte_params }
def linear_params(key, o, u, ifactor=1.0): """Params for y = w x Arguments: key: random.PRNGKey for random bits o: output size u: input size ifactor: scaling factor Returns: a dictionary of parameters """ key, skeys = utils.keygen(key, 1) ifactor = ifactor / np.sqrt(u) return {'w': random.normal(next(skeys), (o, u)) * ifactor}
def kl_gauss_ar1(key, z_mean_t, z_logvar_t, ar1_params, varmin=1e-16): """KL using samples for multi-dim gaussian (thru time) and AR(1) process. To sample KL(q||p), we sample ln q - ln p by drawing samples from q and averaging. q is multidim gaussian, p is AR(1) process. Arguments: key: random.PRNGKey for random bits z_mean_t: np.array of means with leading dim being time z_logvar_t: np.array of log vars, leading dim is time ar1_params: dictionary of ar1 parameters, log noise var and autocorr tau varmin: minimal variance, useful for numerical stability Returns: sampled KL divergence between """ ll = diag_gaussian_log_likelihood sample = diag_gaussian_sample nkeys = z_mean_t.shape[0] key, skeys = utils.keygen(key, nkeys) # Convert AR(1) parameters. # z_t = c + phi z_{t-1} + eps, eps \in N(0, noise var) ar1_mean = ar1_params['mean'] ar1_lognoisevar = np.log(np.exp(ar1_params['lognvar'] + varmin)) phi = np.exp(-np.exp(-ar1_params['logatau'])) # The process variance a function of noise variance, so I added varmin above. # This affects log-likelihood funtions below, also. logprocessvar = ar1_lognoisevar - (np.log(1 - phi) + np.log(1 + phi)) # Sample first AR(1) step according to process variance. z0 = sample(next(skeys), z_mean_t[0], z_logvar_t[0], varmin) logq = ll(z0, z_mean_t[0], z_logvar_t[0], varmin) logp = ll(z0, ar1_mean, logprocessvar, 0.0) z_last = z0 # Sample the remaining time steps with adjusted mean and noise variance. for z_mean, z_logvar in zip(z_mean_t[1:], z_logvar_t[1:]): z = sample(next(skeys), z_mean, z_logvar, varmin) logq += ll(z, z_mean, z_logvar, varmin) logp += ll(z, ar1_mean + phi * z_last, ar1_lognoisevar, 0.0) z_last = z kl = logq - logp return kl
def lfads_decode(params, lfads_hps, key, ic_mean, ic_logvar, xenc_t, keep_rate): """Run the LFADS network from latent variables to log rates. Arguments: params: a dictionary of LFADS parameters lfads_hps: a dictionary of LFADS hyperparameters key: random.PRNGKey for random bits ic_mean: np array of generator initial condition mean ic_logvar: np array of generator initial condition log variance xenc_t: np array bidirectional encoding of input (x_t) with leading dim being time keep_rate: dropout keep rate Returns: 7-tuple of np arrays all with leading dim being time, controller hidden state, inferred input mean, inferred input log var, generator hidden state, factors and log rates """ ntime = lfads_hps['ntimesteps'] key, skeys = utils.keygen(key, 2) # Since the factors feed back to the controller, # factors_{t-1} -> controller_t -> sample_t -> generator_t -> factors_t # is really one big loop and therefor one RNN. c0 = params['con']['h0'] g0 = dists.diag_gaussian_sample(next(skeys), ic_mean, ic_logvar, lfads_hps['var_min']) f0 = np.zeros((lfads_hps['factors_dim'], )) # Make all the randomness for all T steps at once, it's more efficient. # The random keys get passed into scan along with the input, so the input # becomes of a 2-tuple (keys, actual input). T = xenc_t.shape[0] keys_t = random.split(next(skeys), T) state0 = (c0, g0, f0) decoder = partial(lfads_decode_one_step_scan, *(params, lfads_hps, keep_rate)) _, state_and_returns_t = lax.scan(decoder, state0, (keys_t, xenc_t)) return state_and_returns_t
def lfads(params, lfads_hps, key, x_t, keep_rate): """Run the LFADS network from input to output. Arguments: params: a dictionary of LFADS parameters lfads_hps: a dictionary of LFADS hyperparameters key: random.PRNGKey for random bits x_t: np array of input with leading dim being time keep_rate: dropout keep rate Returns: A dictionary of np arrays of all LFADS values of interest. """ key, skeys = utils.keygen(key, 2) ic_mean, ic_logvar, xenc_t = \ lfads_encode(params, lfads_hps, next(skeys), x_t, keep_rate) c_t, gen_t, factor_t, ii_t, ii_mean_t, ii_logvar_t, lograte_t = \ lfads_decode(params, lfads_hps, next(skeys), ic_mean, ic_logvar, xenc_t, keep_rate) # As this is tutorial code, we're passing everything around. return { 'xenc_t': xenc_t, 'ic_mean': ic_mean, 'ic_logvar': ic_logvar, 'ii_t': ii_t, 'c_t': c_t, 'ii_mean_t': ii_mean_t, 'ii_logvar_t': ii_logvar_t, 'gen_t': gen_t, 'factor_t': factor_t, 'lograte_t': lograte_t }
def lfads_decode(params, lfads_hps, key, ic_mean, ic_logvar, xenc_t, keep_rate): """Run the LFADS network from latent variables to log rates. Arguments: params: a dictionary of LFADS parameters lfads_hps: a dictionary of LFADS hyperparameters key: random.PRNGKey for random bits ic_mean: np array of generator initial condition mean ic_logvar: np array of generator initial condition log variance xenc_t: np array bidirectional encoding of input (x_t) with leading dim being time keep_rate: dropout keep rate Returns: 7-tuple of np arrays all with leading dim being time, controller hidden state, inferred input mean, inferred input log var, generator hidden state, factors and log rates """ ntime = lfads_hps['ntimesteps'] key, skeys = utils.keygen(key, 1 + 2 * ntime) # Since the factors feed back to the controller, # factors_{t-1} -> controller_t -> sample_t -> generator_t -> factors_t # is really one big loop and therefor one RNN. c = c0 = params['con']['h0'] g = g0 = dists.diag_gaussian_sample(next(skeys), ic_mean, ic_logvar, lfads_hps['var_min']) f = f0 = np.zeros((lfads_hps['factors_dim'], )) c_t = [] ii_mean_t = [] ii_logvar_t = [] ii_t = [] gen_t = [] factor_t = [] for xenc in xenc_t: cin = np.concatenate([xenc, f], axis=0) c = gru(params['con'], c, cin) cout = affine(params['con_out'], c) ii_mean, ii_logvar = np.split(cout, 2, axis=0) # inferred input params ii = dists.diag_gaussian_sample(next(skeys), ii_mean, ii_logvar, lfads_hps['var_min']) g = gru(params['gen'], g, ii) g = dropout(g, next(skeys), keep_rate) f = normed_linear(params['factors'], g) # Save everything. c_t.append(c) ii_t.append(ii) gen_t.append(g) ii_mean_t.append(ii_mean) ii_logvar_t.append(ii_logvar) factor_t.append(f) c_t = np.array(c_t) ii_t = np.array(ii_t) gen_t = np.array(gen_t) ii_mean_t = np.array(ii_mean_t) ii_logvar_t = np.array(ii_logvar_t) factor_t = np.array(factor_t) lograte_t = batch_affine(params['logrates'], factor_t) return c_t, ii_mean_t, ii_logvar_t, ii_t, gen_t, factor_t, lograte_t