Esempio n. 1
0
    def test_opt_1(self):
        specimen_id = self.ids[1]
        cellmodel = "ADEXP"

        if cellmodel == "IZHI":
            model = model_classes.IzhiModel()
        if cellmodel == "MAT":
            model = model_classes.MATModel()
        if cellmodel == "ADEXP":
            model = model_classes.ADEXPModel()

        target_num_spikes = 8
        dtc = DataTC()
        dtc.backend = cellmodel
        dtc._backend = model._backend
        dtc.attrs = model.attrs
        dtc.params = {
            k: np.mean(v)
            for k, v in MODEL_PARAMS[cellmodel].items()
        }

        dtc = dtc_to_rheo(dtc)
        assert dtc.rheobase is not None
        self.assertIsNotNone(dtc.rheobase)
        vm, plt, dtc = inject_and_plot_model(dtc, plotly=False)
        fixed_current = 122 * qt.pA
        model, suite, nu_tests, target_current, spk_count = opt_setup(
            specimen_id,
            cellmodel,
            target_num_spikes,
            provided_model=model,
            fixed_current=False,
            cached=True)
        model = dtc.dtc_to_model()
        model.seeded_current = target_current['value']
        model.allen = True
        model.seeded_current
        model.NU = True
        cell_evaluator, simple_cell = opt_setup_two(model,
                                                    cellmodel,
                                                    suite,
                                                    nu_tests,
                                                    target_current,
                                                    spk_count,
                                                    provided_model=model)
        NGEN = 15
        MU = 12

        mapping_funct = dask_map_function
        final_pop, hall_of_fame, logs, hist = opt_exec(MU, NGEN, mapping_funct,
                                                       cell_evaluator)
        opt, target = opt_to_model(hall_of_fame, cell_evaluator, suite,
                                   target_current, spk_count)
        best_ind = hall_of_fame[0]
        fitnesses = cell_evaluator.evaluate_with_lists(best_ind)
        assert np.sum(fitnesses) < 0.7
        self.assertGreater(0.7, np.sum(fitnesses))

        #obnames = [obj.name for obj in cell_evaluator.objectives]

        gen_numbers = logs.select('gen')
        min_fitness = logs.select('min')
        max_fitness = logs.select('max')
        avg_fitness = logs.select('avg')
        plt.plot(gen_numbers, max_fitness, label='max fitness')
        plt.plot(gen_numbers, avg_fitness, label='avg fitness')
        plt.plot(gen_numbers, min_fitness, label='min fitness')
        plt.plot(gen_numbers, min_fitness, label='min fitness')
        plt.semilogy()
        plt.xlabel('generation #')
        plt.ylabel('score (# std)')
        plt.legend()
        plt.xlim(min(gen_numbers) - 1, max(gen_numbers) + 1)
        #model = opt.dtc_to_model()
        plt.plot(opt.vm15.times, opt.vm15)
        plt.plot(opt.vm15.times, opt.vm15)
        target.vm15 = suite.traces['vm15']
        plt.plot(target.vm15.times, target.vm15)
        target.vm15 = suite.traces['vm15']
        check_bin_vm15(target, opt)
best_ind = hall_of_fame[0]
fitnesses = cell_evaluator.evaluate_with_lists(best_ind)
assert np.sum(fitnesses) < 0.55
obnames = [obj.name for obj in cell_evaluator.objectives]

gen_numbers = logs.select('gen')
min_fitness = logs.select('min')
max_fitness = logs.select('max')
avg_fitness = logs.select('avg')
plt.plot(gen_numbers, max_fitness, label='max fitness')
plt.plot(gen_numbers, avg_fitness, label='avg fitness')
plt.plot(gen_numbers, min_fitness, label='min fitness')
plt.plot(gen_numbers, min_fitness, label='min fitness')
plt.semilogy()
plt.xlabel('generation #')
plt.ylabel('score (# std)')
plt.legend()
plt.xlim(min(gen_numbers) - 1, max(gen_numbers) + 1)
#plt.show()
model = opt.dtc_to_model()
plt.plot(opt.vm15.times, opt.vm15)
#plt.show()
plt.plot(opt.vm15.times, opt.vm15)
#plt.show()
target.vm15 = suite.traces['vm15']
plt.plot(target.vm15.times, target.vm15)
#plt.show()
target.vm15 = suite.traces['vm15']
check_bin_vm15(target, opt)
opt.attrs
Esempio n. 3
0
def opt_setup(specimen_id,
              cellmodel,
              target_num,
              provided_model=None,
              cached=None,
              fixed_current=False):
    if cached is not None:
        with open(str(specimen_id) + 'later_allen_NU_tests.p', 'rb') as f:
            suite = pickle.load(f)

    else:

        sweep_numbers, data_set, sweeps = make_allen_tests_from_id.allen_id_to_sweeps(
            specimen_id)
        vmm, stimulus, sn, spike_times = make_allen_tests_from_id.get_model_parts_sweep_from_spk_cnt(
            target_num, data_set, sweep_numbers, specimen_id)
        suite, specimen_id = make_allen_tests_from_id.make_suite_known_sweep_from_static_models(
            vmm, stimulus, specimen_id)
        with open(str(specimen_id) + 'later_allen_NU_tests.p', 'wb') as f:
            pickle.dump(suite, f)

    target = StaticModel(vm=suite.traces['vm15'])
    target.vm15 = suite.traces['vm15']
    nu_tests = suite.tests

    check_bin_vm15(target, target)
    attrs = {k: np.mean(v) for k, v in MODEL_PARAMS[cellmodel].items()}
    dtc = DataTC(backend=cellmodel, attrs=attrs)
    for t in nu_tests:
        if t.name == 'Spikecount_1.5x':
            spk_count = float(t.observation['mean'])
            break
    observation_range = {}
    observation_range['value'] = spk_count
    #if provided_model is not None:
    #    model = provided_model
    #else:
    #    model = dtc.dtc_to_model()

    if provided_model is None:
        print('depricated in favor of jithub model')
        assert 1 == 2
        provided_model = ephys.models.ReducedCellModel(
            name='simple_cell',
            params=BPO_PARAMS[cellmodel],
            backend=cellmodel)
        provided_model.params = {
            k: np.mean(v)
            for k, v in model.params.items()
        }

    provided_model.backend = cellmodel
    provided_model.allen = None
    provided_model.allen = True
    model = provided_model
    #print(model.params)

    if fixed_current:
        uc = {
            'amplitude': fixed_current,
            'duration': ALLEN_DURATION,
            'delay': ALLEN_DELAY
        }
        target_current = None
    else:
        scs = SpikeCountSearch(observation_range)
        target_current = scs.generate_prediction(provided_model)
        ALLEN_DELAY = 1000.0 * qt.s
        ALLEN_DURATION = 2000.0 * qt.s
        uc = {
            'amplitude': target_current['value'],
            'duration': ALLEN_DURATION,
            'delay': ALLEN_DELAY
        }
        #tg = target_current['value']

    model = dtc.dtc_to_model()
    model._backend.inject_square_current(**uc)

    return model, suite, nu_tests, target_current, spk_count
