示例#1
0
def test_delete(save_location, compare_to):
    dat = DataHandler('tests')

    dat.delete(save_location=save_location)

    exists = dat.check_group_exists(location=save_location, create=False)
    assert exists == compare_to
示例#2
0
def test_load_no_group():
    dat = DataHandler('tests')
    save_location = 'not_a_location'

    with pytest.raises(NameError):
        loaded = dat.load(parameters=parameters,
                          save_location=save_location)
示例#3
0
def load_and_process(db_name,
                     save_location,
                     parameters,
                     interpolated_samples=100):
    #TODO: move interpolated samples is None check out of interpolation
    # function and add it here, no sense in having it check in the function
    # and have it do nothing, should only call interpolate if interpolating
    """
    Loads the parameters from the save_location,
    returns a dictionary of the interpolated and sampled data

    NOTE: if interpolated_samples is set to None, the raw data will be
    returned without interpolation and sampling

    Parameters
    ----------
    db_name: string
        the database where the data is saved
    save_location: string
        points to the location in the hdf5 database to read from
    parameters: list of strings
        the parameters to load and interpolate
    interpolated_samples: positive int, Optional (Default=100)
        the number of samples to take (evenly) from the interpolated data
        if set to None, no interpolated or sampling will be done, the raw
        data will be returned
    """
    # load data from hdf5 database
    dat = DataHandler(db_name=db_name)
    data = dat.load(parameters=parameters, save_location=save_location)

    # If time is not passed in, create a range from 0 to the length of any
    # other parameter in the list that is not time. This assumes that any
    # data passed in at once will be of the same length

    data_len = len(data[list(data.keys())[0]])
    if 'time' not in parameters:
        data['time'] = np.ones(data_len)

    total_time = np.sum(data['time'])
    dat = []

    # interpolate for even sampling and save to our dictionary
    if interpolated_samples is not None:
        for key in data:
            if key != 'time':
                data[key] = interpolate_data(
                    data=data[key],
                    time_intervals=data['time'],
                    interpolated_samples=interpolated_samples)
    else:
        interpolated_samples = data_len

    # since we are interpolating over time, we are not interpolating
    # the time data, instead evenly sample interpolated_samples from
    # 0 to the sum(time)
    data['cumulative_time'] = np.linspace(0, total_time, interpolated_samples)

    data['read_location'] = save_location
    return data
示例#4
0
def test_group_exists(location, create, compare_to):
    dat = DataHandler('tests')

    exists = dat.check_group_exists(
        location=location,
        create=create)
    assert exists == compare_to
示例#5
0
def test_statistical_error(ideal,
                           save_data,
                           regen,
                           save_location=save_location):
    dat = DataHandler("test")
    if ideal is None:
        # this is done in calculate_error, but we do it here so our manual
        # comparison is loading the same data
        ideal = "ideal_trajectory"
    fake_data = dat.load(parameters=[ideal, "ee_xyz"],
                         save_location=save_location)
    manual_error = np.linalg.norm((fake_data["ee_xyz"] - fake_data[ideal]),
                                  axis=1)

    traj = TrajectoryError(db_name="test",
                           time_derivative=0,
                           interpolated_samples=None)
    data = traj.statistical_error(
        save_location=save_location.split("/")[0],
        ideal=ideal,
        sessions=1,
        runs=1,
        save_data=save_data,
        regen=regen,
    )
示例#6
0
def test_rename_old_save_location_does_not_exist():
    old_save_location = 'test_old_does_not_exist'
    new_save_location = 'test_new'
    with pytest.raises(KeyError):
        dat = DataHandler('tests')
        # try to read from key that doesn't exist
        dat.rename(old_save_location=old_save_location,
                   new_save_location=new_save_location,
                   delete_old=False)
示例#7
0
def test_load(parameters, compare_to, key):
    dat = DataHandler('tests')
    save_location = 'test_loading'
    test_data = {'test_data': np.ones(3)}

    dat.save(
        data=test_data,
        save_location='test_loading',
        overwrite=True)

    loaded = dat.load(parameters=parameters,
                      save_location=save_location)

    assert np.all(loaded[key] == compare_to)
示例#8
0
def test_calculate_error(ideal, save_location=save_location):
    dat = DataHandler('test')
    if ideal is None:
        ideal = 'ideal_trajectory'
    fake_data = dat.load(parameters=[ideal, 'ee_xyz'],
                         save_location=save_location)
    manual_error = np.linalg.norm((fake_data['ee_xyz'] - fake_data[ideal]),
                                  axis=1)

    traj = TrajectoryError(db_name='test',
                           time_derivative=0,
                           interpolated_samples=None)
    data = traj.calculate_error(save_location=save_location, ideal=ideal)

    assert np.array_equal(manual_error, data['error'])
示例#9
0
 def __init__(self, db_name, save_location):
     """
     PARAMETERS
     ----------
     db_name: string
         the name of the database that holds the intercept scans
     save_location: string
         the location in the database the data is saved
     """
     self.save_location = save_location
     self.dat = DataHandler(db_name)
     self.fig = Figure()  # figsize=(10,12), dpi=100)
     self.a = self.fig.add_subplot(111)
     self.ideal = self.dat.load(parameters=["ideal"],
                                save_location=self.save_location)["ideal"]
     self.test_que = []
