def getTrainSet(behavior, ephys, config): """ get train set from run epoch based on configuration Args: config - configuration object behavior - behavior data during run epoch ephys - spike data of run epoch returns: train - train segments training_time - total time of the training dataset """ print "###################################################################################" print "Train set:" # select run segments print "run speed={}, min run duration={}, binsize={}".format( config.run_speed, config.min_run_duration, config.bin_size_run) run = seg.fromlogical(behavior["speed"] > config.run_speed, x=behavior["time"]) run = run[run.duration > config.min_run_duration] # split segments into independent bins bin_size_test = config.bin_size_run run_binned = run.split(size=bin_size_test) r_data_not_to_use = config.r_data_not_to_use first_bin_idx = int(r_data_not_to_use * len(run_binned)) used_data = run_binned[first_bin_idx:] r_train = config.r_train n_train = int(r_train * len(used_data)) train = used_data[0:n_train] total_time = ephys[u'TT1']['spike_times'][-1] - ephys[u'TT1'][ 'spike_times'][0] print "total recording time = {} min, start time = {} s, end time = {} s".format( total_time / 60, ephys[u'TT1']['spike_times'][0], ephys[u'TT1']['spike_times'][-1]) training_time = np.sum(train.duration) print "training_time={} s".format(training_time) # find number of spikes within train datasets n_spikes = [] sum_tt = [] mean_tt = [] no_spike_bin_idx = [] spike_bin_idx = [] n_spikes_all = np.zeros(len(train)) for i, key in enumerate(ephys): tt = ephys[key] n_spikes.append(train.contains(tt['spike_times'])[1]) n_spikes_all = n_spikes_all + n_spikes[i] sum_tt.append(sum(n_spikes[i])) # remove bins without spike for j, n in enumerate(n_spikes_all): if n == 0: no_spike_bin_idx.append(j) else: spike_bin_idx.append(j) if len(no_spike_bin_idx) > 0: train = train[spike_bin_idx] print "number spikes for training:" print np.sum(sum_tt) print "" return train, training_time
def getTestSet(behavior, ephys, config, event=[], replay=False, rm_no_spk=True, count_bins_each_event=True): """ get train set from run epoch based on configuration Args: config - configuration object behavior - behavior data during run epoch ephys - spike data of run/sleep epoch event - replay events replay - whether test replay data or run data run : False replay: True returns: test_binned - test segments event_bins - number of spikes in each event n_spikes_all - number of spikes in each bin """ if replay: test = seg(event['postNREMevent']) bin_size_test = config.bin_size_sleep else: run = seg.fromlogical(behavior["speed"] > config.run_speed, x=behavior["time"]) run = run[run.duration > config.min_run_duration] # split segments into independent bins bin_size_test = config.bin_size_run run_binned = run.split(size=bin_size_test) r_data_not_to_use = config.r_data_not_to_use first_bin_idx = int(r_data_not_to_use * len(run_binned)) used_data = run_binned[first_bin_idx:] r_train = config.r_train n_train = int(r_train * len(used_data)) test = used_data[n_train:] bin_size_test = config.bin_size_run print "###################################################################################" print "Test set:" testing_time = np.sum(test.duration) print "testing_time={} s".format(testing_time) test_binned = test.split(size=bin_size_test) print "binsize={}, test bins = {}".format(bin_size_test, len(test_binned)) # find number of spikes within test datasets n_spikes = [] sum_tt = [] mean_tt = [] no_spike_bin_idx = [] spike_bin_idx = [] n_spikes_all = np.zeros(len(test_binned)) max_bin = 0 max_spike = 0 for i, key in enumerate(ephys): tt = ephys[key] n_spikes.append(test_binned.contains(tt['spike_times'])[1]) n_spikes_all = n_spikes_all + n_spikes[i] sum_tt.append(sum(n_spikes[i])) mean_tt.append(np.mean(n_spikes[i])) print "get spike count done" n_spikes_bin = [] for i in range(len(n_spikes_all)): n_spikes_bin.append( np.asarray([n_spikes[j][i] for j in range(len(ephys))], dtype=np.int32)) if n_spikes_all[i] > max_spike: max_spike = n_spikes_all[i] max_bin = i print "get spike count per bin done" if rm_no_spk: for j, n in enumerate(n_spikes_all): if n == 0: no_spike_bin_idx.append(j) else: spike_bin_idx.append(j) if len(no_spike_bin_idx) > 0: test_binned = test_binned[spike_bin_idx] print "{} no spike bins removed".format(len(no_spike_bin_idx)) testing_time = np.sum(test_binned.duration) print "testing_time={} s".format(testing_time) print "" if count_bins_each_event: # get number of bins in each event if replay: event_bins = [[]] * len(test) for j, evnt_bin in enumerate(test_binned): for i, env in enumerate(test): if env[0] <= evnt_bin[0] and env[1] >= evnt_bin[1]: event_bins[i] = event_bins[i] + [j] break #spio.savemat('event_bins.mat',{'event_bins':event_bins}) true_behavior = [] else: event_bins = [] true_behavior = interpolate.interp1d( behavior["time"], behavior["linear_position"],\ kind='linear', axis=0 ) ( test_binned.center ) print "get #bins in each event done" else: event_bins = len(test_binned) true_behavior = [] return test_binned, event_bins, n_spikes_all, true_behavior