Esempio n. 1
0
def gru_params(key, n, u, ifactor=1.0, hfactor=1.0, hscale=0.0):
    """Generate GRU parameters

  Arguments:
    key: random.PRNGKey for random bits
    n: hidden state size
    u: input size
    ifactor: scaling factor for input weights
    hfactor: scaling factor for hidden -> hidden weights
    hscale: scale on h0 initial condition

  Returns:
    a dictionary of parameters
  """
    key, skeys = utils.keygen(key, 5)
    ifactor = ifactor / np.sqrt(u)
    hfactor = hfactor / np.sqrt(n)

    wRUH = random.normal(next(skeys), (n + n, n)) * hfactor
    wRUX = random.normal(next(skeys), (n + n, u)) * ifactor
    wRUHX = np.concatenate([wRUH, wRUX], axis=1)

    wCH = random.normal(next(skeys), (n, n)) * hfactor
    wCX = random.normal(next(skeys), (n, u)) * ifactor
    wCHX = np.concatenate([wCH, wCX], axis=1)

    return {
        'h0': random.normal(next(skeys), (n, )) * hscale,
        'wRUHX': wRUHX,
        'wCHX': wCHX,
        'bRU': np.zeros((n + n, )),
        'bC': np.zeros((n, ))
    }
Esempio n. 2
0
def lfads_encode(params, lfads_hps, key, x_t, keep_rate):
    """Run the LFADS network from input to generator initial condition vars.

  Arguments:
    params: a dictionary of LFADS parameters
    lfads_hps: a dictionary of LFADS hyperparameters
    key: random.PRNGKey for random bits
    x_t: np array input for lfads with leading dimension being time
    keep_rate: dropout keep rate

  Returns:
    3-tuple of np arrays: generator initial condition mean, log variance
      and also bidirectional encoding of x_t, with leading dim being time
  """
    key, skeys = utils.keygen(key, 3)

    # Encode the input
    x_t = run_dropout(x_t, next(skeys), keep_rate)
    con_ins_t, gen_pre_ics = run_bidirectional_rnn(params['ic_enc'], gru, gru,
                                                   x_t)
    # Push through to posterior mean and variance for initial conditions.
    xenc_t = dropout(con_ins_t, next(skeys), keep_rate)
    gen_pre_ics = dropout(gen_pre_ics, next(skeys), keep_rate)
    ic_gauss_params = affine(params['gen_ic'], gen_pre_ics)
    ic_mean, ic_logvar = np.split(ic_gauss_params, 2, axis=0)
    return ic_mean, ic_logvar, xenc_t
Esempio n. 3
0
def lfads_losses(params, lfads_hps, key, x_bxt, kl_scale, keep_rate):
    """Compute the training loss of the LFADS autoencoder

  Arguments:
    params: a dictionary of LFADS parameters
    lfads_hps: a dictionary of LFADS hyperparameters
    key: random.PRNGKey for random bits
    x_bxt: np array of input with leading dims being batch and time
    keep_rate: dropout keep rate
    kl_scale: scale on KL

  Returns:
    a dictionary of all losses, including the key 'total' used for optimization
  """

    B = lfads_hps['batch_size']
    key, skeys = utils.keygen(key, 2)
    keys_b = random.split(next(skeys), B)
    lfads = batch_lfads(params, lfads_hps, keys_b, x_bxt, keep_rate)

    # Sum over time and state dims, average over batch.
    # KL - g0
    ic_post_mean_b = lfads['ic_mean']
    ic_post_logvar_b = lfads['ic_logvar']
    kl_loss_g0_b = dists.batch_kl_gauss_gauss(ic_post_mean_b, ic_post_logvar_b,
                                              params['ic_prior'],
                                              lfads_hps['var_min'])
    kl_loss_g0_prescale = np.sum(kl_loss_g0_b) / B
    kl_loss_g0 = kl_scale * kl_loss_g0_prescale

    # KL - Inferred input
    ii_post_mean_bxt = lfads['ii_mean_t']
    ii_post_var_bxt = lfads['ii_logvar_t']
    keys_b = random.split(next(skeys), B)
    kl_loss_ii_b = dists.batch_kl_gauss_ar1(keys_b, ii_post_mean_bxt,
                                            ii_post_var_bxt,
                                            params['ii_prior'],
                                            lfads_hps['var_min'])
    kl_loss_ii_prescale = np.sum(kl_loss_ii_b) / B
    kl_loss_ii = kl_scale * kl_loss_ii_prescale

    # Log-likelihood of data given latents.
    lograte_bxt = lfads['lograte_t']
    log_p_xgz = np.sum(dists.poisson_log_likelihood(x_bxt, lograte_bxt)) / B

    # L2
    l2reg = lfads_hps['l2reg']
    l2_loss = l2reg * optimizers.l2_norm(params)**2

    loss = -log_p_xgz + kl_loss_g0 + kl_loss_ii + l2_loss
    all_losses = {
        'total': loss,
        'nlog_p_xgz': -log_p_xgz,
        'kl_g0': kl_loss_g0,
        'kl_g0_prescale': kl_loss_g0_prescale,
        'kl_ii': kl_loss_ii,
        'kl_ii_prescale': kl_loss_ii_prescale,
        'l2': l2_loss
    }
    return all_losses