示例#10
0
def test_rename(old_save_location, new_save_location, delete_old, compare_to):
    dat = DataHandler('tests')
    # save data to rename / move
    dat.save(data={'float': 3.14},
             save_location=old_save_location,
             overwrite=True)
    # make sure the new entry key is available
    dat.delete(save_location=new_save_location)
    # rename to new key
    dat.rename(old_save_location=old_save_location,
               new_save_location=new_save_location,
               delete_old=delete_old)

    # check if the old location exists
    exists = dat.check_group_exists(location=old_save_location, create=False)
    assert exists == compare_to

    # check if the new location exists
    exists = dat.check_group_exists(location=new_save_location, create=False)
    assert exists is True
 def __init__(self, db_name, time_derivative=0, interpolated_samples=100):
     '''
     PARAMETERS
     ----------
     db_name: string
         the name of the database to load data from
     interpolated_samples: positive int, Optional (Default=100)
         the number of samples to take (evenly) from the interpolated data
         if set to None, no interpolated or sampling will be done, the raw
         data will be returned
     time_derivative: int, Optional (Default: 0)
         0: position
         1: velocity
         2: acceleration
         3: jerk
     '''
     # instantiate our data processor
     self.dat = DataHandler(db_name)
     self.db_name = db_name
     self.time_derivative = time_derivative
     self.interpolated_samples = interpolated_samples
示例#12
0
def load_and_process(interpolated_samples, parameters):
    dat = DataHandler("tests")
    loc = "fake_trajectory"
    steps = 147
    fake_traj_data = random_trajectories.generate(steps=steps, plot=False)
    dat.save(data=fake_traj_data, save_location="fake_trajectory", overwrite=True)

    data = proc.load_and_process(
        db_name="tests",
        save_location=loc,
        parameters=parameters,
        interpolated_samples=interpolated_samples,
    )

    if interpolated_samples is None:
        interpolated_samples = steps

    for key in parameters:
        if key == "time":
            key = "cumulative_time"
        assert len(data[key]) == interpolated_samples
示例#13
0
def convert(npz_loc, db_name, save_location, overwrite=False):
    """
    accepts a npz file location and saves its data to the database at the
    specified save_location

    PARAMETERS
    ----------
    npz_loc: string
        location and name of the npz file
    db_name: string
        database to save data to
    save_location: string
        save location in the database
    overwrite: boolean, Optional (Default: False)
        True to overwrite previous data
        NOTE this will be triggered if data is being saved to the same hdf5
        group (folder), not necessarily the same key. In this case you will
        need to set it to True. Other data will not be erased, only data with
        the same keys will be overwritten
    """
    dat = DataHandler(db_name)
    npz = np.load(npz_loc)
    keys = npz.keys()
    new_data = {}
    for key in keys:
        new_data[key] = npz[key]
    dat.save(data=new_data, save_location=save_location, overwrite=overwrite)
    keys = dat.get_keys(save_location)
    data = dat.load(parameters=keys, save_location=save_location)
示例#14
0
def test_rename_dataset():
    old_save_location = 'test_saving'
    new_save_location = 'test_saving_moved'

    dat = DataHandler('tests')
    # save data to rename / move
    dat.save(data={'float': 3.14},
             save_location=old_save_location,
             overwrite=True)
    # make sure the new entry key is available
    dat.delete(save_location=new_save_location)

    dat.rename(old_save_location=old_save_location + '/int',
               new_save_location=new_save_location + '/int',
               delete_old=False)

    # check if the old location exists
    exists = dat.check_group_exists(location=old_save_location, create=False)
    assert exists is True

    # check if the new location exists
    exists = dat.check_group_exists(location=new_save_location, create=False)
    assert exists is True
示例#15
0
def test_rename_new_save_location_exists():
    old_save_location = 'test_rename'
    new_save_location = 'test_already_exists'
    with pytest.raises(Exception):
        dat = DataHandler('tests')
        # save data to rename / move
        dat.save(data={'float': 3.14},
                 save_location=old_save_location)
        # create data at new location
        dat.save(data={'float': 3.14},
                 save_location=new_save_location,
                 overwrite=True)
        # try to rename data onto existing key
        dat.rename(old_save_location=old_save_location,
                   new_save_location=new_save_location,
                   delete_old=False)
示例#16
0
def test_get_keys():
    dat = DataHandler('tests')

    # location exists
    keys = dat.get_keys(save_location='test_loading')

    # location doesn't exist
    dat.delete(save_location='fake_location')
    with pytest.raises(KeyError):
        keys = dat.get_keys(save_location='fake_location')
示例#17
0
def test_calc_cartesion_points():
    db = "tests"
    dat = DataHandler(db)

    class fake_robot_config:
        def __init__(self):
            self.N_JOINTS = 3
            self.N_LINKS = 2

        def Tx(self, name, q):
            assert len(q) == self.N_JOINTS
            return [1, 2, 3]

    robot_config = fake_robot_config()

    # number of time steps
    steps = 10

    # passing in the right dimensions of joint angles
    q = np.zeros((steps, robot_config.N_JOINTS))
    expected_shape = [
        [steps, robot_config.N_JOINTS, 3],
        [steps, robot_config.N_LINKS, 3],
        [steps, 3],
    ]

    data = proc.calc_cartesian_points(robot_config=robot_config, q=q)

    # catch error in the assertion of the functions output's shape
    for ii in range(0, len(expected_shape)):
        for jj in range(0, len(np.array(data[ii]).shape)):
            assert (
                np.array(data[ii]).shape[jj] == expected_shape[ii][jj],
                (
                    "Expected %i Received %i"
                    % (expected_shape[ii][jj], np.asarray(data[ii]).shape[jj])
                ),
            )
