def load(self): query = Query(self._connection) query.SELECT("gid", self.name) query.SELECT("the_geom", self.name, "AsText") query.SELECT(self._primary, self.name) for key, gen in self._fields: query.SELECT(key, self.name) whereList = [] try: for entry in self._subset: item = self.name + "." + self._primary + "='" + entry + "'" whereList.append(item) query.where = " OR ".join(whereList) except TypeError: pass polyDict = PolygonDictionary() pairs = [] for entry in query: d = [] for key, gen in self._fields: d.append((gen, entry[key])) data = Dictionary() data.update(d) p = GeneralizedPolygon(enum.ABSTRACT, entry["gid"], entry["the_geom"], data, self) polyKey = keygen(entry[self._primary]) pairs.append((polyKey, p)) polyDict.update(pairs) return polyDict
def gru_params(key, n, u, ifactor=1.0, hfactor=1.0, hscale=0.0): """Generate GRU parameters Arguments: key: random.PRNGKey for random bits n: hidden state size u: input size ifactor: scaling factor for input weights hfactor: scaling factor for hidden -> hidden weights hscale: scale on h0 initial condition Returns: a dictionary of parameters """ key, skeys = utils.keygen(key, 5) ifactor = ifactor / np.sqrt(u) hfactor = hfactor / np.sqrt(n) wRUH = random.normal(next(skeys), (n+n,n)) * hfactor wRUX = random.normal(next(skeys), (n+n,u)) * ifactor wRUHX = np.concatenate([wRUH, wRUX], axis=1) wCH = random.normal(next(skeys), (n,n)) * hfactor wCX = random.normal(next(skeys), (n,u)) * ifactor wCHX = np.concatenate([wCH, wCX], axis=1) return {'h0' : random.normal(next(skeys), (n,)) * hscale, 'wRUHX' : wRUHX, 'wCHX' : wCHX, 'bRU' : np.zeros((n+n,)), 'bC' : np.zeros((n,))}
def lfads(params, lfads_hps, key, x_t, keep_rate): """Run the LFADS network from input to output. Arguments: params: a dictionary of LFADS parameters lfads_hps: a dictionary of LFADS hyperparameters key: random.PRNGKey for random bits x_t: np array of input with leading dim being time keep_rate: dropout keep rate Returns: A dictionary of np arrays of all LFADS values of interest. """ key, skeys = utils.keygen(key, 2) ic_mean, ic_logvar, xenc_t = \ lfads_encode(params, lfads_hps, next(skeys), x_t, keep_rate) c_t, ii_mean_t, ii_logvar_t, ii_t, gen_t, factor_t, lograte_t = \ lfads_decode(params, lfads_hps, next(skeys), ic_mean, ic_logvar, xenc_t, keep_rate) # As this is tutorial code, we're passing everything around. return {'xenc_t' : xenc_t, 'ic_mean' : ic_mean, 'ic_logvar' : ic_logvar, 'ii_t' : ii_t, 'c_t' : c_t, 'ii_mean_t' : ii_mean_t, 'ii_logvar_t' : ii_logvar_t, 'gen_t' : gen_t, 'factor_t' : factor_t, 'lograte_t' : lograte_t}
def lfads_decode(params, lfads_hps, key, ic_mean, ic_logvar, xenc_t, keep_rate): """Run the LFADS network from latent variables to log rates. Arguments: params: a dictionary of LFADS parameters lfads_hps: a dictionary of LFADS hyperparameters key: random.PRNGKey for random bits ic_mean: np array of generator initial condition mean ic_logvar: np array of generator initial condition log variance xenc_t: np array bidirectional encoding of input (x_t) with leading dim being time keep_rate: dropout keep rate Returns: 7-tuple of np arrays all with leading dim being time, controller hidden state, inferred input mean, inferred input log var, generator hidden state, factors and log rates """ ntime = lfads_hps['ntimesteps'] key, skeys = utils.keygen(key, 1+2*ntime) # Since the factors feed back to the controller, # factors_{t-1} -> controller_t -> sample_t -> generator_t -> factors_t # is really one big loop and therefor one RNN. c = c0 = params['con']['h0'] g = g0 = dists.diag_gaussian_sample(next(skeys), ic_mean, ic_logvar) f = f0 = np.zeros((lfads_hps['factors_dim'],)) c_t = [] ii_mean_t = [] ii_logvar_t = [] ii_t = [] gen_t = [] factor_t = [] for xenc in xenc_t: cin = np.concatenate([xenc, f], axis=0) c = gru(params['con'], c, cin) cout = affine(params['con_out'], c) ii_mean, ii_logvar = np.split(cout, 2, axis=0) # inferred input params ii = dists.diag_gaussian_sample(next(skeys), ii_mean, ii_logvar) g = gru(params['gen'], g, ii) g = dropout(g, next(skeys), keep_rate) f = normed_linear(params['factors'], g) # Save everything. c_t.append(c) ii_t.append(ii) gen_t.append(g) ii_mean_t.append(ii_mean) ii_logvar_t.append(ii_logvar) factor_t.append(f) c_t = np.array(c_t) ii_t = np.array(ii_t) gen_t = np.array(gen_t) ii_mean_t = np.array(ii_mean_t) ii_logvar_t = np.array(ii_logvar_t) factor_t = np.array(factor_t) lograte_t = batch_affine(params['logrates'], factor_t) return c_t, ii_mean_t, ii_logvar_t, ii_t, gen_t, factor_t, lograte_t
def load(self): query = Query(self._connection) query.SELECT('gid', self.name) query.SELECT('the_geom', self.name, 'AsText') query.SELECT(self._primary, self.name) for key, gen in self._fields: query.SELECT(key, self.name) whereList = [] try: for entry in self._subset: item = self.name + '.' + self._primary + '=\'' + entry + '\'' whereList.append(item) query.where = ' OR '.join(whereList) except TypeError: pass polyDict = PolygonDictionary() pairs = [] for entry in query: d = [] for key, gen in self._fields: d.append((gen, entry[key])) data = Dictionary() data.update(d) p = GeneralizedPolygon(enum.ABSTRACT, entry['gid'], entry['the_geom'], data, self) polyKey = keygen(entry[self._primary]) pairs.append((polyKey, p)) polyDict.update(pairs) return polyDict
def lfads_encode(params, lfads_hps, key, x_t, keep_rate): """Run the LFADS network from input to generator initial condition vars. Arguments: params: a dictionary of LFADS parameters lfads_hps: a dictionary of LFADS hyperparameters key: random.PRNGKey for random bits x_t: np array input for lfads with leading dimension being time keep_rate: dropout keep rate Returns: 3-tuple of np arrays: generator initial condition mean, log variance and also bidirectional encoding of x_t, with leading dim being time """ key, skeys = utils.keygen(key, 3) # Encode the input x_t = run_dropout(x_t, next(skeys), keep_rate) con_ins_t, gen_pre_ics = run_bidirectional_rnn(params['ic_enc'], gru, gru, x_t) # Push through to posterior mean and variance for initial conditions. xenc_t = dropout(con_ins_t, next(skeys), keep_rate) gen_pre_ics = dropout(gen_pre_ics, next(skeys), keep_rate) ic_gauss_params = affine(params['gen_ic'], gen_pre_ics) ic_mean, ic_logvar = np.split(ic_gauss_params, 2, axis=0) return ic_mean, ic_logvar, xenc_t
def optimize_lfads_core(key, batch_idx_start, num_batches, update_fun, kl_warmup_fun, opt_state, lfads_hps, lfads_opt_hps, train_data): """Make gradient updates to the LFADS model. Uses lax.fori_loop instead of a Python loop to reduce JAX overhead. This loop will be jit'd and run on device. Arguments: init_params: a dict of parameters to be trained batch_idx_start: Where are we in the total number of batches num_batches: how many batches to run update_fun: the function that changes params based on grad of loss kl_warmup_fun: function to compute the kl warmup opt_state: the jax optimizer state, containing params and opt state lfads_hps: dict of lfads model HPs lfads_opt_hps: dict of optimization HPs train_data: nexamples x time x ndims np array of data for training Returns: opt_state: the jax optimizer state, containing params and optimizer state""" key, dkeyg = utils.keygen(key, num_batches) # data key, fkeyg = utils.keygen(key, num_batches) # forward pass # Begin optimziation loop. Explicitly avoiding a python for-loop # so that jax will not trace it for the sake of a gradient we will not use. def run_update(batch_idx, opt_state): kl_warmup = kl_warmup_fun(batch_idx) didxs = random.randint(next(dkeyg), [lfads_hps['batch_size']], 0, train_data.shape[0]) x_bxt = train_data[didxs].astype(np.float32) opt_state = update_fun(batch_idx, opt_state, lfads_hps, lfads_opt_hps, next(fkeyg), x_bxt, kl_warmup) return opt_state lower = batch_idx_start upper = batch_idx_start + num_batches return lax.fori_loop(lower, upper, run_update, opt_state)
def random_vrnn_params(key, u, n, o, g=1.0): """Generate random RNN parameters""" key, skeys = utils.keygen(key, 4) hscale = 0.25 ifactor = 1.0 / np.sqrt(u) hfactor = g / np.sqrt(n) pfactor = 1.0 / np.sqrt(n) return {'h0' : random.normal(next(skeys), (n,)) * hscale, 'wI' : random.normal(next(skeys), (n,u)) * ifactor, 'wR' : random.normal(next(skeys), (n,n)) * hfactor, 'wO' : random.normal(next(skeys), (o,n)) * pfactor, 'bR' : np.zeros([n]), 'bO' : np.zeros([o])}
def lfads_losses(params, lfads_hps, key, x_bxt, kl_scale, keep_rate): """Compute the training loss of the LFADS autoencoder Arguments: params: a dictionary of LFADS parameters lfads_hps: a dictionary of LFADS hyperparameters key: random.PRNGKey for random bits x_bxt: np array of input with leading dims being batch and time keep_rate: dropout keep rate kl_scale: scale on KL Returns: a dictionary of all losses, including the key 'total' used for optimization """ B = lfads_hps['batch_size'] key, skeys = utils.keygen(key, 2) keys = random.split(next(skeys), B) lfads = batch_lfads(params, lfads_hps, keys, x_bxt, keep_rate) # Sum over time and state dims, average over batch. # KL - g0 ic_post_mean_b = lfads['ic_mean'] ic_post_logvar_b = lfads['ic_logvar'] kl_loss_g0_b = dists.batch_kl_gauss_gauss(ic_post_mean_b, ic_post_logvar_b, params['ic_prior']) kl_loss_g0 = kl_scale * np.sum(kl_loss_g0_b) / B # KL - Inferred input ii_post_mean_bxt = lfads['ii_mean_t'] ii_post_var_bxt = lfads['ii_logvar_t'] keys = random.split(next(skeys), B) kl_loss_ii_b = dists.batch_kl_gauss_ar1(keys, ii_post_mean_bxt, ii_post_var_bxt, params['ii_prior']) kl_loss_ii = kl_scale * np.sum(kl_loss_ii_b) / B # Log-likelihood of data given latents. lograte_bxt = lfads['lograte_t'] log_p_xgz = np.sum(dists.poisson_log_likelihood(x_bxt, lograte_bxt)) / B # L2 l2reg = lfads_hps['l2reg'] flatten_lfads = lambda params: flatten_util.ravel_pytree(params)[0] l2_loss = l2reg * np.sum(flatten_lfads(params)**2) loss = -log_p_xgz + kl_loss_g0 + kl_loss_ii + l2_loss all_losses = {'total' : loss, 'nlog_p_xgz' : -log_p_xgz, 'kl_g0' : kl_loss_g0, 'kl_ii' : kl_loss_ii, 'l2' : l2_loss} return all_losses
def linear_params(key, o, u, ifactor=1.0): """Params for y = w x Arguments: key: random.PRNGKey for random bits o: output size u: input size ifactor: scaling factor Returns: a dictionary of parameters """ key, skeys = utils.keygen(key, 1) ifactor = ifactor / np.sqrt(u) return {'w' : random.normal(next(skeys), (o, u)) * ifactor}
def gru_params(key, **rnn_hps): """Generate GRU parameters Arguments: key: random.PRNGKey for random bits n: hidden state size u: input size i_factor: scaling factor for input weights h_factor: scaling factor for hidden -> hidden weights h_scale: scale on h0 initial condition Returns: a dictionary of parameters """ key, skeys = utils.keygen(key, 6) u = rnn_hps['u'] # input n = rnn_hps['n'] # hidden o = rnn_hps['o'] # output ifactor = rnn_hps['i_factor'] / np.sqrt(u) hfactor = rnn_hps['h_factor'] / np.sqrt(n) hscale = rnn_hps['h_scale'] wRUH = random.normal(next(skeys), (n + n, n)) * hfactor wRUX = random.normal(next(skeys), (n + n, u)) * ifactor wRUHX = np.concatenate([wRUH, wRUX], axis=1) wCH = random.normal(next(skeys), (n, n)) * hfactor wCX = random.normal(next(skeys), (n, u)) * ifactor wCHX = np.concatenate([wCH, wCX], axis=1) # Include the readout params in the GRU, though technically # not a part of the GRU. pfactor = 1.0 / np.sqrt(n) wO = random.normal(next(skeys), (o, n)) * pfactor bO = np.zeros((o, )) return { 'h0': random.normal(next(skeys), (n, )) * hscale, 'wRUHX': wRUHX, 'wCHX': wCHX, 'bRU': np.zeros((n + n, )), 'bC': np.zeros((n, )), 'wO': wO, 'bO': bO }
def kl_gauss_ar1(key, z_mean_t, z_logvar_t, ar1_params): """KL using samples for multi-dim gaussian (thru time) and AR(1) process. To sample KL(q||p), we sample ln q - ln p by drawing samples from q and averaging. q is multidim gaussian, p is AR(1) process. Arguments: key: random.PRNGKey for random bits z_mean_t: np.array of means with leading dim being time z_logvar_t: np.array of log vars, leading dim is time ar1_params: dictionary of ar1 parameters, log noise var and autocorr tau Returns: sampled KL divergence between """ ll = diag_gaussian_log_likelihood sample = diag_gaussian_sample nkeys = z_mean_t.shape[0] key, skeys = utils.keygen(key, nkeys) # Convert AR(1) parameters. # z_t = c + phi z_{t-1} + eps, eps \in N(0, noise var) ar1_mean = ar1_params['mean'] ar1_lognoisevar = ar1_params['lognvar'] phi = np.exp(-np.exp(-ar1_params['logatau'])) logprocessvar = ar1_lognoisevar - (np.log(1-phi) + np.log(1+phi)) # Sample first AR(1) step according to process variance. z0 = sample(next(skeys), z_mean_t[0], z_logvar_t[0]) logq = ll(z0, z_mean_t[0], z_logvar_t[0]) logp = ll(z0, ar1_mean, logprocessvar) z_last = z0 # Sample the remaining time steps with adjusted mean and noise variance. for z_mean, z_logvar in zip(z_mean_t[1:], z_logvar_t[1:]): z = sample(next(skeys), z_mean, z_logvar) logq += ll(z, z_mean, z_logvar) logp += ll(z, ar1_mean + phi * z_last, ar1_lognoisevar) z_last = z kl = logq - logp return kl
def build_input_and_target_pure_integration(input_params, key): """Build white noise input and integration targets.""" bias_val, stddev_val, T, ntime = input_params dt = T / ntime # Create the white noise input. key, skeys = utils.keygen(key, 2) random_sample = random.normal(next(skeys), (1, ))[0] bias = bias_val * 2.0 * (random_sample - 0.5) stddev = stddev_val / np.sqrt(dt) random_samples = random.normal(next(skeys), (ntime, )) noise_t = stddev * random_samples white_noise_t = bias + noise_t # * dt, intentionally left off to get output scaling in O(1). targets_t = np.cumsum(white_noise_t) inputs_tx1 = np.expand_dims(white_noise_t, axis=1) targets_tx1 = np.expand_dims(targets_t, axis=1) return inputs_tx1, targets_tx1
def lfads_params(key, lfads_hps): """Instantiate random LFADS parameters. Arguments: key: random.PRNGKey for random bits lfads_hps: a dict of LFADS hyperparameters Returns: a dictionary of LFADS parameters """ key, skeys = utils.keygen(key, 10) data_dim = lfads_hps['data_dim'] ntimesteps = lfads_hps['ntimesteps'] enc_dim = lfads_hps['enc_dim'] con_dim = lfads_hps['con_dim'] ii_dim = lfads_hps['ii_dim'] gen_dim = lfads_hps['gen_dim'] factors_dim = lfads_hps['factors_dim'] ic_enc_params = {'fwd_rnn' : gru_params(next(skeys), enc_dim, data_dim), 'bwd_rnn' : gru_params(next(skeys), enc_dim, data_dim)} gen_ic_params = affine_params(next(skeys), 2*gen_dim, 2*enc_dim) #m,v <- bi ic_prior_params = dists.diagonal_gaussian_params(next(skeys), gen_dim, 0.0, lfads_hps['ic_prior_var']) con_params = gru_params(next(skeys), con_dim, 2*enc_dim + factors_dim) con_out_params = affine_params(next(skeys), 2*ii_dim, con_dim) #m,v ii_prior_params = dists.ar1_params(next(skeys), ii_dim, lfads_hps['ar_mean'], lfads_hps['ar_autocorrelation_tau'], lfads_hps['ar_noise_variance']) gen_params = gru_params(next(skeys), gen_dim, ii_dim) factors_params = linear_params(next(skeys), factors_dim, gen_dim) lograte_params = affine_params(next(skeys), data_dim, factors_dim) return {'ic_enc' : ic_enc_params, 'gen_ic' : gen_ic_params, 'ic_prior' : ic_prior_params, 'con' : con_params, 'con_out' : con_out_params, 'ii_prior' : ii_prior_params, 'gen' : gen_params, 'factors' : factors_params, 'logrates' : lograte_params}
def build_input_and_target_binary_decision(input_params, key): """Build white noise input and decision targets. The decision is whether the white noise input has a perfect integral greater than, or less than, 0. Output a +1 or -1, respectively. Arguments: inputs_params: tuple of parameters for this decision task key: jax random key for making randomness Returns: 3-tuple of inputs, targets, and the target mask, indicating which time points have optimization pressure on them""" bias_val, stddev_val, T, ntime = input_params dt = T / ntime # Create the white noise input. key, skeys = utils.keygen(key, 2) random_sample = random.normal(next(skeys), (1, ))[0] bias = bias_val * 2.0 * (random_sample - 0.5) stddev = stddev_val / np.sqrt(dt) random_samples = random.normal(next(skeys), (ntime, )) noise_t = stddev * random_samples white_noise_t = bias + noise_t # * dt, intentionally left off to get output scaling in O(1). pure_integration_t = np.cumsum(white_noise_t) decision = 2.0 * ((pure_integration_t[-1] > 0.0) - 0.5) targets_t = np.zeros(pure_integration_t.shape[0] - 1) targets_t = np.concatenate( [targets_t, np.array([decision], dtype=float)], axis=0) inputs_tx1 = np.expand_dims(white_noise_t, axis=1) targets_tx1 = np.expand_dims(targets_t, axis=1) target_mask = np.array([ntime - 1]) # When target is defined. return inputs_tx1, targets_tx1, target_mask
def optimize_lfads(key, init_params, lfads_hps, lfads_opt_hps, train_data, eval_data): """Optimize the LFADS model and print batch based optimization data. This loop is at the cpu nonjax-numpy level. Arguments: init_params: a dict of parameters to be trained lfads_hps: dict of lfads model HPs lfads_opt_hps: dict of optimization HPs train_data: nexamples x time x ndims np array of data for training Returns: a dictionary of trained parameters""" # Begin optimziation loop. all_tlosses = [] all_elosses = [] # Build some functions used in optimization. kl_warmup_fun = get_kl_warmup_fun(lfads_opt_hps) decay_fun = optimizers.exponential_decay(lfads_opt_hps['step_size'], lfads_opt_hps['decay_steps'], lfads_opt_hps['decay_factor']) opt_init, opt_update = optimizers.adam(step_size=decay_fun, b1=lfads_opt_hps['adam_b1'], b2=lfads_opt_hps['adam_b2'], eps=lfads_opt_hps['adam_eps']) opt_state = opt_init(init_params) update_fun = get_update_w_gc_fun(init_params, opt_update) # Run the optimization, pausing every so often to collect data and # print status. batch_size = lfads_hps['batch_size'] num_batches = lfads_opt_hps['num_batches'] print_every = lfads_opt_hps['print_every'] num_opt_loops = int(num_batches / print_every) key, dtkeyg = utils.keygen(key, num_opt_loops) # data, train key, dekeyg = utils.keygen(key, num_opt_loops) # data, eval key, tkeyg = utils.keygen(key, num_opt_loops) # training params = optimizers.get_params(opt_state) for oidx in range(num_opt_loops): batch_idx_start = oidx * print_every start_time = time.time() opt_state = optimize_lfads_core_jit(next(tkeyg), batch_idx_start, print_every, update_fun, kl_warmup_fun, opt_state, lfads_hps, lfads_opt_hps, train_data) batch_time = time.time() - start_time # Losses params = optimizers.get_params(opt_state) batch_pidx = batch_idx_start + print_every kl_warmup = kl_warmup_fun(batch_idx_start) # Training loss didxs = onp.random.randint(0, train_data.shape[0], batch_size) x_bxt = train_data[didxs].astype(onp.float32) tlosses = lfads.lfads_losses_jit(params, lfads_hps, next(dtkeyg), x_bxt, kl_warmup, 1.0) # Evaluation loss didxs = onp.random.randint(0, eval_data.shape[0], batch_size) ex_bxt = eval_data[didxs].astype(onp.float32) elosses = lfads.lfads_losses_jit(params, lfads_hps, next(dekeyg), ex_bxt, kl_warmup, 1.0) # Saving, printing. all_tlosses.append(tlosses) all_elosses.append(elosses) s = "Batches {}-{} in {:0.2f} sec, Step size: {:0.5f}, Training loss {:0.0f}, Eval loss {:0.0f}" print( s.format(batch_idx_start + 1, batch_pidx, batch_time, decay_fun(batch_pidx), tlosses['total'], elosses['total'])) tlosses_thru_training = utils.merge_losses_dicts(all_tlosses) elosses_thru_training = utils.merge_losses_dicts(all_elosses) optimizer_details = { 'tlosses': tlosses_thru_training, 'elosses': elosses_thru_training } return params, optimizer_details
def addField(self, fieldName): key = keygen(fieldName) self._fields.update({fieldName: key}) return key
### import jax.numpy as np from jax import grad, jit, random, vmap import jax.flatten_util as flatten_util from jax.config import config import lfads import numpy as onp import os import utils import time key = random.PRNGKey(0) key, skeys = utils.keygen(key, 10) ### LFADS Hyper parameters data_dim = 100 ntimesteps = 25 batch_size = 128 # batch size during optimization # LFADS architecture enc_dim = 32 #64 # encoder dim con_dim = 32 #64 # contoller dim ii_dim = 1 # inferred input dim gen_dim = 40 # 75 # generator dim factors_dim = 10 #20 # factors dim # Optimization HPs that percolates into model l2reg = 0.000002 # amount of l2 on weights
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from utils import keygen R = keygen ('R') DB = keygen ('DB') SHP = keygen ('SHP') PASS = keygen ('PASS') MEAN = keygen ('MEAN') SUM = keygen ('SUM') COUNT = keygen ('COUNT') MAX = keygen ('MAX') MIN = keygen ('MIN') SD = keygen ('SD') BAR = keygen ('BAR') BOX = keygen ('BOX') SCATTER = keygen ('SCATTER')