nreplications) # Split the data, inputs, labels and times into train vs. validation. rates_train, rates_valid = \ split_list_by_inds(rates, train_inds, valid_inds) noisy_data_train, noisy_data_valid = \ split_list_by_inds(noisy_data, train_inds, valid_inds) input_train, inputs_valid = \ split_list_by_inds(inputs, train_inds, valid_inds) condition_labels_train, condition_labels_valid = \ split_list_by_inds(condition_labels, train_inds, valid_inds) input_times_train, input_times_valid = \ split_list_by_inds(input_times, train_inds, valid_inds) # Turn rates, noisy_data, and input into numpy arrays. rates_train = nparray_and_transpose(rates_train) rates_valid = nparray_and_transpose(rates_valid) noisy_data_train = nparray_and_transpose(noisy_data_train) noisy_data_valid = nparray_and_transpose(noisy_data_valid) input_train = nparray_and_transpose(input_train) inputs_valid = nparray_and_transpose(inputs_valid) # Note that we put these 'truth' rates and input into this # structure, the only data that is used in LFADS are the noisy # data e.g. spike trains. The rest is either for printing or posterity. data = {'train_truth': rates_train, 'valid_truth': rates_valid, 'input_train_truth' : input_train, 'input_valid_truth' : inputs_valid, 'train_data' : noisy_data_train, 'valid_data' : noisy_data_valid,
truth_data_e = normalize_rates(data_e, E, N) spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt, max_firing_rate=FLAGS.max_firing_rate) train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage, nspikifications) data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e, train_inds, valid_inds) data_train_spiking, data_valid_spiking = split_list_by_inds(spiking_data_e, train_inds, valid_inds) data_train_truth = nparray_and_transpose(data_train_truth) data_valid_truth = nparray_and_transpose(data_valid_truth) data_train_spiking = nparray_and_transpose(data_train_spiking) data_valid_spiking = nparray_and_transpose(data_valid_spiking) # save down the inputs used to generate this data train_inputs_u, valid_inputs_u = split_list_by_inds(u_e, train_inds, valid_inds) train_inputs_u = nparray_and_transpose(train_inputs_u) valid_inputs_u = nparray_and_transpose(valid_inputs_u) # save down the network outputs (may be useful later) train_outputs_u, valid_outputs_u = split_list_by_inds(outs_e, train_inds, valid_inds)
else: rates.append(rates_b[trial]) spikes.append(spikes_b[trial]) # split into train and validation sets train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage, nreplications) rates_train, rates_valid = split_list_by_inds(rates, train_inds, valid_inds) spikes_train, spikes_valid = split_list_by_inds(spikes, train_inds, valid_inds) condition_labels_train, condition_labels_valid = split_list_by_inds( condition_labels, train_inds, valid_inds) ext_input_train, ext_input_valid = split_list_by_inds(ext_input, train_inds, valid_inds) rates_train = nparray_and_transpose(rates_train) rates_valid = nparray_and_transpose(rates_valid) spikes_train = nparray_and_transpose(spikes_train) spikes_valid = nparray_and_transpose(spikes_valid) # add train_ext_input and valid_ext input data = { 'train_truth': rates_train, 'valid_truth': rates_valid, 'train_data': spikes_train, 'valid_data': spikes_valid, 'train_ext_input': np.array(ext_input_train), 'valid_ext_input': np.array(ext_input_valid), 'train_percentage': train_percentage, 'nreplications': nreplications, 'dt': FLAGS.dt,
truth_data_e = normalize_rates(data_e, E, N) spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt, max_firing_rate=FLAGS.max_firing_rate) train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage, nreplications) data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e, train_inds, valid_inds) data_train_spiking, data_valid_spiking = split_list_by_inds(spiking_data_e, train_inds, valid_inds) data_train_truth = nparray_and_transpose(data_train_truth) data_valid_truth = nparray_and_transpose(data_valid_truth) data_train_spiking = nparray_and_transpose(data_train_spiking) data_valid_spiking = nparray_and_transpose(data_valid_spiking) # save down the inputs used to generate this data train_inputs_u, valid_inputs_u = split_list_by_inds(u_e, train_inds, valid_inds) train_inputs_u = nparray_and_transpose(train_inputs_u) valid_inputs_u = nparray_and_transpose(valid_inputs_u) # save down the network outputs (may be useful later) train_outputs_u, valid_outputs_u = split_list_by_inds(outs_e, train_inds, valid_inds)
else: rates.append(rates_b[trial]) spikes.append(spikes_b[trial]) # split into train and validation sets train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage, nreplications) rates_train, rates_valid = split_list_by_inds(rates, train_inds, valid_inds) spikes_train, spikes_valid = split_list_by_inds(spikes, train_inds, valid_inds) condition_labels_train, condition_labels_valid = split_list_by_inds( condition_labels, train_inds, valid_inds) ext_input_train, ext_input_valid = split_list_by_inds( ext_input, train_inds, valid_inds) rates_train = nparray_and_transpose(rates_train) rates_valid = nparray_and_transpose(rates_valid) spikes_train = nparray_and_transpose(spikes_train) spikes_valid = nparray_and_transpose(spikes_valid) # add train_ext_input and valid_ext input data = {'train_truth': rates_train, 'valid_truth': rates_valid, 'train_data' : spikes_train, 'valid_data' : spikes_valid, 'train_ext_input' : np.array(ext_input_train), 'valid_ext_input': np.array(ext_input_valid), 'train_percentage' : train_percentage, 'nreplications' : nreplications, 'dt' : FLAGS.dt, 'P_sxn' : P_nxn,