def review(save_name, ideal_function, num_to_plot=10):
    '''
    loads the data from save name and gets num_to_plot tests that most
    closley match the ideal function that was passed in during the scan

    PARAMETERS
    ----------
    save_name: string
        the location in the intercepts_scan database to load from
    ideal_fuinction: lambda function(n_timesteps)
        used as the desired profile to compare against. The review function
        will use this to find the closest matching results
    num_to_plot: int, Optional (Default: 10)
        the number of tests to find that most closley match the ideal
    '''
    dat = DataHandler('intercepts_scan')

    ideal_data = dat.load(parameters=['ideal', 'total_intercepts'],
                          save_location='%s' % save_name)
    ideal = ideal_data['ideal']
    num = ideal_data['total_intercepts']

    if num_to_plot > num:
        print('Only %i runs to plot' % num)
        num_to_plot = num

    run_data = []
    errors = []
    n_bins = 30
    for ii in range(0, num):
        data = dat.load(
            parameters=['intercept_bounds', 'intercept_mode',
                        'y', 'error', 'num_active',
                        'num_inactive', 'title'],
            save_location='%s/%05d' % (save_name, ii))

        if data['title'] == 'proportion_time_neurons_active':
            y, bins_out = np.histogram(np.squeeze(data['y']),
                                       bins=np.linspace(0, 1, n_bins))
            data['x'] = 0.5*(bins_out[1:]+bins_out[:-1])
            data['y'] = y
        else:
            data['x'] = np.cumsum(np.ones(len(data['y'])))

        ideal = [ideal_function(x) for x in data['x']]
        diff_to_ideal = ideal - data['y']
        error = np.sum(np.abs(diff_to_ideal))

        run_data.append(data)
        errors.append(error)

    indices = np.array(errors).argsort()[:num_to_plot]
    print('Plotting...')
    plt.figure()
    for ii in range(0, num_to_plot):
        ind = indices[ii]
        data = run_data[ind]
        if data['title'] == 'proportion_time_neurons_active':
            plt.bar(data['x'], data['y'], width=1/(2*n_bins),
                    edgecolor='white', alpha=0.5,
                    label=('%i: err:%.2f \n%s: %s' %
                           (ind, errors[ind], data['intercept_bounds'],
                            data['intercept_mode'])))
        else:
            plt.plot(np.squeeze(data['x']), np.squeeze(data['y']),
                     label=('%i: err:%.2f \n%s: %s' %
                            (ind, errors[ind], data['intercept_bounds'],
                             data['intercept_mode'])))

    plt.title(data['title'])
    plt.plot(np.squeeze(data['x']), ideal, c='k', lw=3, linestyle='--',
             label='ideal')
    plt.legend()
示例#19
0
def review(save_name, ideal_function, num_to_plot=10, db_name="intercepts_scan"):
    """
    loads the data from save name and gets num_to_plot tests that most
    closley match the ideal function that was passed in during the scan

    PARAMETERS
    ----------
    save_name: string
        the location in the intercepts_scan database to load from
    ideal_fuinction: lambda function(n_timesteps)
        used as the desired profile to compare against. The review function
        will use this to find the closest matching results
    num_to_plot: int, Optional (Default: 10)
        the number of tests to find that most closley match the ideal
    """
    dat = DataHandler(db_name)

    ideal_data = dat.load(
        parameters=["ideal", "total_intercepts"], save_location="%s" % save_name
    )
    ideal = ideal_data["ideal"]
    num = ideal_data["total_intercepts"]

    if num_to_plot > num:
        print("Only %i runs to plot" % num)
        num_to_plot = num

    run_data = []
    errors = []
    n_bins = 30
    for ii in range(0, num):
        data = dat.load(
            parameters=[
                "intercept_bounds",
                "intercept_mode",
                "y",
                "error",
                "num_active",
                "num_inactive",
                "title",
            ],
            save_location="%s/%05d" % (save_name, ii),
        )

        if data["title"] == "proportion_time_neurons_active":
            y, bins_out = np.histogram(
                np.squeeze(data["y"]), bins=np.linspace(0, 1, n_bins)
            )
            data["x"] = 0.5 * (bins_out[1:] + bins_out[:-1])
            data["y"] = y
        else:
            data["x"] = np.cumsum(np.ones(len(data["y"])))

        ideal = [ideal_function(x) for x in data["x"]]
        diff_to_ideal = ideal - data["y"]
        error = np.sum(np.abs(diff_to_ideal))

        run_data.append(data)
        errors.append(error)

    indices = np.array(errors).argsort()[:num_to_plot]
    print("Plotting...")
    plt.figure()
    for ii in range(0, num_to_plot):
        ind = indices[ii]
        data = run_data[ind]
        if data["title"] == "proportion_time_neurons_active":
            plt.bar(
                data["x"],
                data["y"],
                width=1 / (2 * n_bins),
                edgecolor="white",
                alpha=0.5,
                label=(
                    "%i: err:%.2f \n%s: %s"
                    % (
                        ind,
                        errors[ind],
                        data["intercept_bounds"],
                        data["intercept_mode"],
                    )
                ),
            )
        else:
            plt.plot(
                np.squeeze(data["x"]),
                np.squeeze(data["y"]),
                label=(
                    "%i: err:%.2f \n%s: %s"
                    % (
                        ind,
                        errors[ind],
                        data["intercept_bounds"],
                        data["intercept_mode"],
                    )
                ),
            )

    plt.title(data["title"])
    plt.plot(np.squeeze(data["x"]), ideal, c="k", lw=3, linestyle="--", label="ideal")
    plt.legend()
示例#20
0
import numpy as np
import pytest

from abr_analyze.plotting import TrajectoryError
from abr_analyze.data_handler import DataHandler
from abr_analyze.utils import random_trajectories

