def test_returns_correct_distance_for_multiunits(self):
     a0 = neo.SpikeTrain(sp.array([1.0, 5.0, 7.0]) * pq.s, t_stop=8.0 * pq.s)
     a1 = neo.SpikeTrain(sp.array([1.0, 2.0, 5.0]) * pq.s, t_stop=8.0 * pq.s)
     b0 = neo.SpikeTrain(sp.array([2.0, 4.0, 5.0]) * pq.s, t_stop=8.0 * pq.s)
     b1 = neo.SpikeTrain(sp.array([3.0, 8.0]) * pq.s, t_stop=9.0 * pq.s)
     units = {0: [a0, a1], 1: [b0, b1]}
     reassignment_cost = 0.7
     expected = sp.array([[0.0, 4.4], [4.4, 0.0]])
     actual = stm.victor_purpura_multiunit_dist(units, reassignment_cost)
     assert_array_almost_equal(expected, actual)
import timeit


tau = 5.0 * pq.ms


def trains_as_multiunits(trains, trains_per_unit, num_units):
    units = {}
    for i in xrange(num_units):
        units[i] = trains[i * trains_per_unit:(i + 1) * trains_per_unit]
    return units


metrics = {
    'vp': (r'$D_{\mathrm{V}}$ (multi-unit)',
           lambda units: stm.victor_purpura_multiunit_dist(
               units, 0.7, 2.0 / tau)),
    'vr': (r'$D_{\mathrm{R}}$ (multi-unit)',
           lambda units: stm.van_rossum_multiunit_dist(units, 0.5, tau))}


def print_available_metrics():
    for key in metrics:
        print "%s  (%s)" % (key, metrics[key][0])


class BenchmarkData(object):
    def __init__(
            self, spike_count_range, train_count_range, num_units_range,
            firing_rate=50 * pq.Hz):
        self.spike_count_range = spike_count_range
        self.train_count_range = train_count_range
 def test_raises_exception_if_number_of_trials_differs(self):
     st = create_empty_spike_train()
     with self.assertRaises(ValueError):
         stm.victor_purpura_multiunit_dist({0: [st], 1: [st, st]}, 1.0)
 def test_returns_empty_array_if_trials_are_empty(self):
     expected = sp.zeros((0, 0))
     actual = stm.victor_purpura_multiunit_dist({0: [], 1: []}, 1.0)
     assert_array_equal(expected, actual)
 def test_returns_empty_array_if_empty_dict_is_passed(self):
     expected = sp.zeros((0, 0))
     actual = stm.victor_purpura_multiunit_dist({}, 1.0)
     assert_array_equal(expected, actual)
 def calc_metric(self, trains):
     return stm.victor_purpura_multiunit_dist({0: trains}, 1)
           lambda trains: stm.hunter_milton_similarity(trains, tau)),
    'norm': ('norm distance',
             lambda trains: stm.norm_dist(
                 trains, sigproc.CausalDecayingExpKernel(tau), sampling_rate)),
    'ss': ('Schreiber et al. similarity measure',
           lambda trains: stm.schreiber_similarity(
               trains, sigproc.GaussianKernel(tau), sort=False)),
    'vr': ('van Rossum distance',
           lambda trains: stm.van_rossum_dist(trains, tau, sort=False)),
    'vp': ('Victor-Purpura\'s distance',
           lambda trains: stm.victor_purpura_dist(trains, 2.0 / tau)),
    'vr_mu': ('van Rossum multi-unit distance',
              lambda trains: stm.van_rossum_multiunit_dist(
                  trains_as_multiunits(trains), 0.5, tau)),
    'vp_mu': ('Victor-Purpura\'s multi-unit distance',
              lambda trains: stm.victor_purpura_multiunit_dist(
                  trains_as_multiunits(trains), 0.3, 2.0 / tau))}


def print_available_metrics():
    for key in metrics:
        print "%s  (%s)" % (key, metrics[key][0])


def print_summary(profile_file):
    stats = pstats.Stats(profile_file)
    stats.strip_dirs().sort_stats('cumulative').print_stats(
        r'^spike_train_metrics.py:\d+\([^_<].*(?<!compute)\)')


def profile_metrics(trains, to_profile):
    for key in to_profile: