示例#1
0
    def run_epoch(rng, _opt_state, epoch_idx):
        _rng, dat_keys = utils.keygen(rng, 1)
        _rng, batch_keys = utils.keygen(_rng, num_batches)

        # Randomize epoch data.
        epoch_data = random.shuffle(next(dat_keys), X_train, axis=0)

        def update(batch_idx, __opt_state):
            """Update func for gradients, includes gradient clipping."""
            kl_warmup = kl_warmup_fun(epoch_idx * num_batches + batch_idx)

            batch_data = lax.dynamic_slice_in_dim(epoch_data,
                                                  batch_idx * BATCH_SIZE,
                                                  BATCH_SIZE,
                                                  axis=0)
            batch_data = batch_data.astype(np.float32)

            params = get_params(__opt_state)
            grads = grad(loss_fn)(params, batch_data, next(batch_keys),
                                  BATCH_SIZE, ic_prior, VAR_MIN, kl_warmup,
                                  L2_REG)
            clipped_grads = optimizers.clip_grads(grads, MAX_GRAD_NORM)

            return opt_update(batch_idx, clipped_grads, __opt_state)

        return lax.fori_loop(0, num_batches, update, _opt_state)
示例#2
0
def gru_params(rng, n, u, ifactor=1.0, hfactor=1.0, hscale=0.0):
    """
    Helper function for GRU parameter initialization.
    Used twice in the BidirectionalGRU (encoder) and once in the FreeEvolveGRU (decoder).
    :param rng:
    :param n: hidden state size
    :param u: input size
    :param ifactor: scaling factor for input weights
    :param hfactor: scaling factor for hidden -> hidden weights
    :param hscale: scale on h0 initial condition
    :return:
    """
    rng, keys = utils.keygen(rng, 5)

    ifac = ifactor / np.sqrt(u)
    hfac = hfactor / np.sqrt(n)

    wRUH = random.normal(next(keys), (n + n, n)) * hfac
    wRUX = random.normal(next(keys), (n + n, u)) * ifac
    wRUHX = np.concatenate([wRUH, wRUX], axis=1)

    wCH = random.normal(next(keys), (n, n)) * hfac
    wCX = random.normal(next(keys), (n, u)) * ifac
    wCHX = np.concatenate([wCH, wCX], axis=1)

    return {
        'h0': random.normal(next(keys), (n, )) * hscale,
        'wRUHX': wRUHX,
        'wCHX': wCHX,
        'bRU': np.zeros((n + n, )),
        'bC': np.zeros((n, ))
    }
示例#3
0
 def init_fun(rng, input_shape):
     u = input_shape[-1]
     key, keys = utils.keygen(rng, 1)
     ifac = ifactor / np.sqrt(u)
     params = {'w': random.normal(next(keys), (output_size, u)) * ifac}
     output_shape = input_shape[:-1] + (output_size, )
     return output_shape, params
示例#4
0
def run_gru(params, x_t, h0=None, keep_rate=1.0, rng=None):
    """
    Run a GRU module forward in time.

    Arguments:
    params: dictionary of parameters for gru (keys: 'wRUHX', 'bRU', 'wCHX', 'bC') and optionally 'h0'
    x_t: np array data for RNN input with leading dim being time
    h0: initial condition for running rnn, which overwrites param h0
    keep_rate:
    rng:

    Returns:
    np array of rnn applied to time data with leading dim being time
    """
    if rng is None:
        raise ValueError("GRU dropout requires rng key.")
    rng, keys = utils.keygen(rng, len(x_t))
    h = h0 if h0 is not None else params['h0']
    h_t = []
    for x in x_t:
        h = gru(params, h, x)
        # Do dropout on hidden state
        # TODO: Only do dropout during training.
        keep = random.bernoulli(next(keys), keep_rate, h.shape)
        h = np.where(keep, h / keep_rate, 0)
        h_t.append(h)

    return np.array(h_t)
示例#5
0
 def apply_fun(params, inputs, rng=None):
     if rng is None:
         raise ValueError("SampleDistrib apply_fun requires rng key.")
     rng, keys = utils.keygen(rng, 1)
     _mean, _logvar = np.split(inputs, 2, axis=0)
     samples = dists.diag_gaussian_sample(next(keys), _mean, _logvar,
                                          var_min)
     return samples
示例#6
0
 def init_fun(rng, input_shape):
     output_shape = (evolve_steps, n_hidden)
     rng, keys = utils.keygen(rng, 1)
     gen_params = gru_params(next(keys), n_hidden, 1)
     # Modify params so x weights are all 0. Not necessary because input is always 0.
     # gen_params['wRUHX'][:, -1] = 0
     # gen_params['wCHX'][:, -1] = 0
     return output_shape, gen_params