Esempio n. 4
0
def optimize_lfads_core(key, batch_idx_start, num_batches, update_fun,
                        kl_warmup_fun, opt_state, lfads_hps, lfads_opt_hps,
                        train_data):
    """Make gradient updates to the LFADS model.

  Uses lax.fori_loop instead of a Python loop to reduce JAX overhead. This 
    loop will be jit'd and run on device.

  Arguments:
    init_params: a dict of parameters to be trained
    batch_idx_start: Where are we in the total number of batches
    num_batches: how many batches to run
    update_fun: the function that changes params based on grad of loss
    kl_warmup_fun: function to compute the kl warmup
    opt_state: the jax optimizer state, containing params and opt state
    lfads_hps: dict of lfads model HPs
    lfads_opt_hps: dict of optimization HPs
    train_data: nexamples x time x ndims np array of data for training

  Returns:
    opt_state: the jax optimizer state, containing params and optimizer state"""

    key, dkeyg = utils.keygen(key, num_batches)  # data
    key, fkeyg = utils.keygen(key, num_batches)  # forward pass

    # Begin optimziation loop. Explicitly avoiding a python for-loop
    # so that jax will not trace it for the sake of a gradient we will not use.
    def run_update(batch_idx, opt_state):
        kl_warmup = kl_warmup_fun(batch_idx)
        didxs = random.randint(next(dkeyg), [lfads_hps['batch_size']], 0,
                               train_data.shape[0])
        x_bxt = train_data[didxs].astype(np.float32)
        opt_state = update_fun(batch_idx, opt_state, lfads_hps, lfads_opt_hps,
                               next(fkeyg), x_bxt, kl_warmup)
        return opt_state

    lower = batch_idx_start
    upper = batch_idx_start + num_batches
    return lax.fori_loop(lower, upper, run_update, opt_state)
Esempio n. 5
0
def lfads_params(key, lfads_hps):
    """Instantiate random LFADS parameters.

  Arguments:
    key: random.PRNGKey for random bits
    lfads_hps: a dict of LFADS hyperparameters

  Returns:
    a dictionary of LFADS parameters
  """
    key, skeys = utils.keygen(key, 10)

    data_dim = lfads_hps['data_dim']
    ntimesteps = lfads_hps['ntimesteps']
    enc_dim = lfads_hps['enc_dim']
    con_dim = lfads_hps['con_dim']
    ii_dim = lfads_hps['ii_dim']
    gen_dim = lfads_hps['gen_dim']
    factors_dim = lfads_hps['factors_dim']

    ic_enc_params = {
        'fwd_rnn': gru_params(next(skeys), enc_dim, data_dim),
        'bwd_rnn': gru_params(next(skeys), enc_dim, data_dim)
    }
    gen_ic_params = affine_params(next(skeys), 2 * gen_dim,
                                  2 * enc_dim)  #m,v <- bi
    ic_prior_params = dists.diagonal_gaussian_params(next(skeys), gen_dim, 0.0,
                                                     lfads_hps['ic_prior_var'])
    con_params = gru_params(next(skeys), con_dim, 2 * enc_dim + factors_dim)
    con_out_params = affine_params(next(skeys), 2 * ii_dim, con_dim)  #m,v
    ii_prior_params = dists.ar1_params(next(skeys), ii_dim,
                                       lfads_hps['ar_mean'],
                                       lfads_hps['ar_autocorrelation_tau'],
                                       lfads_hps['ar_noise_variance'])
    gen_params = gru_params(next(skeys), gen_dim, ii_dim)
    factors_params = linear_params(next(skeys), factors_dim, gen_dim)
    lograte_params = affine_params(next(skeys), data_dim, factors_dim)

    return {
        'ic_enc': ic_enc_params,
        'gen_ic': gen_ic_params,
        'ic_prior': ic_prior_params,
        'con': con_params,
        'con_out': con_out_params,
        'ii_prior': ii_prior_params,
        'gen': gen_params,
        'factors': factors_params,
        'logrates': lograte_params
    }
