示例#1
0
def getTrainSet(behavior, ephys, config):
    """
    get train set from run epoch based on configuration

    Args:
        config - configuration object
        behavior - behavior data during run epoch
        ephys - spike data of run epoch
    returns:
        train - train segments
        training_time - total time of the training dataset
    """
    print "###################################################################################"
    print "Train set:"
    # select run segments
    print "run speed={}, min run duration={}, binsize={}".format(
        config.run_speed, config.min_run_duration, config.bin_size_run)
    run = seg.fromlogical(behavior["speed"] > config.run_speed,
                          x=behavior["time"])
    run = run[run.duration > config.min_run_duration]

    # split segments into independent bins
    bin_size_test = config.bin_size_run
    run_binned = run.split(size=bin_size_test)
    r_data_not_to_use = config.r_data_not_to_use
    first_bin_idx = int(r_data_not_to_use * len(run_binned))
    used_data = run_binned[first_bin_idx:]
    r_train = config.r_train
    n_train = int(r_train * len(used_data))
    train = used_data[0:n_train]

    total_time = ephys[u'TT1']['spike_times'][-1] - ephys[u'TT1'][
        'spike_times'][0]
    print "total recording time = {} min, start time = {} s, end time = {} s".format(
        total_time / 60, ephys[u'TT1']['spike_times'][0],
        ephys[u'TT1']['spike_times'][-1])
    training_time = np.sum(train.duration)
    print "training_time={} s".format(training_time)

    # find number of spikes within train datasets
    n_spikes = []
    sum_tt = []
    mean_tt = []
    no_spike_bin_idx = []
    spike_bin_idx = []
    n_spikes_all = np.zeros(len(train))
    for i, key in enumerate(ephys):
        tt = ephys[key]
        n_spikes.append(train.contains(tt['spike_times'])[1])
        n_spikes_all = n_spikes_all + n_spikes[i]
        sum_tt.append(sum(n_spikes[i]))
    # remove bins without spike
    for j, n in enumerate(n_spikes_all):
        if n == 0:
            no_spike_bin_idx.append(j)
        else:
            spike_bin_idx.append(j)
    if len(no_spike_bin_idx) > 0:
        train = train[spike_bin_idx]

    print "number spikes for training:"
    print np.sum(sum_tt)
    print ""
    return train, training_time
示例#2
0
def getTestSet(behavior,
               ephys,
               config,
               event=[],
               replay=False,
               rm_no_spk=True,
               count_bins_each_event=True):
    """
    get train set from run epoch based on configuration

    Args:
        config - configuration object
        behavior - behavior data during run epoch
        ephys - spike data of run/sleep epoch
        event - replay events
        replay - whether test replay data or run data
                 run : False
                 replay: True
    returns:
        test_binned - test segments
        event_bins - number of spikes in each event
        n_spikes_all - number of spikes in each bin 
    """
    if replay:
        test = seg(event['postNREMevent'])
        bin_size_test = config.bin_size_sleep
    else:
        run = seg.fromlogical(behavior["speed"] > config.run_speed,
                              x=behavior["time"])
        run = run[run.duration > config.min_run_duration]
        # split segments into independent bins
        bin_size_test = config.bin_size_run
        run_binned = run.split(size=bin_size_test)
        r_data_not_to_use = config.r_data_not_to_use
        first_bin_idx = int(r_data_not_to_use * len(run_binned))
        used_data = run_binned[first_bin_idx:]
        r_train = config.r_train
        n_train = int(r_train * len(used_data))
        test = used_data[n_train:]
        bin_size_test = config.bin_size_run

    print "###################################################################################"
    print "Test set:"
    testing_time = np.sum(test.duration)
    print "testing_time={} s".format(testing_time)
    test_binned = test.split(size=bin_size_test)
    print "binsize={}, test bins = {}".format(bin_size_test, len(test_binned))

    # find number of spikes within test datasets
    n_spikes = []
    sum_tt = []
    mean_tt = []
    no_spike_bin_idx = []
    spike_bin_idx = []
    n_spikes_all = np.zeros(len(test_binned))
    max_bin = 0
    max_spike = 0

    for i, key in enumerate(ephys):
        tt = ephys[key]
        n_spikes.append(test_binned.contains(tt['spike_times'])[1])
        n_spikes_all = n_spikes_all + n_spikes[i]
        sum_tt.append(sum(n_spikes[i]))
        mean_tt.append(np.mean(n_spikes[i]))
    print "get spike count done"
    n_spikes_bin = []
    for i in range(len(n_spikes_all)):
        n_spikes_bin.append(
            np.asarray([n_spikes[j][i] for j in range(len(ephys))],
                       dtype=np.int32))
        if n_spikes_all[i] > max_spike:
            max_spike = n_spikes_all[i]
            max_bin = i
    print "get spike count per bin done"
    if rm_no_spk:
        for j, n in enumerate(n_spikes_all):
            if n == 0:
                no_spike_bin_idx.append(j)
            else:
                spike_bin_idx.append(j)
        if len(no_spike_bin_idx) > 0:
            test_binned = test_binned[spike_bin_idx]
        print "{} no spike bins removed".format(len(no_spike_bin_idx))

    testing_time = np.sum(test_binned.duration)
    print "testing_time={} s".format(testing_time)
    print ""

    if count_bins_each_event:
        # get number of bins in each event
        if replay:
            event_bins = [[]] * len(test)
            for j, evnt_bin in enumerate(test_binned):
                for i, env in enumerate(test):
                    if env[0] <= evnt_bin[0] and env[1] >= evnt_bin[1]:
                        event_bins[i] = event_bins[i] + [j]
                        break
            #spio.savemat('event_bins.mat',{'event_bins':event_bins})
            true_behavior = []
        else:
            event_bins = []
            true_behavior = interpolate.interp1d( behavior["time"], behavior["linear_position"],\
                    kind='linear', axis=0 ) ( test_binned.center )
        print "get #bins in each event done"
    else:
        event_bins = len(test_binned)
        true_behavior = []

    return test_binned, event_bins, n_spikes_all, true_behavior