class TestPhysTree(): def loadTree(self, reinitialize=0, segments=False): """ Load the T-tree morphology in memory 6--5--4--7--8 | | 1 """ if not hasattr(self, 'tree') or reinitialize: print('>>> loading T-tree <<<') fname = 'Ttree_segments.swc' if segments else 'Ttree.swc' self.tree = PhysTree('test_morphologies/' + fname, types=[1, 3, 4]) def testLeakDistr(self): self.loadTree(reinitialize=1) with pytest.raises(AssertionError): self.tree.fitLeakCurrent(-75., -10.) # test simple distribution self.tree.fitLeakCurrent(-75., 10.) for node in self.tree: assert np.abs(node.c_m - 1.0) < 1e-9 assert np.abs(node.currents['L'][0] - 1. / (10. * 1e-3)) < 1e-9 assert np.abs(node.e_eq + 75.) < 1e-9 # create complex distribution tau_distr = lambda x: x + 100. for node in self.tree: d2s = self.tree.pathLength({ 'node': node.index, 'x': 1. }, (1., 0.5)) node.fitLeakCurrent(self.tree.channel_storage, e_eq_target=-75., tau_m_target=tau_distr(d2s)) assert np.abs(node.c_m - 1.0) < 1e-9 assert np.abs(node.currents['L'][0] - 1. / (tau_distr(d2s)*1e-3)) < \ 1e-9 assert np.abs(node.e_eq + 75.) < 1e-9 def testPhysiologySetting(self): self.loadTree(reinitialize=1) d2s = {1: 0., 4: 50., 5: 125., 6: 175., 7: 125., 8: 175.} # passive parameters as float c_m = 1. r_a = 100. * 1e-6 self.tree.setPhysiology(c_m, r_a) for node in self.tree: assert np.abs(node.c_m - c_m) < 1e-10 assert np.abs(node.r_a - r_a) < 1e-10 # passive parameters as function c_m = lambda x: .5 * x + 1. r_a = lambda x: np.exp(0.01 * x) * 100 * 1e-6 self.tree.setPhysiology(c_m, r_a) for node in self.tree: assert np.abs(node.c_m - c_m(d2s[node.index])) < 1e-10 assert np.abs(node.r_a - r_a(d2s[node.index])) < 1e-10 # passive parameters as incomplete dict r_a = 100. * 1e-6 c_m = {1: 1., 4: 1.2} with pytest.raises(KeyError): self.tree.setPhysiology(c_m, r_a) # passive parameters as complete dict c_m.update({5: 1.1, 6: 0.9, 7: 0.8, 8: 1.}) self.tree.setPhysiology(c_m, r_a) for node in self.tree: assert np.abs(node.c_m - c_m[node.index]) < 1e-10 # equilibrium potential as float e_eq = -75. self.tree.setEEq(e_eq) for node in self.tree: assert np.abs(node.e_eq - e_eq) < 1e-10 # equilibrium potential as dict e_eq = {1: -75., 4: -74., 5: -73., 6: -72., 7: -71., 8: -70.} self.tree.setEEq(e_eq) for node in self.tree: assert np.abs(node.e_eq - e_eq[node.index]) < 1e-10 # equilibrium potential as function e_eq = lambda x: -70. + 0.1 * x self.tree.setEEq(e_eq) for node in self.tree: assert np.abs(node.e_eq - e_eq(d2s[node.index])) < 1e-10 # as wrong type with pytest.raises(TypeError): self.tree.setEEq([]) self.tree.setPhysiology([], []) # leak as float g_l, e_l = 100., -75. self.tree.setLeakCurrent(g_l, e_l) for node in self.tree: g, e = node.currents['L'] assert np.abs(g - g_l) < 1e-10 assert np.abs(e - e_l) < 1e-10 # equilibrium potential as dict g_l = {1: 101., 4: 103., 5: 105., 6: 107., 7: 108., 8: 109.} e_l = {1: -75., 4: -74., 5: -73., 6: -72., 7: -71., 8: -70.} self.tree.setLeakCurrent(g_l, e_l) for node in self.tree: g, e = node.currents['L'] assert np.abs(g - g_l[node.index]) < 1e-10 assert np.abs(e - e_l[node.index]) < 1e-10 # equilibrium potential as function g_l = lambda x: 100. + 0.05 * x e_l = lambda x: -70. + 0.05 * x self.tree.setLeakCurrent(g_l, e_l) for node in self.tree: g, e = node.currents['L'] assert np.abs(g - g_l(d2s[node.index])) < 1e-10 assert np.abs(e - e_l(d2s[node.index])) < 1e-10 # as wrong type with pytest.raises(TypeError): self.tree.setLeakCurrent([]) # gmax as potential as float e_rev = 100. g_max = 100. channel = channelcollection.TestChannel2() self.tree.addCurrent(channel, g_max, e_rev) for node in self.tree: g_m = node.currents['TestChannel2'][0] assert np.abs(g_m - g_max) < 1e-10 # equilibrium potential as dict g_max = {1: 101., 4: 103., 5: 104., 6: 106., 7: 107., 8: 110.} self.tree.addCurrent(channel, g_max, e_rev) for node in self.tree: g_m = node.currents['TestChannel2'][0] assert np.abs(g_m - g_max[node.index]) < 1e-10 # equilibrium potential as function g_max = lambda x: 100. + 0.005 * x**2 self.tree.addCurrent(channel, g_max, e_rev) for node in self.tree: g_m = node.currents['TestChannel2'][0] assert np.abs(g_m - g_max(d2s[node.index])) < 1e-10 # test is channel is stored assert isinstance( self.tree.channel_storage[channel.__class__.__name__], channelcollection.TestChannel2) # check if error is thrown if an ionchannel is not give with pytest.raises(IOError): self.tree.addCurrent('TestChannel2', g_max, e_rev) def testMembraneFunctions(self): self.loadTree(reinitialize=1) self.tree.setPhysiology(1., 100 * 1e-6) # passive parameters c_m = 1. r_a = 100. * 1e-6 e_eq = -75. self.tree.setPhysiology(c_m, r_a) self.tree.setEEq(e_eq) # channel p_open = .9 * .3**3 * .5**2 + .1 * .4**2 * .6**1 # TestChannel2 g_chan, e_chan = 100., 100. channel = channelcollection.TestChannel2() self.tree.addCurrent(channel, g_chan, e_chan) # fit the leak current self.tree.fitLeakCurrent(-30., 10.) # test if fit was correct for node in self.tree: tau_mem = c_m / (node.currents['L'][0] + g_chan * p_open) * 1e3 assert np.abs(tau_mem - 10.) < 1e-10 e_eq = (node.currents['L'][0]*node.currents['L'][1] + \ g_chan*p_open*e_chan) / (node.currents['L'][0] + g_chan*p_open) assert np.abs(e_eq - (-30.)) < 1e-10 # test if warning is raised for impossible to reach time scale with pytest.warns(UserWarning): tree = copy.deepcopy(self.tree) tree.fitLeakCurrent(-30., 100000.) # total membrane conductance g_pas = self.tree[1].currents['L'][0] + g_chan * p_open # make passive membrane tree = copy.deepcopy(self.tree) tree.asPassiveMembrane() # test if fit was correct for node in tree: assert np.abs(node.currents['L'][0] - g_pas) < 1e-10 def testCompTree(self): self.loadTree(reinitialize=1, segments=True) # capacitance axial resistance constant c_m = 1. r_a = 100. * 1e-6 self.tree.setPhysiology(c_m, r_a) self.tree.setCompTree() self.tree.treetype = 'computational' assert [n.index for n in self.tree] == [1, 8, 10, 12] # capacitance and axial resistance change c_m = lambda x: 1. if x < 200. else 1.6 r_a = lambda x: 1. if x < 300. else 1.6 self.tree.setPhysiology(c_m, r_a) self.tree.setCompTree() self.tree.treetype = 'computational' assert [n.index for n in self.tree] == [1, 5, 6, 8, 10, 12] # leak current changes g_l = lambda x: 100. if x < 400. else 160. self.tree.setLeakCurrent(g_l, -75.) self.tree.setCompTree() self.tree.treetype = 'computational' assert [n.index for n in self.tree] == [1, 5, 6, 7, 8, 10, 12] # leak current & reversal change g_l = 100. e_l = {ind: -75. for ind in [1, 4, 5, 6, 7, 8, 11, 12]} e_l.update({ind: -55. for ind in [9, 10]}) self.tree.setLeakCurrent(g_l, e_l) self.tree.setCompTree() self.tree.treetype = 'computational' assert [n.index for n in self.tree] == [1, 5, 6, 8, 10, 12] # leak current & reversal change g_l = 100. e_l = {ind: -75. for ind in [1, 4, 5, 6, 7, 8, 10, 11, 12]} e_l.update({9: -55.}) self.tree.setLeakCurrent(g_l, e_l) self.tree.setCompTree() self.tree.treetype = 'computational' assert [n.index for n in self.tree] == [1, 5, 6, 8, 9, 10, 12] # shunt self.tree.treetype = 'original' self.tree[7].g_shunt = 1. self.tree.setCompTree() self.tree.treetype = 'computational' assert [n.index for n in self.tree] == [1, 5, 6, 7, 8, 9, 10, 12]
class TestCompartmentFitter(): def loadTTree(self): ''' Load the T-tree model 6--5--4--7--8 | | 1 ''' fname = os.path.join(MORPHOLOGIES_PATH_PREFIX, 'Tsovtree.swc') self.tree = PhysTree(fname, types=[1, 3, 4]) self.tree.setPhysiology(0.8, 100. / 1e6) self.tree.fitLeakCurrent(-75., 10.) self.tree.setCompTree() def loadBallAndStick(self): ''' Load the ball and stick model 1--4 ''' self.tree = PhysTree(file_n=os.path.join(MORPHOLOGIES_PATH_PREFIX, 'ball_and_stick.swc')) self.tree.setPhysiology(0.8, 100. / 1e6) self.tree.setLeakCurrent(100., -75.) self.tree.setCompTree() def loadBall(self): ''' Load point neuron model ''' self.tree = PhysTree( file_n=os.path.join(MORPHOLOGIES_PATH_PREFIX, 'ball.swc')) # capacitance and axial resistance self.tree.setPhysiology(0.8, 100. / 1e6) # ion channels k_chan = channelcollection.Kv3_1() self.tree.addCurrent(k_chan, 0.766 * 1e6, -85.) na_chan = channelcollection.Na_Ta() self.tree.addCurrent(na_chan, 1.71 * 1e6, 50.) # fit leak current self.tree.fitLeakCurrent(-75., 10.) # set equilibirum potententials self.tree.setEEq(-75.) # set computational tree self.tree.setCompTree() def loadTSegmentTree(self): ''' Load point neuron model ''' self.tree = PhysTree(file_n=os.path.join(MORPHOLOGIES_PATH_PREFIX, 'Ttree_segments.swc')) # self.tree = PhysTree(file_n=os.path.join(MORPHOLOGIES_PATH_PREFIX, 'L23PyrBranco.swc')) # capacitance and axial resistance self.tree.setPhysiology(0.8, 100. / 1e6) # ion channels k_chan = channelcollection.Kv3_1() g_k = {1: 0.766 * 1e6} g_k.update({n.index: 0.034*1e6 / self.tree.pathLength((1,.5), (n.index,.5)) \ for n in self.tree if n.index != 1}) self.tree.addCurrent(k_chan, g_k, -85.) na_chan = channelcollection.Na_Ta() self.tree.addCurrent(na_chan, 1.71 * 1e6, 50., node_arg=[self.tree[1]]) # fit leak current self.tree.fitLeakCurrent(-75., 10.) # set equilibirum potententials self.tree.setEEq(-75.) # set computational tree self.tree.setCompTree() def testTreeStructure(self): self.loadTTree() cm = CompartmentFitter(self.tree) # set of locations fit_locs1 = [(1, .5), (4, .5), (5, .5)] # no bifurcations fit_locs2 = [(1, .5), (4, .5), (5, .5), (8, .5)] # w bifurcation, should be added fit_locs3 = [(1, .5), (4, 1.), (5, .5), (8, .5)] # w bifurcation, already added # test fit_locs1, no bifurcation are added # input paradigm 1 cm.setCTree(fit_locs1, extend_w_bifurc=True) # fl1_a = cm.tree.getLocs('fit locs') # with pytest.warns(UserWarning): # cm.setCTree(fit_locs1, extend_w_bifurc=False) # fl1_b = cm.tree.getLocs('fit locs') # assert len(fl1_a) == len(fl1_b) # for fla, flb in zip(fl1_a, fl1_b): assert fla == flb # # input paradigm 2 # cm.tree.storeLocs(fit_locs1, 'fl1') # cm.setCTree('fl1', extend_w_bifurc=True) # fl1_a = cm.tree.getLocs('fit locs') # assert len(fl1_a) == len(fl1_b) # for fla, flb in zip(fl1_a, fl1_b): assert fla == flb # # test tree structure # assert len(cm.ctree) == 3 # for cn in cm.ctree: assert len(cn.child_nodes) <= 1 # # test fit_locs2, a bifurcation should be added # with pytest.warns(UserWarning): # cm.setCTree(fit_locs2, extend_w_bifurc=False) # fl2_b = cm.tree.getLocs('fit locs') # cm.setCTree(fit_locs2, extend_w_bifurc=True) # fl2_a = cm.tree.getLocs('fit locs') # assert len(fl2_a) == len(fl2_b) + 1 # for fla, flb in zip(fl2_a, fl2_b): assert fla == flb # assert fl2_a[-1] == (4,1.) # # test tree structure # assert len(cm.ctree) == 5 # for cn in cm.ctree: # assert len(cn.child_nodes) <= 1 if cn.loc_ind != 4 else \ # len(cn.child_nodes) == 2 # # test fit_locs2, no bifurcation should be added as it is already present # cm.setCTree(fit_locs3, extend_w_bifurc=True) # fl3 = cm.tree.getLocs('fit locs') # for fl_, fl3 in zip(fit_locs3, fl3): assert fl_ == fl3 # # test tree structure # assert len(cm.ctree) == 4 # for cn in cm.ctree: # assert len(cn.child_nodes) <= 1 if cn.loc_ind != 1 else \ # len(cn.child_nodes) == 2 def _checkChannels(self, tree, channel_names): assert isinstance(tree, compartmentfitter.FitTreeGF) assert set(tree.channel_storage.keys()) == set(channel_names) for node in tree: assert set(node.currents.keys()) == set(channel_names + ['L']) def testCreateTreeGF(self): self.loadBall() cm = CompartmentFitter(self.tree) # create tree with only 'L' tree_pas = cm.createTreeGF() self._checkChannels(tree_pas, []) # create tree with only 'Na_Ta' tree_na = cm.createTreeGF(['Na_Ta']) self._checkChannels(tree_na, ['Na_Ta']) # create tree with only 'Kv3_1' tree_k = cm.createTreeGF(['Kv3_1']) self._checkChannels(tree_k, ['Kv3_1']) # create tree with all channels tree_all = cm.createTreeGF(['Na_Ta', 'Kv3_1']) self._checkChannels(tree_all, ['Na_Ta', 'Kv3_1']) def reduceExplicit(self): self.loadBall() freqs = np.array([0.]) locs = [(1, 0.5)] e_eqs = [-75., -55., -35., -15.] # create compartment tree ctree = self.tree.createCompartmentTree(locs) ctree.addCurrent(channelcollection.Na_Ta(), 50.) ctree.addCurrent(channelcollection.Kv3_1(), -85.) # create tree with only leak greens_tree_pas = self.tree.__copy__(new_tree=GreensTree()) greens_tree_pas[1].currents = {'L': greens_tree_pas[1].currents['L']} greens_tree_pas.setCompTree() greens_tree_pas.setImpedance(freqs) # compute the passive impedance matrix z_mat_pas = greens_tree_pas.calcImpedanceMatrix(locs)[0] # create tree with only potassium greens_tree_k = self.tree.__copy__(new_tree=GreensTree()) greens_tree_k[1].currents = {key: val for key, val in greens_tree_k[1].currents.items() \ if key != 'Na_Ta'} # compute potassium impedance matrices z_mats_k = [] for e_eq in e_eqs: greens_tree_k.setEEq(e_eq) greens_tree_k.setCompTree() greens_tree_k.setImpedance(freqs) z_mats_k.append(greens_tree_k.calcImpedanceMatrix(locs)) # create tree with only sodium greens_tree_na = self.tree.__copy__(new_tree=GreensTree()) greens_tree_na[1].currents = {key: val for key, val in greens_tree_na[1].currents.items() \ if key != 'Kv3_1'} # create state variable expansion points svs = [] e_eqs_ = [] na_chan = greens_tree_na.channel_storage['Na_Ta'] for e_eq1 in e_eqs: sv1 = na_chan.computeVarinf(e_eq1) for e_eq2 in e_eqs: e_eqs_.append(e_eq2) sv2 = na_chan.computeVarinf(e_eq2) svs.append({'m': sv2['m'], 'h': sv1['h']}) # compute sodium impedance matrices z_mats_na = [] for sv, eh in zip(svs, e_eqs_): greens_tree_na.setEEq(eh) greens_tree_na[1].setExpansionPoint('Na_Ta', sv) greens_tree_na.setCompTree() greens_tree_na.setImpedance(freqs) z_mats_na.append(greens_tree_na.calcImpedanceMatrix(locs)) # passive fit ctree.computeGMC(z_mat_pas) # potassium channel fit matrices fit_mats_k = [] for z_mat_k, e_eq in zip(z_mats_k, e_eqs): mf, vt = ctree.computeGSingleChanFromImpedance( 'Kv3_1', z_mat_k, e_eq, freqs, other_channel_names=['L'], action='return') fit_mats_k.append([mf, vt]) # sodium channel fit matrices fit_mats_na = [] for z_mat_na, e_eq, sv in zip(z_mats_na, e_eqs_, svs): mf, vt = ctree.computeGSingleChanFromImpedance( 'Na_Ta', z_mat_na, e_eq, freqs, sv=sv, other_channel_names=['L'], action='return') fit_mats_na.append([mf, vt]) return fit_mats_na, fit_mats_k def testChannelFitMats(self): self.loadBall() cm = CompartmentFitter(self.tree) cm.setCTree([(1, .5)]) # check if reversals are correct for key in set(cm.ctree[0].currents) - {'L'}: assert np.abs(cm.ctree[0].currents[key][1] - \ self.tree[1].currents[key][1]) < 1e-10 # fit the passive model cm.fitPassive(use_all_channels=False) fit_mats_cm_na = cm.evalChannel('Na_Ta', parallel=False) fit_mats_cm_k = cm.evalChannel('Kv3_1', parallel=False) fit_mats_control_na, fit_mats_control_k = self.reduceExplicit() # test whether potassium fit matrices agree for fm_cm, fm_control in zip(fit_mats_cm_k, fit_mats_control_k): assert np.allclose(np.sum(fm_cm[0]), fm_control[0][0, 0]) # feature matrices assert np.allclose(fm_cm[1], fm_control[1]) # target vectors # test whether sodium fit matrices agree for fm_cm, fm_control in zip(fit_mats_cm_na[4:], fit_mats_control_na): assert np.allclose(np.sum(fm_cm[0]), fm_control[0][0, 0]) # feature matrices assert np.allclose(fm_cm[1], fm_control[1]) # target vectors def _checkPasCondProps(self, ctree1, ctree2): assert len(ctree1) == len(ctree2) for n1, n2 in zip(ctree1, ctree2): assert np.allclose(n1.currents['L'][0], n2.currents['L'][0]) assert np.allclose(n1.g_c, n2.g_c) def _checkPasCaProps(self, ctree1, ctree2): assert len(ctree1) == len(ctree2) for n1, n2 in zip(ctree1, ctree2): assert np.allclose(n1.ca, n2.ca) def _checkAllCurrProps(self, ctree1, ctree2): assert len(ctree1) == len(ctree2) assert ctree1.channel_storage.keys() == ctree2.channel_storage.keys() for n1, n2 in zip(ctree1, ctree2): assert np.allclose(n1.g_c, n2.g_c) for key in n1.currents: assert np.allclose(n1.currents[key][0], n2.currents[key][0]) assert np.allclose(n1.currents[key][1], n2.currents[key][1]) def _checkPhysTrees(self, tree1, tree2): assert len(tree1) == len(tree2) assert tree1.channel_storage.keys() == tree2.channel_storage.keys() for n1, n2 in zip(tree1, tree2): assert np.allclose(n1.r_a, n2.r_a) assert np.allclose(n1.c_m, n2.c_m) for key in n1.currents: assert np.allclose(n1.currents[key][0], n2.currents[key][0]) assert np.allclose(n1.currents[key][1], n2.currents[key][1]) def _checkEL(self, ctree, e_l): for n in ctree: assert np.allclose(n.currents['L'][1], e_l) def testPassiveFit(self): self.loadTTree() fit_locs = [(1, .5), (4, 1.), (5, .5), (8, .5)] # fit a tree directly from CompartmentTree greens_tree = self.tree.__copy__(new_tree=GreensTree()) greens_tree.setCompTree() freqs = np.array([0.]) greens_tree.setImpedance(freqs) z_mat = greens_tree.calcImpedanceMatrix(fit_locs)[0].real ctree = greens_tree.createCompartmentTree(fit_locs) ctree.computeGMC(z_mat) sov_tree = self.tree.__copy__(new_tree=SOVTree()) sov_tree.calcSOVEquations() alphas, phimat = sov_tree.getImportantModes(locarg=fit_locs) ctree.computeC(-alphas[0:1].real * 1e3, phimat[0:1, :].real) # fit a tree with compartment fitter cm = CompartmentFitter(self.tree) cm.setCTree(fit_locs) cm.fitPassive() cm.fitCapacitance() cm.fitEEq() # check whether both trees are the same self._checkPasCondProps(ctree, cm.ctree) self._checkPasCaProps(ctree, cm.ctree) self._checkEL(cm.ctree, -75.) # test whether all channels are used correctly for passive fit self.loadBall() fit_locs = [(1, .5)] # fit ball model with only leak greens_tree = self.tree.__copy__(new_tree=GreensTree()) greens_tree.channel_storage = {} for n in greens_tree: n.currents = {'L': n.currents['L']} greens_tree.setCompTree() freqs = np.array([0.]) greens_tree.setImpedance(freqs) z_mat = greens_tree.calcImpedanceMatrix(fit_locs)[0].real ctree_leak = greens_tree.createCompartmentTree(fit_locs) ctree_leak.computeGMC(z_mat) sov_tree = greens_tree.__copy__(new_tree=SOVTree()) sov_tree.calcSOVEquations() alphas, phimat = sov_tree.getImportantModes(locarg=fit_locs) ctree_leak.computeC(-alphas[0:1].real * 1e3, phimat[0:1, :].real) # make ball model with leak based on all channels tree = self.tree.__copy__() tree.asPassiveMembrane() greens_tree = tree.__copy__(new_tree=GreensTree()) greens_tree.setCompTree() freqs = np.array([0.]) greens_tree.setImpedance(freqs) z_mat = greens_tree.calcImpedanceMatrix(fit_locs)[0].real ctree_all = greens_tree.createCompartmentTree(fit_locs) ctree_all.computeGMC(z_mat) sov_tree = tree.__copy__(new_tree=SOVTree()) sov_tree.calcSOVEquations() alphas, phimat = sov_tree.getImportantModes(locarg=fit_locs) ctree_all.computeC(-alphas[0:1].real * 1e3, phimat[0:1, :].real) # new compartment fitter cm = CompartmentFitter(self.tree) cm.setCTree(fit_locs) # test fitting cm.fitPassive(use_all_channels=False) cm.fitCapacitance() cm.fitEEq() self._checkPasCondProps(ctree_leak, cm.ctree) self._checkPasCaProps(ctree_leak, cm.ctree) with pytest.raises(AssertionError): self._checkEL(cm.ctree, self.tree[1].currents['L'][1]) cm.fitPassive(use_all_channels=True) cm.fitCapacitance() cm.fitEEq() self._checkPasCondProps(ctree_all, cm.ctree) self._checkPasCaProps(ctree_all, cm.ctree) self._checkEL(cm.ctree, greens_tree[1].currents['L'][1]) with pytest.raises(AssertionError): self._checkEL(cm.ctree, self.tree[1].currents['L'][1]) with pytest.raises(AssertionError): self._checkPasCondProps(ctree_leak, ctree_all) def testRecalcImpedanceMatrix(self, g_inp=np.linspace(0., 0.01, 20)): self.loadBall() fit_locs = [(1, .5)] cm = CompartmentFitter(self.tree) cm.setCTree(fit_locs) # test only leak # compute impedances explicitly greens_tree = cm.createTreeGF(channel_names=[]) greens_tree.setEEq(-75.) greens_tree.setImpedancesInTree() z_mat = greens_tree.calcImpedanceMatrix(fit_locs, explicit_method=False)[0] z_test = z_mat[:, :, None] / (1. + z_mat[:, :, None] * g_inp[None, None, :]) # compute impedances with compartmentfitter function z_calc = np.array([ \ cm.recalcImpedanceMatrix('fit locs', [g_i], \ channel_names=[] ) \ for g_i in g_inp \ ]) z_calc = np.swapaxes(z_calc, 0, 2) assert np.allclose(z_calc, z_test) # test with z based on all channels (passive) # compute impedances explicitly greens_tree = cm.createTreeGF( channel_names=list(cm.tree.channel_storage.keys())) greens_tree.setEEq(-75.) greens_tree.setImpedancesInTree() z_mat = greens_tree.calcImpedanceMatrix(fit_locs, explicit_method=False)[0] z_test = z_mat[:, :, None] / (1. + z_mat[:, :, None] * g_inp[None, None, :]) # compute impedances with compartmentfitter function z_calc = np.array([ \ cm.recalcImpedanceMatrix('fit locs', [g_i], \ channel_names=list(cm.tree.channel_storage.keys())) \ for g_i in g_inp \ ]) z_calc = np.swapaxes(z_calc, 0, 2) assert np.allclose(z_calc, z_test) def testSynRescale(self, g_inp=np.linspace(0., 0.01, 20)): e_rev, v_eq = 0., -75. self.loadBallAndStick() fit_locs = [(4, .7)] syn_locs = [(4, 1.)] cm = CompartmentFitter(self.tree) cm.setCTree(fit_locs) # compute impedance matrix greens_tree = cm.createTreeGF(channel_names=[]) greens_tree.setEEq(-75.) greens_tree.setImpedancesInTree() z_mat = greens_tree.calcImpedanceMatrix(fit_locs + syn_locs)[0] # analytical synapse scale factors beta_calc = 1. / (1. + (z_mat[1, 1] - z_mat[0, 0]) * g_inp) beta_full = z_mat[0,1] / z_mat[0,0] * (e_rev - v_eq) / \ ((1. + (z_mat[1,1] - z_mat[0,0]) * g_inp ) * (e_rev - v_eq)) # synapse scale factors from compartment fitter beta_cm = np.array([cm.fitSynRescale(fit_locs, syn_locs, [0], [g_i], e_revs=[0.])[0] \ for g_i in g_inp]) assert np.allclose(beta_calc, beta_cm, atol=.020) assert np.allclose(beta_full, beta_cm, atol=.015) def fitBall(self): self.loadBall() freqs = np.array([0.]) locs = [(1, 0.5)] e_eqs = [-75., -55., -35., -15.] # create compartment tree ctree = self.tree.createCompartmentTree(locs) ctree.addCurrent(channelcollection.Na_Ta(), 50.) ctree.addCurrent(channelcollection.Kv3_1(), -85.) # create tree with only leak greens_tree_pas = self.tree.__copy__(new_tree=GreensTree()) greens_tree_pas[1].currents = {'L': greens_tree_pas[1].currents['L']} greens_tree_pas.setCompTree() greens_tree_pas.setImpedance(freqs) # compute the passive impedance matrix z_mat_pas = greens_tree_pas.calcImpedanceMatrix(locs)[0] # create tree with only potassium greens_tree_k = self.tree.__copy__(new_tree=GreensTree()) greens_tree_k[1].currents = {key: val for key, val in greens_tree_k[1].currents.items() \ if key != 'Na_Ta'} # compute potassium impedance matrices z_mats_k = [] for e_eq in e_eqs: greens_tree_k.setEEq(e_eq) greens_tree_k.setCompTree() greens_tree_k.setImpedance(freqs) z_mats_k.append(greens_tree_k.calcImpedanceMatrix(locs)) # create tree with only sodium greens_tree_na = self.tree.__copy__(new_tree=GreensTree()) greens_tree_na[1].currents = {key: val for key, val in greens_tree_na[1].currents.items() \ if key != 'Kv3_1'} # create state variable expansion points svs = [] e_eqs_ = [] na_chan = greens_tree_na.channel_storage['Na_Ta'] for e_eq1 in e_eqs: sv1 = na_chan.computeVarinf(e_eq1) for e_eq2 in e_eqs: e_eqs_.append(e_eq2) sv2 = na_chan.computeVarinf(e_eq2) svs.append({'m': sv2['m'], 'h': sv1['h']}) # compute sodium impedance matrices z_mats_na = [] for ii, sv in enumerate(svs): greens_tree_na.setEEq(e_eqs[ii % len(e_eqs)]) greens_tree_na[1].setExpansionPoint('Na_Ta', sv) greens_tree_na.setCompTree() greens_tree_na.setImpedance(freqs) z_mats_na.append(greens_tree_na.calcImpedanceMatrix(locs)) # passive fit ctree.computeGMC(z_mat_pas) # get SOV constants for capacitance fit sov_tree = greens_tree_pas.__copy__(new_tree=SOVTree()) sov_tree.setCompTree() sov_tree.calcSOVEquations() alphas, phimat, importance = sov_tree.getImportantModes( locarg=locs, sort_type='importance', eps=1e-12, return_importance=True) # fit the capacitances from SOV time-scales ctree.computeC(-alphas[0:1].real * 1e3, phimat[0:1, :].real, weights=importance[0:1]) # potassium channel fit for z_mat_k, e_eq in zip(z_mats_k, e_eqs): ctree.computeGSingleChanFromImpedance('Kv3_1', z_mat_k, e_eq, freqs, other_channel_names=['L']) ctree.runFit() # sodium channel fit for z_mat_na, e_eq, sv in zip(z_mats_na, e_eqs_, svs): ctree.computeGSingleChanFromImpedance('Na_Ta', z_mat_na, e_eq, freqs, sv=sv, other_channel_names=['L']) ctree.runFit() ctree.setEEq(-75.) ctree.removeExpansionPoints() ctree.fitEL() self.ctree = ctree def testFitModel(self): self.loadTTree() fit_locs = [(1, .5), (4, 1.), (5, .5), (8, .5)] # fit a tree directly from CompartmentTree greens_tree = self.tree.__copy__(new_tree=GreensTree()) greens_tree.setCompTree() freqs = np.array([0.]) greens_tree.setImpedance(freqs) z_mat = greens_tree.calcImpedanceMatrix(fit_locs)[0] ctree = greens_tree.createCompartmentTree(fit_locs) ctree.computeGMC(z_mat) sov_tree = self.tree.__copy__(new_tree=SOVTree()) sov_tree.calcSOVEquations() alphas, phimat = sov_tree.getImportantModes(locarg=fit_locs) ctree.computeC(-alphas[0:1].real * 1e3, phimat[0:1, :].real) # fit a tree with compartmentfitter cm = CompartmentFitter(self.tree) ctree_cm = cm.fitModel(fit_locs) # compare the two trees self._checkPasCondProps(ctree_cm, ctree) self._checkPasCaProps(ctree_cm, ctree) self._checkEL(ctree_cm, -75.) # check active channel self.fitBall() locs = [(1, 0.5)] cm = CompartmentFitter(self.tree) ctree_cm_1 = cm.fitModel(locs, parallel=False, use_all_channels_for_passive=False) ctree_cm_2 = cm.fitModel(locs, parallel=False, use_all_channels_for_passive=True) self._checkAllCurrProps(self.ctree, ctree_cm_1) self._checkAllCurrProps(self.ctree, ctree_cm_2) def testPickling(self): self.loadBall() # of PhysTree ss = pickle.dumps(self.tree) pt_ = pickle.loads(ss) self._checkPhysTrees(self.tree, pt_) # of GreensTree greens_tree = self.tree.__copy__(new_tree=GreensTree()) greens_tree.setCompTree() freqs = np.array([0.]) greens_tree.setImpedance(freqs) ss = dill.dumps(greens_tree) gt_ = dill.loads(ss) self._checkPhysTrees(greens_tree, gt_) # of SOVTree sov_tree = self.tree.__copy__(new_tree=SOVTree()) sov_tree.calcSOVEquations() # fails with pickle (lambda functions) with pytest.raises(AttributeError): ss = pickle.dumps(sov_tree) # works with dill ss = dill.dumps(sov_tree) st_ = dill.loads(ss) self._checkPhysTrees(sov_tree, st_) def testParallel(self, w_benchmark=False): self.loadTSegmentTree() locs = [(nn.index, 0.5) for nn in self.tree.nodes[:30]] cm = CompartmentFitter(self.tree) ctree_cm = cm.fitModel(locs, parallel=False, use_all_channels_for_passive=True) if w_benchmark: from timeit import default_timer as timer t0 = timer() cm.fitChannels(recompute=False, pprint=False, parallel=False) t1 = timer() print('Not parallel: %.8f s' % (t1 - t0)) t0 = timer() cm.fitChannels(recompute=False, pprint=False, parallel=True) t1 = timer() print('Parallel: %.8f s' % (t1 - t0))