dat = DataHandler('test')
save_location = 'traj_err_test/session000/run000'
# generate a random trajectory and an ideal
data = random_trajectories.generate(steps=100)
# generate another trajectory and ideal so we can test passing in a custom ideal
data_alt = random_trajectories.generate(steps=100)
# save the second ideal to the first data dict
data['alt_traj'] = data_alt['ideal_trajectory']
dat.save(data=data, save_location=save_location, overwrite=True)


@pytest.mark.parametrize('ideal', ((None), ('ideal_trajectory'), ('alt_traj')))
def test_calculate_error(ideal, save_location=save_location):
    dat = DataHandler('test')
    if ideal is None:
        ideal = 'ideal_trajectory'
    fake_data = dat.load(parameters=[ideal, 'ee_xyz'],
                         save_location=save_location)
    manual_error = np.linalg.norm((fake_data['ee_xyz'] - fake_data[ideal]),
                                  axis=1)

    traj = TrajectoryError(db_name='test',
                           time_derivative=0,
                           interpolated_samples=None)
def run(encoders, intercept_vals, input_signal, seed=1,
        db_name='intercepts_scan', save_name='example', notes='',
        analysis_fncs=None, **kwargs):
    '''
    runs a scan for the proportion of neurons that are active over time

    PARAMETERS
    ----------
    encoders: array of floats (n_neurons x n_inputs)
        the values that specify along what vector a neuron will be
        sensitive to
    intercept_vals: array of floats (n_intercepts to try x 3)
        the [left_bound, mode, right_bound] to pass on to the triangluar
        intercept function in network_utils
    input_signal: array of floats (n_timesteps x n_inputs)
        the input signal that we want to check our networks response to
    seed: int
        the seed used for any randomization in the sim
    save_name: string, Optional (Default: proportion_neurons)
        the name to save the data under in the intercept_scan database
    notes: string, Optional (Default: '')
        any additional notes to save with the scan
    analysis_fncs: list of network_utils functions to apply to the spike trains
        the function must accept network and input signal, and return a list of
        data and activity
    '''
    if not isinstance(analysis_fncs, list):
        analysis_fncs = [analysis_fncs]

    print('Input Signal Shape: ', np.asarray(input_signal).shape)

    loop_time = 0
    elapsed_time = 0
    for ii, intercept in enumerate(intercept_vals):
        start = timeit.default_timer()
        elapsed_time += loop_time
        print('%i/%i | ' % (ii+1, len(intercept_vals))
              + '%.2f%% Complete | ' % (ii/len(intercept_vals)*100)
              + '%.2f min elapsed | ' % (elapsed_time/60)
              + '%.2f min for last sim | ' % (loop_time/60)
              + '~%.2f min remaining...'
              % ((len(intercept_vals)-ii)*loop_time/60),
              end='\r')

        # create our intercept distribution from the intercepts vals
        # Generates intercepts for a d-dimensional ensemble, such that, given a
        # random uniform input (from the interior of the d-dimensional ball), the
        # probability of a neuron firing has the probability density function given
        # by rng.triangular(left, mode, right, size=n)
        np.random.seed(seed)
        triangular = np.random.triangular(
            # intercept_vals = [left, right, mode]
            left=intercept[0],
            right=intercept[1],
            mode=intercept[2],
            size=encoders.shape[1],
        )
        intercepts = nengo.dists.CosineSimilarity(encoders.shape[2] + 2).ppf(1 - triangular)
        intercept_list = intercepts.reshape((1, encoders.shape[1]))

        print()
        print(intercept)
        print(intercept_list)

        # create a network with the new intercepts
        network = signals.DynamicsAdaptation(
            n_input=encoders.shape[2],
            n_output=1,  # number of output is irrelevant
            n_neurons=encoders.shape[1],
            intercepts=intercept_list,
            seed=seed,
            encoders=encoders,
            **kwargs)

        # get the spike trains from the sim
        spike_trains = network_utils.get_activities(
            network=network, input_signal=input_signal,
            synapse=0.005)

        # loop through the analysis functions
        for func in analysis_fncs:
            func_name = func.__name__
            y, activity = func(network=network, input_signal=input_signal,
                               pscs=spike_trains)

            # get the number of active and inactive neurons
            num_active, num_inactive = (
                network_utils.n_neurons_active_and_inactive(activity=activity))

            if ii == 0:
                dat = DataHandler(db_name)
                dat.save(
                    data={'total_intercepts': len(intercept_vals),
                          'notes': notes},
                    save_location='%s/%s' % (save_name, func_name),
                    overwrite=True)

            # not saving activity because takes up a lot of disk space
            data = {'intercept_bounds': intercept[:2],
                    'intercept_mode': intercept[2],
                    'y': y,
                    'num_active': num_active,
                    'num_inactive': num_inactive,
                    'title': func_name
                    }
            dat.save(data=data, save_location='%s/%s/%05d' %
                     (save_name, func_name, ii), overwrite=True)

            loop_time = timeit.default_timer() - start
