def gru_params(key, **rnn_hps): """Generate GRU parameters Arguments: key: random.PRNGKey for random bits n: hidden state size u: input size i_factor: scaling factor for input weights h_factor: scaling factor for hidden -> hidden weights h_scale: scale on h0 initial condition Returns: a dictionary of parameters """ key, skeys = utils.keygen(key, 6) u = rnn_hps['u'] # input n = rnn_hps['n'] # hidden o = rnn_hps['o'] # output ifactor = rnn_hps['i_factor'] / np.sqrt(u) hfactor = rnn_hps['h_factor'] / np.sqrt(n) hscale = rnn_hps['h_scale'] wRUH = random.normal(next(skeys), (n + n, n)) * hfactor wRUX = random.normal(next(skeys), (n + n, u)) * ifactor wRUHX = np.concatenate([wRUH, wRUX], axis=1) wCH = random.normal(next(skeys), (n, n)) * hfactor wCX = random.normal(next(skeys), (n, u)) * ifactor wCHX = np.concatenate([wCH, wCX], axis=1) # Include the readout params in the GRU, though technically # not a part of the GRU. pfactor = 1.0 / np.sqrt(n) wO = random.normal(next(skeys), (o, n)) * pfactor bO = np.zeros((o, )) return { 'h0': random.normal(next(skeys), (n, )) * hscale, 'wRUHX': wRUHX, 'wCHX': wCHX, 'bRU': np.zeros((n + n, )), 'bC': np.zeros((n, )), 'wO': wO, 'bO': bO }
def build_input_and_target_binary_decision(input_params, key): """Build white noise input and decision targets. The decision is whether the white noise input has a perfect integral greater than, or less than, 0. Output a +1 or -1, respectively. Arguments: inputs_params: tuple of parameters for this decision task key: jax random key for making randomness Returns: 3-tuple of inputs, targets, and the target mask, indicating which time points have optimization pressure on them""" bias_val, stddev_val, T, ntime = input_params dt = T / ntime # Create the white noise input. key, skeys = utils.keygen(key, 2) random_sample = random.normal(next(skeys), (1, ))[0] bias = bias_val * 2.0 * (random_sample - 0.5) stddev = stddev_val / np.sqrt(dt) random_samples = random.normal(next(skeys), (ntime, )) noise_t = stddev * random_samples white_noise_t = bias + noise_t # * dt, intentionally left off to get output scaling in O(1). pure_integration_t = np.cumsum(white_noise_t) decision = 2.0 * ((pure_integration_t[-1] > 0.0) - 0.5) targets_t = np.zeros(pure_integration_t.shape[0] - 1) targets_t = np.concatenate( [targets_t, np.array([decision], dtype=float)], axis=0) inputs_tx1 = np.expand_dims(white_noise_t, axis=1) targets_tx1 = np.expand_dims(targets_t, axis=1) target_mask = np.array([ntime - 1]) # When target is defined. return inputs_tx1, targets_tx1, target_mask