def run_epoch(rng, _opt_state, epoch_idx): _rng, dat_keys = utils.keygen(rng, 1) _rng, batch_keys = utils.keygen(_rng, num_batches) # Randomize epoch data. epoch_data = random.shuffle(next(dat_keys), X_train, axis=0) def update(batch_idx, __opt_state): """Update func for gradients, includes gradient clipping.""" kl_warmup = kl_warmup_fun(epoch_idx * num_batches + batch_idx) batch_data = lax.dynamic_slice_in_dim(epoch_data, batch_idx * BATCH_SIZE, BATCH_SIZE, axis=0) batch_data = batch_data.astype(np.float32) params = get_params(__opt_state) grads = grad(loss_fn)(params, batch_data, next(batch_keys), BATCH_SIZE, ic_prior, VAR_MIN, kl_warmup, L2_REG) clipped_grads = optimizers.clip_grads(grads, MAX_GRAD_NORM) return opt_update(batch_idx, clipped_grads, __opt_state) return lax.fori_loop(0, num_batches, update, _opt_state)
def gru_params(rng, n, u, ifactor=1.0, hfactor=1.0, hscale=0.0): """ Helper function for GRU parameter initialization. Used twice in the BidirectionalGRU (encoder) and once in the FreeEvolveGRU (decoder). :param rng: :param n: hidden state size :param u: input size :param ifactor: scaling factor for input weights :param hfactor: scaling factor for hidden -> hidden weights :param hscale: scale on h0 initial condition :return: """ rng, keys = utils.keygen(rng, 5) ifac = ifactor / np.sqrt(u) hfac = hfactor / np.sqrt(n) wRUH = random.normal(next(keys), (n + n, n)) * hfac wRUX = random.normal(next(keys), (n + n, u)) * ifac wRUHX = np.concatenate([wRUH, wRUX], axis=1) wCH = random.normal(next(keys), (n, n)) * hfac wCX = random.normal(next(keys), (n, u)) * ifac wCHX = np.concatenate([wCH, wCX], axis=1) return { 'h0': random.normal(next(keys), (n, )) * hscale, 'wRUHX': wRUHX, 'wCHX': wCHX, 'bRU': np.zeros((n + n, )), 'bC': np.zeros((n, )) }
def init_fun(rng, input_shape): u = input_shape[-1] key, keys = utils.keygen(rng, 1) ifac = ifactor / np.sqrt(u) params = {'w': random.normal(next(keys), (output_size, u)) * ifac} output_shape = input_shape[:-1] + (output_size, ) return output_shape, params
def run_gru(params, x_t, h0=None, keep_rate=1.0, rng=None): """ Run a GRU module forward in time. Arguments: params: dictionary of parameters for gru (keys: 'wRUHX', 'bRU', 'wCHX', 'bC') and optionally 'h0' x_t: np array data for RNN input with leading dim being time h0: initial condition for running rnn, which overwrites param h0 keep_rate: rng: Returns: np array of rnn applied to time data with leading dim being time """ if rng is None: raise ValueError("GRU dropout requires rng key.") rng, keys = utils.keygen(rng, len(x_t)) h = h0 if h0 is not None else params['h0'] h_t = [] for x in x_t: h = gru(params, h, x) # Do dropout on hidden state # TODO: Only do dropout during training. keep = random.bernoulli(next(keys), keep_rate, h.shape) h = np.where(keep, h / keep_rate, 0) h_t.append(h) return np.array(h_t)
def apply_fun(params, inputs, rng=None): if rng is None: raise ValueError("SampleDistrib apply_fun requires rng key.") rng, keys = utils.keygen(rng, 1) _mean, _logvar = np.split(inputs, 2, axis=0) samples = dists.diag_gaussian_sample(next(keys), _mean, _logvar, var_min) return samples
def init_fun(rng, input_shape): output_shape = (evolve_steps, n_hidden) rng, keys = utils.keygen(rng, 1) gen_params = gru_params(next(keys), n_hidden, 1) # Modify params so x weights are all 0. Not necessary because input is always 0. # gen_params['wRUHX'][:, -1] = 0 # gen_params['wCHX'][:, -1] = 0 return output_shape, gen_params
def apply_fun(params, x_t, rng=None): if rng is None: raise ValueError("BidirectionalGRU apply_fun requires rng key.") rng, keys = utils.keygen(rng, 2) fwd_enc_t = run_gru(params['fwd_rnn'], x_t, rng=next(keys)) bwd_enc_t = np.flipud( run_gru(params['bwd_rnn'], np.flipud(x_t), rng=next(keys))) enc_ends = np.concatenate([bwd_enc_t[0], fwd_enc_t[-1]], axis=1) return enc_ends
def init_fun(rng, input_shape): u = input_shape[-1] output_shape = input_shape[:-2] + (2 * n_hidden, ) rng, keys = utils.keygen(rng, 2) ic_enc_params = { 'fwd_rnn': gru_params(next(keys), n_hidden, u), 'bwd_rnn': gru_params(next(keys), n_hidden, u) } return output_shape, ic_enc_params
def lfads_onestep(params, rng, data): rng, keys = utils.keygen(rng, 2) enc_params, dec_params = params latent_vars = encdec[0](enc_params, data, rng=next(keys)) neuron_log_rates = encdec[1](dec_params, latent_vars, rng=next(keys)) ic_post_mean, ic_post_logvar = np.split(latent_vars, 2, axis=0) return { 'ic_post_mean': ic_post_mean, 'ic_post_logvar': ic_post_logvar, 'neuron_log_rates': neuron_log_rates }
def kl_gauss_ar1(key, z_mean_t, z_logvar_t, ar1_params, varmin=1e-16): """KL using samples for multi-dim gaussian (thru time) and AR(1) process. To sample KL(q||p), we sample ln q - ln p by drawing samples from q and averaging. q is multidim gaussian, p is AR(1) process. Arguments: key: random.PRNGKey for random bits z_mean_t: np.array of means with leading dim being time z_logvar_t: np.array of log vars, leading dim is time ar1_params: dictionary of ar1 parameters, log noise var and autocorr tau varmin: minimal variance, useful for numerical stability Returns: sampled KL divergence between """ ll = diag_gaussian_log_likelihood sample = diag_gaussian_sample nkeys = z_mean_t.shape[0] key, skeys = utils.keygen(key, nkeys) # Convert AR(1) parameters. # z_t = c + phi z_{t-1} + eps, eps \in N(0, noise var) ar1_mean = ar1_params['mean'] ar1_lognoisevar = np.log(np.exp(ar1_params['lognvar'] + varmin)) phi = np.exp(-np.exp(-ar1_params['logatau'])) # The process variance a function of noise variance, so I added varmin above. # This affects log-likelihood funtions below, also. logprocessvar = ar1_lognoisevar - (np.log(1-phi) + np.log(1+phi)) # Sample first AR(1) step according to process variance. z0 = sample(next(skeys), z_mean_t[0], z_logvar_t[0], varmin) logq = ll(z0, z_mean_t[0], z_logvar_t[0], varmin) logp = ll(z0, ar1_mean, logprocessvar, 0.0) z_last = z0 # Sample the remaining time steps with adjusted mean and noise variance. for z_mean, z_logvar in zip(z_mean_t[1:], z_logvar_t[1:]): z = sample(next(skeys), z_mean, z_logvar, varmin) logq += ll(z, z_mean, z_logvar, varmin) logp += ll(z, ar1_mean + phi * z_last, ar1_lognoisevar, 0.0) z_last = z kl = logq - logp return kl
num_batches = int(n_trials * EPOCHS / BATCH_SIZE) # how many batches do we train # Get the model # encoder_init, encode = LFADSEncoderModel(P_DROPOUT, ENC_DIM, IC_DIM) decoder_init, decode = LFADSDecoderModel(VAR_MIN, P_DROPOUT, IC_DIM, n_timesteps, FACTORS_DIM, n_neurons) encdec = encode, decode # Init the model ic_prior = { 'mean': 0.0 * np.ones((IC_DIM, )), 'logvar': np.log(IC_PRIOR_VAR) * np.ones((IC_DIM, )) } rng, keys = utils.keygen(rng, 2) latent_shape, init_encoder_params = encoder_init(next(keys), (n_timesteps, n_neurons)) decoded_shape, init_decoder_params = decoder_init(next(keys), latent_shape) init_params = init_encoder_params, init_decoder_params # Optimizer # def kl_warmup_fun(batch_idx): progress_frac = ((batch_idx - kl_warmup_start) / (kl_warmup_end - kl_warmup_start)) _warmup = np.where(batch_idx < kl_warmup_start, kl_min, (kl_max - kl_min) * progress_frac + kl_min) return np.where(batch_idx > kl_warmup_end, kl_max, _warmup)