Exemplo n.º 1
0
def lfads_decode_one_step(params, lfads_hps, key, keep_rate, c, f, g, xenc):
    """Run the LFADS network from latent variables to log rates one time step.

  Arguments:
    params: a dictionary of LFADS parameters
    lfads_hps: a dictionary of LFADS hyperparameters
    key: random.PRNGKey for random bits
    keep_rate: dropout keep rate
    c: controller state at time step t-1
    g: generator state at time step t-1
    f: factors at time step t-1
    xenc: np array bidirectional encoding at time t of input (x_t)

  Returns:
    7-tuple of np arrays all with leading dim being time,
      controller hidden state, generator hidden state, factors, 
      inferred input (ii) sample, ii mean, ii log var, log rates
  """
    keys = random.split(key, 2)
    cin = np.concatenate([xenc, f], axis=0)
    c = gru(params['con'], c, cin)
    cout = affine(params['con_out'], c)
    ii_mean, ii_logvar = np.split(cout, 2, axis=0)  # inferred input params
    ii = dists.diag_gaussian_sample(keys[0], ii_mean, ii_logvar,
                                    lfads_hps['var_min'])
    g = gru(params['gen'], g, ii)
    g = dropout(g, keys[1], keep_rate)
    f = normed_linear(params['factors'], g)
    lograte = affine(params['logrates'], f)
    return c, g, f, ii, ii_mean, ii_logvar, lograte
Exemplo n.º 2
0
def lfads_decode(params, lfads_hps, key, ic_mean, ic_logvar, xenc_t,
                 keep_rate):
    """Run the LFADS network from latent variables to log rates.

  Arguments:
    params: a dictionary of LFADS parameters
    lfads_hps: a dictionary of LFADS hyperparameters
    key: random.PRNGKey for random bits
    ic_mean: np array of generator initial condition mean
    ic_logvar: np array of generator initial condition log variance
    xenc_t: np array bidirectional encoding of input (x_t) with leading dim
      being time
    keep_rate: dropout keep rate

  Returns:
    7-tuple of np arrays all with leading dim being time,
      controller hidden state, inferred input mean, inferred input log var,
      generator hidden state, factors and log rates
  """

    ntime = lfads_hps['ntimesteps']
    key, skeys = utils.keygen(key, 2)

    # Since the factors feed back to the controller,
    #    factors_{t-1} -> controller_t -> sample_t -> generator_t -> factors_t
    # is really one big loop and therefor one RNN.
    c0 = params['con']['h0']
    g0 = dists.diag_gaussian_sample(next(skeys), ic_mean, ic_logvar,
                                    lfads_hps['var_min'])
    f0 = np.zeros((lfads_hps['factors_dim'], ))

    # Make all the randomness for all T steps at once, it's more efficient.
    # The random keys get passed into scan along with the input, so the input
    # becomes of a 2-tuple (keys, actual input).
    T = xenc_t.shape[0]
    keys_t = random.split(next(skeys), T)

    state0 = (c0, g0, f0)
    decoder = partial(lfads_decode_one_step_scan,
                      *(params, lfads_hps, keep_rate))
    _, state_and_returns_t = lax.scan(decoder, state0, (keys_t, xenc_t))
    return state_and_returns_t
Exemplo n.º 3
0
def lfads_decode(params, lfads_hps, key, ic_mean, ic_logvar, xenc_t,
                 keep_rate):
    """Run the LFADS network from latent variables to log rates.

  Arguments:
    params: a dictionary of LFADS parameters
    lfads_hps: a dictionary of LFADS hyperparameters
    key: random.PRNGKey for random bits
    ic_mean: np array of generator initial condition mean
    ic_logvar: np array of generator initial condition log variance
    xenc_t: np array bidirectional encoding of input (x_t) with leading dim
      being time
    keep_rate: dropout keep rate

  Returns:
    7-tuple of np arrays all with leading dim being time,
      controller hidden state, inferred input mean, inferred input log var,
      generator hidden state, factors and log rates
  """

    ntime = lfads_hps['ntimesteps']
    key, skeys = utils.keygen(key, 1 + 2 * ntime)

    # Since the factors feed back to the controller,
    #    factors_{t-1} -> controller_t -> sample_t -> generator_t -> factors_t
    # is really one big loop and therefor one RNN.
    c = c0 = params['con']['h0']
    g = g0 = dists.diag_gaussian_sample(next(skeys), ic_mean, ic_logvar,
                                        lfads_hps['var_min'])
    f = f0 = np.zeros((lfads_hps['factors_dim'], ))
    c_t = []
    ii_mean_t = []
    ii_logvar_t = []
    ii_t = []
    gen_t = []
    factor_t = []
    for xenc in xenc_t:
        cin = np.concatenate([xenc, f], axis=0)
        c = gru(params['con'], c, cin)
        cout = affine(params['con_out'], c)
        ii_mean, ii_logvar = np.split(cout, 2, axis=0)  # inferred input params
        ii = dists.diag_gaussian_sample(next(skeys), ii_mean, ii_logvar,
                                        lfads_hps['var_min'])
        g = gru(params['gen'], g, ii)
        g = dropout(g, next(skeys), keep_rate)
        f = normed_linear(params['factors'], g)
        # Save everything.
        c_t.append(c)
        ii_t.append(ii)
        gen_t.append(g)
        ii_mean_t.append(ii_mean)
        ii_logvar_t.append(ii_logvar)
        factor_t.append(f)

    c_t = np.array(c_t)
    ii_t = np.array(ii_t)
    gen_t = np.array(gen_t)
    ii_mean_t = np.array(ii_mean_t)
    ii_logvar_t = np.array(ii_logvar_t)
    factor_t = np.array(factor_t)
    lograte_t = batch_affine(params['logrates'], factor_t)

    return c_t, ii_mean_t, ii_logvar_t, ii_t, gen_t, factor_t, lograte_t