class LiveFigure():
    '''
    Class that handles the live plotting of the figure in the gui
    '''
    def __init__(self, db_name, save_location):

        '''
        PARAMETERS
        ----------
        db_name: string
            the name of the database that holds the intercept scans
        save_location: string
            the location in the database the data is saved
        '''
        self.save_location = save_location
        self.dat = DataHandler(db_name)
        self.fig = Figure()#figsize=(10,12), dpi=100)
        self.a = self.fig.add_subplot(111)
        self.ideal = self.dat.load(
                parameters=['ideal'],
                save_location=self.save_location)['ideal']
        self.test_que = []

    def plot(self,i):
        '''
        plots the data specified in the global variables that get updated with
        the functions linked to the gui buttons
        '''
        global save_val
        global keep_test_val
        global toggle_ideal_val
        global clear_plot_val
        global legend_loc_val
        global update_plot
        global plt_col

        if update_plot:
            update_plot = False
            if clear_plot_val:
                data = None
                self.test_que = []
                a = self.fig.add_subplot(111)
                a.clear()
                clear_plot_val = False
            else:
                intercept_keys = self.get_intercept_vals_from_buttons()
                data = self.load(keys=intercept_keys)

                if data is not None:
                    a = self.fig.add_subplot(111)
                    a.clear()
                    label = '(%.1f, %.1f), %.1f\nACTV:%i | INACTV:%i | %.2f'%(
                        data['intercept_bounds'][0],
                        data['intercept_bounds'][1],
                        data['intercept_mode'],
                        data['num_active'],
                        data['num_inactive'],
                        data['y_mean']
                        )
                    if data['title'] == 'proportion_time_neurons_active':
                        a.set_xlim(0, 1)
                        a.bar(data['x'], data['y'], width=1/(2*self.bins),
                              label=label, edgecolor='white',
                              alpha=0.5, color=plt_col[0])
                    else:
                        a.set_ylim(0, 1)
                        a.plot(data['x'], data['y'], label=label, c=plt_col[0])
                    a.set_title(data['title'])

                    if keep_test_val:
                        current_test = {'label': label, 'y': data['y'],
                                'x': data['x']}
                        self.test_que.append(current_test)
                        keep_test_val = False

                    if self.test_que:
                        for aa, test in enumerate(self.test_que):
                            if aa+1 >= len(plt_col):
                                plt_col.append(None)
                            if data['title'] == 'proportion_time_neurons_active':
                                a.bar(test['x'], test['y'],
                                      width=1/(2*self.bins),
                                      label=test['label'],
                                      edgecolor='white', alpha=0.5,
                                      color=plt_col[aa+1])
                            else:
                                a.plot(test['x'], test['y'],
                                        label=test['label'], c=plt_col[aa+1])
                    a.legend(loc=legend_loc_val%4+1)

                    if save_val:
                        a.figure.savefig('%s/intercept_scan_%s.png'
                                         % (figures_dir, data['title']))
                        msg = ('Figure saved to:'
                                + ' %s/intercept_scan_%s.png'
                                % (figures_dir, data['title']))
                        print(msg)
                        #TODO: make the font smaller while this is printed out
                        save_val = False


    def get_intercept_vals_from_buttons(self):
        '''
        updates the intercept values that are linked to a specific test based
        on the values in the text boxes
        '''
        global intercept_vals
        intercept_keys = []
        for ii, val in enumerate(intercept_vals):
            intercept_keys.append(val.get())
        return intercept_keys

    def load(self, keys):
        '''
        loads the required data for plotting from the test specified by the
        provided keys

        PARAMETERS
        ----------
        keys: list of 3 strings
            this list is created in the intercept scanner page and it links the
            test numbers to their intercept ranges and mode
        '''
        global key_dict
        global valid_val
        self.bins = 30
        try:
            test_num = int(key_dict[keys[0]][keys[1]][keys[2]])
            data = self.dat.load(
                    parameters=['intercept_bounds', 'intercept_mode', 'y',
                        'title', 'num_active', 'num_inactive'],
                    save_location='%s/%05d'%(self.save_location, test_num))
            # get the mean before converting y to a histogram
            data['y_mean'] = np.mean(data['y'])
            if data['title'] == 'proportion_time_neurons_active':
                y, bins_out = np.histogram(
                    np.squeeze(data['y']), bins=np.linspace(0, 1, self.bins))
                data['x'] = 0.5*(bins_out[1:]+bins_out[:-1])
                data['y'] = y
            else:
                data['x'] = np.cumsum(np.ones(len(data['y'])))
            valid_val.set('Valid')
            return data
        except:
            print('Test does not exist')
            valid_val.set('Invalid')
            return None