示例#7
0
 def apply_fun(params, x_t, rng=None):
     if rng is None:
         raise ValueError("BidirectionalGRU apply_fun requires rng key.")
     rng, keys = utils.keygen(rng, 2)
     fwd_enc_t = run_gru(params['fwd_rnn'], x_t, rng=next(keys))
     bwd_enc_t = np.flipud(
         run_gru(params['bwd_rnn'], np.flipud(x_t), rng=next(keys)))
     enc_ends = np.concatenate([bwd_enc_t[0], fwd_enc_t[-1]], axis=1)
     return enc_ends
示例#8
0
 def init_fun(rng, input_shape):
     u = input_shape[-1]
     output_shape = input_shape[:-2] + (2 * n_hidden, )
     rng, keys = utils.keygen(rng, 2)
     ic_enc_params = {
         'fwd_rnn': gru_params(next(keys), n_hidden, u),
         'bwd_rnn': gru_params(next(keys), n_hidden, u)
     }
     return output_shape, ic_enc_params
示例#9
0
def lfads_onestep(params, rng, data):
    rng, keys = utils.keygen(rng, 2)
    enc_params, dec_params = params
    latent_vars = encdec[0](enc_params, data, rng=next(keys))
    neuron_log_rates = encdec[1](dec_params, latent_vars, rng=next(keys))
    ic_post_mean, ic_post_logvar = np.split(latent_vars, 2, axis=0)
    return {
        'ic_post_mean': ic_post_mean,
        'ic_post_logvar': ic_post_logvar,
        'neuron_log_rates': neuron_log_rates
    }
示例#10
0
def kl_gauss_ar1(key, z_mean_t, z_logvar_t, ar1_params, varmin=1e-16):
    """KL using samples for multi-dim gaussian (thru time) and AR(1) process.
    To sample KL(q||p), we sample
          ln q - ln p
    by drawing samples from q and averaging. q is multidim gaussian, p
    is AR(1) process.

    Arguments:
      key: random.PRNGKey for random bits
      z_mean_t: np.array of means with leading dim being time
      z_logvar_t: np.array of log vars, leading dim is time
      ar1_params: dictionary of ar1 parameters, log noise var and autocorr tau
      varmin: minimal variance, useful for numerical stability

    Returns:
      sampled KL divergence between
    """
    ll = diag_gaussian_log_likelihood
    sample = diag_gaussian_sample
    nkeys = z_mean_t.shape[0]
    key, skeys = utils.keygen(key, nkeys)

    # Convert AR(1) parameters.
    # z_t = c + phi z_{t-1} + eps, eps \in N(0, noise var)
    ar1_mean = ar1_params['mean']
    ar1_lognoisevar = np.log(np.exp(ar1_params['lognvar'] + varmin))
    phi = np.exp(-np.exp(-ar1_params['logatau']))
    # The process variance a function of noise variance, so I added varmin above.
    # This affects log-likelihood funtions below, also.
    logprocessvar = ar1_lognoisevar - (np.log(1-phi) + np.log(1+phi))

    # Sample first AR(1) step according to process variance.
    z0 = sample(next(skeys), z_mean_t[0], z_logvar_t[0], varmin)
    logq = ll(z0, z_mean_t[0], z_logvar_t[0], varmin)
    logp = ll(z0, ar1_mean, logprocessvar, 0.0)
    z_last = z0

    # Sample the remaining time steps with adjusted mean and noise variance.
    for z_mean, z_logvar in zip(z_mean_t[1:], z_logvar_t[1:]):
        z = sample(next(skeys), z_mean, z_logvar, varmin)
        logq += ll(z, z_mean, z_logvar, varmin)
        logp += ll(z, ar1_mean + phi * z_last, ar1_lognoisevar, 0.0)
        z_last = z

    kl = logq - logp
    return kl
示例#11
0
    num_batches = int(n_trials * EPOCHS /
                      BATCH_SIZE)  # how many batches do we train

    # Get the model #
    encoder_init, encode = LFADSEncoderModel(P_DROPOUT, ENC_DIM, IC_DIM)
    decoder_init, decode = LFADSDecoderModel(VAR_MIN, P_DROPOUT, IC_DIM,
                                             n_timesteps, FACTORS_DIM,
                                             n_neurons)
    encdec = encode, decode

    # Init the model
    ic_prior = {
        'mean': 0.0 * np.ones((IC_DIM, )),
        'logvar': np.log(IC_PRIOR_VAR) * np.ones((IC_DIM, ))
    }
    rng, keys = utils.keygen(rng, 2)
    latent_shape, init_encoder_params = encoder_init(next(keys),
                                                     (n_timesteps, n_neurons))
    decoded_shape, init_decoder_params = decoder_init(next(keys), latent_shape)
    init_params = init_encoder_params, init_decoder_params

    # Optimizer #


    def kl_warmup_fun(batch_idx):
        progress_frac = ((batch_idx - kl_warmup_start) /
                         (kl_warmup_end - kl_warmup_start))
        _warmup = np.where(batch_idx < kl_warmup_start, kl_min,
                           (kl_max - kl_min) * progress_frac + kl_min)
        return np.where(batch_idx > kl_warmup_end, kl_max, _warmup)