Esempio n. 6
0
def linear_params(key, o, u, ifactor=1.0):
    """Params for y = w x

  Arguments:
    key: random.PRNGKey for random bits
    o: output size
    u: input size
    ifactor: scaling factor

  Returns:
    a dictionary of parameters
  """
    key, skeys = utils.keygen(key, 1)
    ifactor = ifactor / np.sqrt(u)
    return {'w': random.normal(next(skeys), (o, u)) * ifactor}
def kl_gauss_ar1(key, z_mean_t, z_logvar_t, ar1_params, varmin=1e-16):
    """KL using samples for multi-dim gaussian (thru time) and AR(1) process.
  To sample KL(q||p), we sample
        ln q - ln p
  by drawing samples from q and averaging. q is multidim gaussian, p
  is AR(1) process.

  Arguments:
    key: random.PRNGKey for random bits
    z_mean_t: np.array of means with leading dim being time
    z_logvar_t: np.array of log vars, leading dim is time
    ar1_params: dictionary of ar1 parameters, log noise var and autocorr tau
    varmin: minimal variance, useful for numerical stability

  Returns:
    sampled KL divergence between
  """
    ll = diag_gaussian_log_likelihood
    sample = diag_gaussian_sample
    nkeys = z_mean_t.shape[0]
    key, skeys = utils.keygen(key, nkeys)

    # Convert AR(1) parameters.
    # z_t = c + phi z_{t-1} + eps, eps \in N(0, noise var)
    ar1_mean = ar1_params['mean']
    ar1_lognoisevar = np.log(np.exp(ar1_params['lognvar'] + varmin))
    phi = np.exp(-np.exp(-ar1_params['logatau']))
    # The process variance a function of noise variance, so I added varmin above.
    # This affects log-likelihood funtions below, also.
    logprocessvar = ar1_lognoisevar - (np.log(1 - phi) + np.log(1 + phi))

    # Sample first AR(1) step according to process variance.
    z0 = sample(next(skeys), z_mean_t[0], z_logvar_t[0], varmin)
    logq = ll(z0, z_mean_t[0], z_logvar_t[0], varmin)
    logp = ll(z0, ar1_mean, logprocessvar, 0.0)
    z_last = z0

    # Sample the remaining time steps with adjusted mean and noise variance.
    for z_mean, z_logvar in zip(z_mean_t[1:], z_logvar_t[1:]):
        z = sample(next(skeys), z_mean, z_logvar, varmin)
        logq += ll(z, z_mean, z_logvar, varmin)
        logp += ll(z, ar1_mean + phi * z_last, ar1_lognoisevar, 0.0)
        z_last = z

    kl = logq - logp
    return kl
Esempio n. 8
0
def lfads_decode(params, lfads_hps, key, ic_mean, ic_logvar, xenc_t,
                 keep_rate):
    """Run the LFADS network from latent variables to log rates.

  Arguments:
    params: a dictionary of LFADS parameters
    lfads_hps: a dictionary of LFADS hyperparameters
    key: random.PRNGKey for random bits
    ic_mean: np array of generator initial condition mean
    ic_logvar: np array of generator initial condition log variance
    xenc_t: np array bidirectional encoding of input (x_t) with leading dim
      being time
    keep_rate: dropout keep rate

  Returns:
    7-tuple of np arrays all with leading dim being time,
      controller hidden state, inferred input mean, inferred input log var,
      generator hidden state, factors and log rates
  """

    ntime = lfads_hps['ntimesteps']
    key, skeys = utils.keygen(key, 2)

    # Since the factors feed back to the controller,
    #    factors_{t-1} -> controller_t -> sample_t -> generator_t -> factors_t
    # is really one big loop and therefor one RNN.
    c0 = params['con']['h0']
    g0 = dists.diag_gaussian_sample(next(skeys), ic_mean, ic_logvar,
                                    lfads_hps['var_min'])
    f0 = np.zeros((lfads_hps['factors_dim'], ))

    # Make all the randomness for all T steps at once, it's more efficient.
    # The random keys get passed into scan along with the input, so the input
    # becomes of a 2-tuple (keys, actual input).
    T = xenc_t.shape[0]
    keys_t = random.split(next(skeys), T)

    state0 = (c0, g0, f0)
    decoder = partial(lfads_decode_one_step_scan,
                      *(params, lfads_hps, keep_rate))
    _, state_and_returns_t = lax.scan(decoder, state0, (keys_t, xenc_t))
    return state_and_returns_t