class TrajectoryError():
    def __init__(self, db_name, time_derivative=0, interpolated_samples=100):
        '''
        PARAMETERS
        ----------
        db_name: string
            the name of the database to load data from
        interpolated_samples: positive int, Optional (Default=100)
            the number of samples to take (evenly) from the interpolated data
            if set to None, no interpolated or sampling will be done, the raw
            data will be returned
        time_derivative: int, Optional (Default: 0)
            0: position
            1: velocity
            2: acceleration
            3: jerk
        '''
        # instantiate our data processor
        self.dat = DataHandler(db_name)
        self.db_name = db_name
        self.time_derivative = time_derivative
        self.interpolated_samples = interpolated_samples

    def statistical_error(self,
                          save_location,
                          ideal=None,
                          sessions=1,
                          runs=1,
                          save_data=True,
                          regen=False):
        '''
        calls the calculate error function to get the trajectory for all runs
        and sessions specified at the save location and calculates the mean
        error and confidence intervals

        PARAMETERS
        ----------
        save_location: string
            location of data in database
        ideal: string, Optional (Default: None)
            This tells the function what trajectory to calculate the error
            relative to
            None: use the saved filter data at save_location
            if string: key of Nx3 data in database to use
        sessions: int, Optional (Default: 1)
            the number of sessions to calculate error for
        runs: int, Optional (Default: 1)
            the number of runs in each session
        save_data: boolean, Optional (Default: True)
            True to save data, this saves the error for each session
        regen: boolean, Optional (Default: False)
            True to regenerate data
            False to load data if it exists
        '''
        if regen is False:
            exists = self.dat.check_group_exists(
                '%s/statistical_error_%s' %
                (save_location, self.time_derivative))
            if exists:
                ci_errors = self.dat.load(
                    parameters=[
                        'mean', 'upper_bound', 'lower_bound', 'ee_xyz',
                        'ideal_trajectory', 'time', 'time_derivative',
                        'read_location', 'error'
                    ],
                    save_location='%s/statistical_error_%i' %
                    (save_location, self.time_derivative))

                # still using as boolean, just a python cheatcode
                exists = len(ci_errors['mean'])
        else:
            exists = False

        if not exists:
            errors = []
            for session in range(sessions):
                session_error = []
                for run in range(runs):
                    print('%.3f processing complete...' %
                          (100 * ((run + 1) + (session * runs)) /
                           (sessions * runs)),
                          end='\r')
                    loc = ('%s/session%03d/run%03d' %
                           (save_location, session, run))
                    data = self.calculate_error(save_location=loc, ideal=ideal)
                    session_error.append(np.sum(data['error']))
                errors.append(session_error)

            ci_errors = proc.get_mean_and_ci(raw_data=errors)
            ci_errors['time_derivative'] = self.time_derivative

            if save_data:
                self.dat.save(data=ci_errors,
                              save_location='%s/statistical_error_%i' %
                              (save_location, self.time_derivative),
                              overwrite=True)

        return ci_errors

    def calculate_error(self, save_location, ideal=None):
        '''
        loads the ee_xyz data from save_location and compares it to ideal. If
        ideal is not passed in, it is assuemed that a filtered path planner is
        saved in save_location under the key 'ideal_trajectory' and will be
        used as the reference for the error calculation. The data is loaded,
        interpolated, and differentiated. The two norm error is returned.

        the following dict is returned
        data = {
            'ee_xyz': list of end-effector positions (n_timesteps, xyz),
            'ideal_trajectory': list of path planner positions
                shape of (n_timesteps, xyz)
            'time': list of timesteps (n_timesteps),
            'time_derivative': int, the order of differentiation applied,
            'read_location': string, the location the raw data was loaded from,
            'error': the two-norm error between the end-effector trajectory and
                the path planner followed that run

        PARAMETERS
        ----------
        save_location: string
            location of data in database
        ideal: string, Optional (Default: None)
            This tells the function what trajectory to calculate the error
            relative to
            None: use the saved filter data at save_location
            if string: key of Nx3 data in database to use
        '''
        if ideal is None:
            ideal = 'ideal_trajectory'
        parameters = ['ee_xyz', 'time', ideal]

        # load and interpolate data
        data = proc.load_and_process(
            db_name=self.db_name,
            save_location=save_location,
            parameters=parameters,
            interpolated_samples=self.interpolated_samples)

        if ideal == 'ideal_trajectory':
            data['ideal_trajectory'] = data['ideal_trajectory'][:, :3]
        dt = np.sum(data['time']) / len(data['time'])

        # integrate data
        if self.time_derivative > 0:
            # set our keys that are able to be differentiated to avoid errors
            differentiable_keys = ['ee_xyz', 'ideal_trajectory']
            if ideal is not None:
                differentiable_keys.append(ideal)

            for key in data:
                if key in differentiable_keys:
                    # differentiate the number of times specified by
                    # time_derivative
                    for _ in range(0, self.time_derivative):
                        data[key][:, 0] = np.gradient(data[key][:, 0], dt)
                        data[key][:, 1] = np.gradient(data[key][:, 1], dt)
                        data[key][:, 2] = np.gradient(data[key][:, 2], dt)

        data['time_derivative'] = self.time_derivative
        data['read_location'] = save_location
        data['error'] = np.linalg.norm((data['ee_xyz'] - data[ideal]), axis=1)

        return data

    def plot(self,
             ax,
             save_location,
             step=-1,
             c=None,
             linestyle='--',
             label=None,
             loc=1,
             title='Trajectory Error to Path Planner'):

        data = self.dat.load(parameters=['mean', 'upper_bound', 'lower_bound'],
                             save_location='%s/statistical_error_%i' %
                             (save_location, self.time_derivative))
        vis.plot_mean_and_ci(ax=ax,
                             data=data,
                             c=c,
                             linestyle=linestyle,
                             label=label,
                             loc=loc,
                             title=title)
    def __init__(self, parent, db_name, save_location, *args):
        '''
        PARAMETERS
        ----------
        db_name: string
            the name of the database that holds the intercept scans
        save_location: string
            the location in the database the data is saved
        '''
        global key_dict
        global intercept_vals
        global save_val
        global keep_test_val
        global toggle_ideal_val
        global clear_plot_val
        global legend_loc_val
        global update_plot
        global plt_col

        save_val = False
        keep_test_val = False
        toggle_ideal_val = True
        clear_plot_val = False
        legend_loc_val = 1
        update_plot = True
        # set some step boundaries for possible values
        mode_step = 0.1
        bound_step = 0.1
        mode_range = [-0.9, 0.9]
        bound_range = [-0.9, 0.9]
        plt_col = ['r', 'b', 'y', 'k', 'tab:grey']

        # instanitate our item creating class
        self.create = GuiItems()

        # instantiate our gui parameter class
        self.pars = FontsAndColors()

        # instantiate our button function class
        self.button = ButtonFun()

        # instantiate our data loading class and get defaults
        self.dat = DataHandler(db_name=db_name)
        data = self.dat.load(
                parameters=['ideal', 'total_intercepts'],
                save_location=save_location)
        self.ideal = data['ideal']
        runs = data['total_intercepts']

        key_dict = {}
        for ii in range(0,runs):
            data = self.dat.load(
                    parameters=['intercept_bounds','intercept_mode'],
                    save_location='%s/%05d'%(save_location,ii))
            left_bound = '%.1f'%data['intercept_bounds'][0]
            right_bound = '%.1f'%data['intercept_bounds'][1]
            mode = '%.1f'%data['intercept_mode']

            # if first time with this left bound, save all the data now
            if left_bound not in key_dict:
                key_dict[left_bound] = {right_bound: {mode: '%05d'%ii}}
            # left bound exists, check if the right bound has already been saved
            elif right_bound not in key_dict[left_bound]:
                key_dict[left_bound][right_bound] = {mode: '%05d'%ii}
            # left and right bound combination already exist, check if mode saved
            elif mode not in key_dict[left_bound][right_bound]:
                key_dict[left_bound][right_bound][mode] = '%05d'%ii

        # instantiate our frame
        tk.Frame.__init__(self, parent)
        self.configure(background=self.pars.background_color)

        # instantiate the cells in our main frame
        # CELL 1: cell for our plot
        frame_plot = tk.Frame(self, parent)
        frame_plot.grid(row=0,column=0, padx=10)
        frame_plot.configure(background=self.pars.background_color)

        # CELL 2: cell for our save / hold / clear buttons
        frame_plot_buttons = tk.Frame(self, parent)
        frame_plot_buttons.grid(row=0, column=1, padx=10)
        frame_plot.configure(background=self.pars.background_color)

        # CELL 3: cell for our intercept setting buttons and text boxes
        frame_intercept_val = tk.Frame(self, parent)
        frame_intercept_val.grid(row=1, column=0, padx=10)
        frame_intercept_val.configure(background=self.pars.background_color)

        # CELL 4: frame for printouts / notes to user
        frame_notes = tk.Frame(self, parent)
        frame_notes.grid(row=1, column=1, padx=10)
        frame_notes.configure(background=self.pars.background_color)

        # create our plotting buttons
        keep_button = self.create.button(
            frame=frame_plot_buttons,
            text='Keep Test',
            function=lambda: self.button.keep_test(self),
            row=0, col=0)

        clear_button = self.create.button(
            frame=frame_plot_buttons,
            text='Clear Plot',
            function=lambda: self.button.clear_plot(self),
            row=1, col=0)

        # ideal_button = self.create.button(
        #     frame=frame_plot_buttons,
        #     text='Toggle Ideal',
        #     function=lambda: self.button.toggle_ideal(self),
        #     row=2, col=0)

        save_button = self.create.button(
            frame=frame_plot_buttons,
            text='Save',
            function=lambda: self.button.save(self),
            row=3, col=0)

        legend_loc_button = self.create.button(
            frame=frame_plot_buttons,
            text='Legend Loc',
            function=lambda: self.button.legend_loc(self),
            row=4, col=0)

        # def callback(*args):
        #     print('button press')
        # create our string variables for intercept values
        left_bound_val = tk.StringVar()
        #left_bound_val.trace('w', callback)
        left_bound_val.set(left_bound)
        right_bound_val = tk.StringVar()
        #right_bound_val.trace('w', callback)
        right_bound_val.set(right_bound)
        mode_val = tk.StringVar()
        #mode_val.trace('w', callback)
        mode_val.set(mode)

        # set our intercept_vals to a starting value
        intercept_vals = [left_bound_val, right_bound_val, mode_val]


        # create our intercept setting buttons
        left_bound_up = self.create.button(
            frame=frame_intercept_val,
            text='/\\',
            function=lambda: self.button.val_up(left_bound_val, bound_step),
            row=0, col=0)

        left_bound_down = self.create.button(
            frame=frame_intercept_val,
            text='\/',
            function=lambda: self.button.val_down(left_bound_val, bound_step),
            row=1, col=0)

        right_bound_up = self.create.button(
            frame=frame_intercept_val,
            text='/\\',
            function=lambda: self.button.val_up(right_bound_val, bound_step),
            row=0, col=4)

        right_bound_down = self.create.button(
            frame=frame_intercept_val,
            text='\/',
            function=lambda: self.button.val_down(right_bound_val, bound_step),
            row=1, col=4)

        mode_up = self.create.button(
            frame=frame_intercept_val,
            text='/\\',
            function=lambda: self.button.val_up(mode_val, mode_step),
            row=0, col=2)

        mode_down = self.create.button(
            frame=frame_intercept_val,
            text='\/',
            function=lambda: self.button.val_down(mode_val, mode_step),
            row=1, col=2)

        # create our intercept setting entry box
        left_bound_entry = self.create.entry_box(
                frame=frame_intercept_val,
                text_var=left_bound_val,
                row=0, col=1)

        right_bound_entry = self.create.entry_box(
                frame=frame_intercept_val,
                text_var=right_bound_val,
                row=0, col=5)

        mode_entry = self.create.entry_box(
                frame=frame_intercept_val,
                text_var=mode_val,
                row=0, col=3)

        # labels for entry boxes
        left_bound_label = self.create.label(
                frame=frame_intercept_val,
                text='Left',
                row=1, col=1)

        right_bound_label = self.create.label(
                frame=frame_intercept_val,
                text='Right',
                row=1, col=5)

        mode_label = self.create.label(
                frame=frame_intercept_val,
                text='Mode',
                row=1, col=3)

        # label that triggers if test with selected values does not exist
        global valid_val
        #global valid_col
        valid_val = tk.StringVar()
        #valid_col = tk.StringVar()
        valid_val.set('')
        #valid_col.set('green')
        valid_label = self.create.label(
                frame=frame_notes,
                textvariable=valid_val,
                row=1, col=1,
                font=self.pars.XL_FONT)

        # Plotting Window
        canvas = FigureCanvasTkAgg(live.fig, self)
        canvas.draw()
        canvas.get_tk_widget().grid(row=0, column=0)
        canvas.get_tk_widget().configure(background=self.pars.background_color)
