예제 #1
0
  for m in range(n):
    P_sxn = np.roll(P_sxn, S, axis=1)

  if input_magnitude > 0.0:
    # time of "hits" randomly chosen between [1/4 and 3/4] of total time
    input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4)
  else:
    input_times = None

  rates, x0s, inputs = \
      generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn,
                    input_magnitude=input_magnitude,
                    input_times=input_times)

  if FLAGS.noise_type == "poisson":
    noisy_data = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
  elif FLAGS.noise_type == "gaussian":
    noisy_data = gaussify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
  else:
    raise ValueError("Only noise types supported are poisson or gaussian")

    # split into train and validation sets
  train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
                                                  nreplications)

  # Split the data, inputs, labels and times into train vs. validation.
  rates_train, rates_valid = \
      split_list_by_inds(rates, train_inds, valid_inds)
  noisy_data_train, noisy_data_valid = \
      split_list_by_inds(noisy_data, train_inds, valid_inds)
  input_train, inputs_valid = \
예제 #2
0
      feed_dict[inputs_ph_t[t]] = np.reshape(u_1xt[:,t], (batch_size,-1))

    states_t_bxn, outputs_t_bxn = sess.run([states_t, outputs_t],
                                           feed_dict=feed_dict)
    states_nxt = np.transpose(np.squeeze(np.asarray(states_t_bxn)))
    outputs_t_bxn = np.squeeze(np.asarray(outputs_t_bxn))
    r_sxt = np.dot(P_nxn, states_nxt)

    for s in xrange(nspikifications):
      data_e.append(r_sxt)
      u_e.append(u_1xt)
      outs_e.append(outputs_t_bxn)

  truth_data_e = normalize_rates(data_e, E, N)

spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt,
                              max_firing_rate=FLAGS.max_firing_rate)
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
                                                nspikifications)

data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e,
                                                        train_inds,
                                                        valid_inds)
data_train_spiking, data_valid_spiking = split_list_by_inds(spiking_data_e,
                                                            train_inds,
                                                            valid_inds)

data_train_truth = nparray_and_transpose(data_train_truth)
data_valid_truth = nparray_and_transpose(data_valid_truth)
data_train_spiking = nparray_and_transpose(data_train_spiking)
data_valid_spiking = nparray_and_transpose(data_valid_spiking)
    for ns in range(nreplications):
        condition_labels.append(condition_number)
    condition_number += 1
x0s = np.concatenate(x0s, axis=1)

P_nxn = rng.randn(N, N) / np.sqrt(N)

# generate trials for both RNNs
rates_a, x0s_a, _ = generate_data(rnn_a,
                                  T=T,
                                  E=E,
                                  x0s=x0s,
                                  P_sxn=P_nxn,
                                  input_magnitude=0.0,
                                  input_times=None)
spikes_a = spikify_data(rates_a, rng, rnn_a['dt'], rnn_a['max_firing_rate'])

rates_b, x0s_b, _ = generate_data(rnn_b,
                                  T=T,
                                  E=E,
                                  x0s=x0s,
                                  P_sxn=P_nxn,
                                  input_magnitude=0.0,
                                  input_times=None)
spikes_b = spikify_data(rates_b, rng, rnn_b['dt'], rnn_b['max_firing_rate'])

# not the best way to do this but E is small enough
rates = []
spikes = []
for trial in xrange(E):
    if rnn_to_use[trial] == 0:
예제 #4
0
      feed_dict[inputs_ph_t[t]] = np.reshape(u_1xt[:,t], (batch_size,-1))

    states_t_bxn, outputs_t_bxn = sess.run([states_t, outputs_t],
                                           feed_dict=feed_dict)
    states_nxt = np.transpose(np.squeeze(np.asarray(states_t_bxn)))
    outputs_t_bxn = np.squeeze(np.asarray(outputs_t_bxn))
    r_sxt = np.dot(P_nxn, states_nxt)

    for s in xrange(nreplications):
      data_e.append(r_sxt)
      u_e.append(u_1xt)
      outs_e.append(outputs_t_bxn)

  truth_data_e = normalize_rates(data_e, E, N)

spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt,
                              max_firing_rate=FLAGS.max_firing_rate)
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
                                                nreplications)

data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e,
                                                        train_inds,
                                                        valid_inds)
data_train_spiking, data_valid_spiking = split_list_by_inds(spiking_data_e,
                                                            train_inds,
                                                            valid_inds)

data_train_truth = nparray_and_transpose(data_train_truth)
data_valid_truth = nparray_and_transpose(data_valid_truth)
data_train_spiking = nparray_and_transpose(data_train_spiking)
data_valid_spiking = nparray_and_transpose(data_valid_spiking)
예제 #5
0
condition_labels = []
condition_number = 0
for c in range(C):
  x0 = FLAGS.x0_std * rng.randn(N, 1)
  x0s.append(np.tile(x0, nreplications))
  for ns in range(nreplications):
    condition_labels.append(condition_number)
  condition_number += 1
x0s = np.concatenate(x0s, axis=1)

P_nxn = rng.randn(N, N) / np.sqrt(N)

# generate trials for both RNNs
rates_a, x0s_a, _ = generate_data(rnn_a, T=T, E=E, x0s=x0s, P_sxn=P_nxn,
                                  input_magnitude=0.0, input_times=None)
spikes_a = spikify_data(rates_a, rng, rnn_a['dt'], rnn_a['max_firing_rate'])

rates_b, x0s_b, _ = generate_data(rnn_b, T=T, E=E, x0s=x0s, P_sxn=P_nxn,
                                  input_magnitude=0.0, input_times=None)
spikes_b = spikify_data(rates_b, rng, rnn_b['dt'], rnn_b['max_firing_rate'])

# not the best way to do this but E is small enough
rates = []
spikes = []
for trial in xrange(E):
  if rnn_to_use[trial] == 0:
    rates.append(rates_a[trial])
    spikes.append(spikes_a[trial])
  else:
    rates.append(rates_b[trial])
    spikes.append(spikes_b[trial])