Esempio n. 9
0
def lfads(params, lfads_hps, key, x_t, keep_rate):
    """Run the LFADS network from input to output.

  Arguments:
    params: a dictionary of LFADS parameters
    lfads_hps: a dictionary of LFADS hyperparameters
    key: random.PRNGKey for random bits
    x_t: np array of input with leading dim being time
    keep_rate: dropout keep rate

  Returns:
    A dictionary of np arrays of all LFADS values of interest.
  """

    key, skeys = utils.keygen(key, 2)

    ic_mean, ic_logvar, xenc_t = \
        lfads_encode(params, lfads_hps, next(skeys), x_t, keep_rate)

    c_t, gen_t, factor_t, ii_t, ii_mean_t, ii_logvar_t, lograte_t = \
        lfads_decode(params, lfads_hps, next(skeys), ic_mean, ic_logvar,
                     xenc_t, keep_rate)

    # As this is tutorial code, we're passing everything around.
    return {
        'xenc_t': xenc_t,
        'ic_mean': ic_mean,
        'ic_logvar': ic_logvar,
        'ii_t': ii_t,
        'c_t': c_t,
        'ii_mean_t': ii_mean_t,
        'ii_logvar_t': ii_logvar_t,
        'gen_t': gen_t,
        'factor_t': factor_t,
        'lograte_t': lograte_t
    }
Esempio n. 10
0
def lfads_decode(params, lfads_hps, key, ic_mean, ic_logvar, xenc_t,
                 keep_rate):
    """Run the LFADS network from latent variables to log rates.

  Arguments:
    params: a dictionary of LFADS parameters
    lfads_hps: a dictionary of LFADS hyperparameters
    key: random.PRNGKey for random bits
    ic_mean: np array of generator initial condition mean
    ic_logvar: np array of generator initial condition log variance
    xenc_t: np array bidirectional encoding of input (x_t) with leading dim
      being time
    keep_rate: dropout keep rate

  Returns:
    7-tuple of np arrays all with leading dim being time,
      controller hidden state, inferred input mean, inferred input log var,
      generator hidden state, factors and log rates
  """

    ntime = lfads_hps['ntimesteps']
    key, skeys = utils.keygen(key, 1 + 2 * ntime)

    # Since the factors feed back to the controller,
    #    factors_{t-1} -> controller_t -> sample_t -> generator_t -> factors_t
    # is really one big loop and therefor one RNN.
    c = c0 = params['con']['h0']
    g = g0 = dists.diag_gaussian_sample(next(skeys), ic_mean, ic_logvar,
                                        lfads_hps['var_min'])
    f = f0 = np.zeros((lfads_hps['factors_dim'], ))
    c_t = []
    ii_mean_t = []
    ii_logvar_t = []
    ii_t = []
    gen_t = []
    factor_t = []
    for xenc in xenc_t:
        cin = np.concatenate([xenc, f], axis=0)
        c = gru(params['con'], c, cin)
        cout = affine(params['con_out'], c)
        ii_mean, ii_logvar = np.split(cout, 2, axis=0)  # inferred input params
        ii = dists.diag_gaussian_sample(next(skeys), ii_mean, ii_logvar,
                                        lfads_hps['var_min'])
        g = gru(params['gen'], g, ii)
        g = dropout(g, next(skeys), keep_rate)
        f = normed_linear(params['factors'], g)
        # Save everything.
        c_t.append(c)
        ii_t.append(ii)
        gen_t.append(g)
        ii_mean_t.append(ii_mean)
        ii_logvar_t.append(ii_logvar)
        factor_t.append(f)

    c_t = np.array(c_t)
    ii_t = np.array(ii_t)
    gen_t = np.array(gen_t)
    ii_mean_t = np.array(ii_mean_t)
    ii_logvar_t = np.array(ii_logvar_t)
    factor_t = np.array(factor_t)
    lograte_t = batch_affine(params['logrates'], factor_t)

    return c_t, ii_mean_t, ii_logvar_t, ii_t, gen_t, factor_t, lograte_t