示例#25
0
def run(
    intercept_vals,
    input_signal,
    seed=1,
    db_name="intercepts_scan",
    save_name="example",
    notes="",
    encoders=None,
    analysis_fncs=None,
    network_class=None,
    network_ens_type=None,
    force_params=None,
    angle_params=None,
    means=None,
    variances=None,
    **kwargs
):
    """
    runs a scan for the proportion of neurons that are active over time

    PARAMETERS
    ----------
    intercept_vals: array of floats (n_intercepts to try x 3)
        the [left_bound, mode, right_bound] to pass on to the triangluar
        intercept function in network_utils
    input_signal: array of floats (n_timesteps x n_inputs)
        the input signal that we want to check our networks response to
    seed: int
        the seed used for any randomization in the sim
    save_name: string, Optional (Default: proportion_neurons)
        the name to save the data under in the intercept_scan database
    notes: string, Optional (Default: '')
        any additional notes to save with the scan
    analysis_fncs: list of network_utils functions to apply to the spike trains
        the function must accept network and input signal, and return a list of
        data and activity
    """
    if network_class is None:
        network_class = signals.DynamicsAdaptation
    if not isinstance(analysis_fncs, list):
        analysis_fncs = [analysis_fncs]

    if network_ens_type == "force":
        n_neurons = force_params["n_neurons"]
        n_ensembles = force_params["n_ensembles"]
        n_input = force_params["n_input"]
    elif network_ens_type == "angle":
        n_neurons = angle_params["n_neurons"]
        n_ensembles = angle_params["n_ensembles"]
        n_input = angle_params["n_input"]
    elif network_ens_type is None:
        n_neurons = encoders.shape[1]
        n_ensembles = encoders.shape[0]
        n_input = encoders.shape[2]

    print("Running intercepts scan on %s" % network_class.__name__)
    print("Input Signal Shape: ", np.asarray(input_signal).shape)

    loop_time = 0
    elapsed_time = 0
    for ii, intercept in enumerate(intercept_vals):
        start = timeit.default_timer()
        elapsed_time += loop_time
        print(
            "%i/%i | " % (ii + 1, len(intercept_vals))
            + "%.2f%% Complete | " % (ii / len(intercept_vals) * 100)
            + "%.2f min elapsed | " % (elapsed_time / 60)
            + "%.2f min for last sim | " % (loop_time / 60)
            + "~%.2f min remaining..." % ((len(intercept_vals) - ii) * loop_time / 60),
            end="\r",
        )

        triangular = np.random.triangular(
            left=intercept[0],
            right=intercept[1],
            mode=intercept[2],
            size=n_neurons * n_ensembles,
        )
        intercepts = nengo.dists.CosineSimilarity(n_input + 2).ppf(1 - triangular)
        # intercepts = nengo.dists.CosineSimilarity(1000 + 2).ppf(1 - triangular)
        intercepts = intercepts.reshape((n_ensembles, n_neurons))

        force_params["intercepts"] = intercepts

        # get the spike trains from the sim
        network = network_class(
            force_params=force_params,
            angle_params=angle_params,
            means=means,
            variances=variances,
            seed=seed,
        )

        if network_ens_type == "force":
            network_ens = network.force_ens
            synapse = force_params["tau_output"]
        elif network_ens_type == "angle":
            network_ens = network.angle_ens
            synapse = angle_params["tau_output"]

        spike_trains = network_utils.get_activities(
            network=network,
            network_ens=network_ens,
            input_signal=input_signal,
            synapse=synapse,
        )

        for func in analysis_fncs:
            func_name = func.__name__
            y, activity = func(
                pscs=spike_trains, n_neurons=n_neurons, n_ensembles=n_ensembles
            )

            # get the number of active and inactive neurons
            num_active, num_inactive = network_utils.n_neurons_active_and_inactive(
                activity=activity
            )

            if ii == 0:
                dat = DataHandler(db_name)
                dat.save(
                    data={"total_intercepts": len(intercept_vals), "notes": notes},
                    save_location="%s/%s" % (save_name, func_name),
                    overwrite=True,
                )

            # not saving activity because takes up a lot of disk space
            data = {
                "intercept_bounds": intercept[:2],
                "intercept_mode": intercept[2],
                "y": y,
                "num_active": num_active,
                "num_inactive": num_inactive,
                "title": func_name,
            }
            dat.save(
                data=data,
                save_location="%s/%s/%05d" % (save_name, func_name, ii),
                overwrite=True,
            )

            loop_time = timeit.default_timer() - start
示例#26
0
def test_save(data, overwrite):
    save_location = 'test_saving'
    dat = DataHandler('tests')
    dat.save(data=data,
             save_location=save_location,
             overwrite=overwrite)
示例#27
0
def test_save_error():
    dat = DataHandler('tests')
    with pytest.raises(Exception):
        dat.save(data={'bool': True},
                 save_location='test_saving',
                 overwrite=False)
示例#28
0
def test_save_type_error():
    dat = DataHandler('tests')
    with pytest.raises(TypeError):
        dat.save(data=np.ones(10),
                 save_location='test_saving',
                 overwrite=True)