コード例 #1
0
def gru_params(key, **rnn_hps):
    """Generate GRU parameters

  Arguments:
    key: random.PRNGKey for random bits
    n: hidden state size
    u: input size
    i_factor: scaling factor for input weights
    h_factor: scaling factor for hidden -> hidden weights
    h_scale: scale on h0 initial condition

  Returns:
    a dictionary of parameters
  """
    key, skeys = utils.keygen(key, 6)
    u = rnn_hps['u']  # input
    n = rnn_hps['n']  # hidden
    o = rnn_hps['o']  # output

    ifactor = rnn_hps['i_factor'] / np.sqrt(u)
    hfactor = rnn_hps['h_factor'] / np.sqrt(n)
    hscale = rnn_hps['h_scale']

    wRUH = random.normal(next(skeys), (n + n, n)) * hfactor
    wRUX = random.normal(next(skeys), (n + n, u)) * ifactor
    wRUHX = np.concatenate([wRUH, wRUX], axis=1)

    wCH = random.normal(next(skeys), (n, n)) * hfactor
    wCX = random.normal(next(skeys), (n, u)) * ifactor
    wCHX = np.concatenate([wCH, wCX], axis=1)

    # Include the readout params in the GRU, though technically
    # not a part of the GRU.
    pfactor = 1.0 / np.sqrt(n)
    wO = random.normal(next(skeys), (o, n)) * pfactor
    bO = np.zeros((o, ))
    return {
        'h0': random.normal(next(skeys), (n, )) * hscale,
        'wRUHX': wRUHX,
        'wCHX': wCHX,
        'bRU': np.zeros((n + n, )),
        'bC': np.zeros((n, )),
        'wO': wO,
        'bO': bO
    }
コード例 #2
0
def build_input_and_target_binary_decision(input_params, key):
    """Build white noise input and decision targets.

  The decision is whether the white noise input has a perfect integral
  greater than, or less than, 0. Output a +1 or -1, respectively.

  Arguments: 
    inputs_params: tuple of parameters for this decision task
    key: jax random key for making randomness

  Returns:
    3-tuple of inputs, targets, and the target mask, indicating 
      which time points have optimization pressure on them"""

    bias_val, stddev_val, T, ntime = input_params
    dt = T / ntime

    # Create the white noise input.
    key, skeys = utils.keygen(key, 2)
    random_sample = random.normal(next(skeys), (1, ))[0]
    bias = bias_val * 2.0 * (random_sample - 0.5)
    stddev = stddev_val / np.sqrt(dt)
    random_samples = random.normal(next(skeys), (ntime, ))
    noise_t = stddev * random_samples
    white_noise_t = bias + noise_t

    # * dt, intentionally left off to get output scaling in O(1).
    pure_integration_t = np.cumsum(white_noise_t)
    decision = 2.0 * ((pure_integration_t[-1] > 0.0) - 0.5)
    targets_t = np.zeros(pure_integration_t.shape[0] - 1)
    targets_t = np.concatenate(
        [targets_t, np.array([decision], dtype=float)], axis=0)
    inputs_tx1 = np.expand_dims(white_noise_t, axis=1)
    targets_tx1 = np.expand_dims(targets_t, axis=1)
    target_mask = np.array([ntime - 1])  # When target is defined.
    return inputs_tx1, targets_tx1, target_mask