Exemplo n.º 1
0
                                                  nreplications)

  # Split the data, inputs, labels and times into train vs. validation.
  rates_train, rates_valid = \
      split_list_by_inds(rates, train_inds, valid_inds)
  noisy_data_train, noisy_data_valid = \
      split_list_by_inds(noisy_data, train_inds, valid_inds)
  input_train, inputs_valid = \
      split_list_by_inds(inputs, train_inds, valid_inds)
  condition_labels_train, condition_labels_valid = \
      split_list_by_inds(condition_labels, train_inds, valid_inds)
  input_times_train, input_times_valid = \
      split_list_by_inds(input_times, train_inds, valid_inds)

  # Turn rates, noisy_data, and input into numpy arrays.
  rates_train = nparray_and_transpose(rates_train)
  rates_valid = nparray_and_transpose(rates_valid)
  noisy_data_train = nparray_and_transpose(noisy_data_train)
  noisy_data_valid = nparray_and_transpose(noisy_data_valid)
  input_train = nparray_and_transpose(input_train)
  inputs_valid = nparray_and_transpose(inputs_valid)

  # Note that we put these 'truth' rates and input into this
  # structure, the only data that is used in LFADS are the noisy
  # data e.g. spike trains.  The rest is either for printing or posterity.
  data = {'train_truth': rates_train,
          'valid_truth': rates_valid,
          'input_train_truth' : input_train,
          'input_valid_truth' : inputs_valid,
          'train_data' : noisy_data_train,
          'valid_data' : noisy_data_valid,
Exemplo n.º 2
0
  truth_data_e = normalize_rates(data_e, E, N)

spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt,
                              max_firing_rate=FLAGS.max_firing_rate)
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
                                                nspikifications)

data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e,
                                                        train_inds,
                                                        valid_inds)
data_train_spiking, data_valid_spiking = split_list_by_inds(spiking_data_e,
                                                            train_inds,
                                                            valid_inds)

data_train_truth = nparray_and_transpose(data_train_truth)
data_valid_truth = nparray_and_transpose(data_valid_truth)
data_train_spiking = nparray_and_transpose(data_train_spiking)
data_valid_spiking = nparray_and_transpose(data_valid_spiking)

# save down the inputs used to generate this data
train_inputs_u, valid_inputs_u = split_list_by_inds(u_e,
                                                    train_inds,
                                                    valid_inds)
train_inputs_u = nparray_and_transpose(train_inputs_u)
valid_inputs_u = nparray_and_transpose(valid_inputs_u)

# save down the network outputs (may be useful later)
train_outputs_u, valid_outputs_u = split_list_by_inds(outs_e,
                                                      train_inds,
                                                      valid_inds)
    else:
        rates.append(rates_b[trial])
        spikes.append(spikes_b[trial])

# split into train and validation sets
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
                                                nreplications)

rates_train, rates_valid = split_list_by_inds(rates, train_inds, valid_inds)
spikes_train, spikes_valid = split_list_by_inds(spikes, train_inds, valid_inds)
condition_labels_train, condition_labels_valid = split_list_by_inds(
    condition_labels, train_inds, valid_inds)
ext_input_train, ext_input_valid = split_list_by_inds(ext_input, train_inds,
                                                      valid_inds)

rates_train = nparray_and_transpose(rates_train)
rates_valid = nparray_and_transpose(rates_valid)
spikes_train = nparray_and_transpose(spikes_train)
spikes_valid = nparray_and_transpose(spikes_valid)

# add train_ext_input and valid_ext input
data = {
    'train_truth': rates_train,
    'valid_truth': rates_valid,
    'train_data': spikes_train,
    'valid_data': spikes_valid,
    'train_ext_input': np.array(ext_input_train),
    'valid_ext_input': np.array(ext_input_valid),
    'train_percentage': train_percentage,
    'nreplications': nreplications,
    'dt': FLAGS.dt,
Exemplo n.º 4
0
  truth_data_e = normalize_rates(data_e, E, N)

spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt,
                              max_firing_rate=FLAGS.max_firing_rate)
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
                                                nreplications)

data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e,
                                                        train_inds,
                                                        valid_inds)
data_train_spiking, data_valid_spiking = split_list_by_inds(spiking_data_e,
                                                            train_inds,
                                                            valid_inds)

data_train_truth = nparray_and_transpose(data_train_truth)
data_valid_truth = nparray_and_transpose(data_valid_truth)
data_train_spiking = nparray_and_transpose(data_train_spiking)
data_valid_spiking = nparray_and_transpose(data_valid_spiking)

# save down the inputs used to generate this data
train_inputs_u, valid_inputs_u = split_list_by_inds(u_e,
                                                    train_inds,
                                                    valid_inds)
train_inputs_u = nparray_and_transpose(train_inputs_u)
valid_inputs_u = nparray_and_transpose(valid_inputs_u)

# save down the network outputs (may be useful later)
train_outputs_u, valid_outputs_u = split_list_by_inds(outs_e,
                                                      train_inds,
                                                      valid_inds)
Exemplo n.º 5
0
  else:
    rates.append(rates_b[trial])
    spikes.append(spikes_b[trial])

# split into train and validation sets
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
                                                nreplications)

rates_train, rates_valid = split_list_by_inds(rates, train_inds, valid_inds)
spikes_train, spikes_valid = split_list_by_inds(spikes, train_inds, valid_inds)
condition_labels_train, condition_labels_valid = split_list_by_inds(
    condition_labels, train_inds, valid_inds)
ext_input_train, ext_input_valid = split_list_by_inds(
    ext_input, train_inds, valid_inds)

rates_train = nparray_and_transpose(rates_train)
rates_valid = nparray_and_transpose(rates_valid)
spikes_train = nparray_and_transpose(spikes_train)
spikes_valid = nparray_and_transpose(spikes_valid)

# add train_ext_input and valid_ext input
data = {'train_truth': rates_train,
        'valid_truth': rates_valid,
        'train_data' : spikes_train,
        'valid_data' : spikes_valid,
        'train_ext_input' : np.array(ext_input_train),
        'valid_ext_input': np.array(ext_input_valid),
        'train_percentage' : train_percentage,
        'nreplications' : nreplications,
        'dt' : FLAGS.dt,
        'P_sxn' : P_nxn,