Esempio n. 4
0
    def test_opt(self):
        ids = [
            324257146, 325479788, 476053392, 623893177, 623960880, 482493761,
            471819401
        ]

        specimen_id = ids[1]
        cellmodel = "MAT"

        if cellmodel == "IZHI":
            model = model_classes.IzhiModel()
        if cellmodel == "MAT":
            model = model_classes.MATModel()
        if cellmodel == "ADEXP":
            model = model_classes.ADEXPModel()

        specific_filter_list = [
            'ISI_log_slope_1.5x', 'mean_frequency_1.5x',
            'adaptation_index2_1.5x', 'first_isi_1.5x', 'ISI_CV_1.5x',
            'median_isi_1.5x', 'Spikecount_1.5x', 'all_ISI_values',
            'ISI_values', 'time_to_first_spike', 'time_to_last_spike',
            'time_to_second_spike', 'spike_times'
        ]
        simple_yes_list = specific_filter_list
        target_num_spikes = 8
        dtc = DataTC()
        dtc.backend = cellmodel
        dtc._backend = model._backend
        dtc.attrs = model.attrs
        dtc.params = {
            k: np.mean(v)
            for k, v in MODEL_PARAMS[cellmodel].items()
        }
        dtc = dtc_to_rheo(dtc)
        vm, plt, dtc = inject_and_plot_model(dtc, plotly=False)
        fixed_current = 122 * qt.pA
        model, suite, nu_tests, target_current, spk_count = opt_setup(
            specimen_id,
            cellmodel,
            target_num_spikes,
            provided_model=model,
            fixed_current=False)
        model = dtc.dtc_to_model()
        model.seeded_current = target_current['value']
        model.allen = True
        model.seeded_current
        model.NU = True
        cell_evaluator, simple_cell = opt_setup_two(model,
                                                    cellmodel,
                                                    suite,
                                                    nu_tests,
                                                    target_current,
                                                    spk_count,
                                                    provided_model=model)
        NGEN = 15
        MU = 12

        mapping_funct = dask_map_function
        final_pop, hall_of_fame, logs, hist = opt_exec(MU, NGEN, mapping_funct,
                                                       cell_evaluator)
        opt, target = opt_to_model(hall_of_fame, cell_evaluator, suite,
                                   target_current, spk_count)
        best_ind = hall_of_fame[0]
        fitnesses = cell_evaluator.evaluate_with_lists(best_ind)
        assert np.sum(fitnesses) < 0.7
        self.assertGreater(0.7, np.sum(fitnesses))

        #obnames = [obj.name for obj in cell_evaluator.objectives]

        gen_numbers = logs.select('gen')
        min_fitness = logs.select('min')
        max_fitness = logs.select('max')
        avg_fitness = logs.select('avg')
        plt.plot(gen_numbers, max_fitness, label='max fitness')
        plt.plot(gen_numbers, avg_fitness, label='avg fitness')
        plt.plot(gen_numbers, min_fitness, label='min fitness')
        plt.plot(gen_numbers, min_fitness, label='min fitness')
        plt.semilogy()
        plt.xlabel('generation #')
        plt.ylabel('score (# std)')
        plt.legend()
        plt.xlim(min(gen_numbers) - 1, max(gen_numbers) + 1)
        #model = opt.dtc_to_model()
        plt.plot(opt.vm15.times, opt.vm15)
        plt.plot(opt.vm15.times, opt.vm15)
        target.vm15 = suite.traces['vm15']
        plt.plot(target.vm15.times, target.vm15)
        target.vm15 = suite.traces['vm15']
        check_bin_vm15(target, opt)