Ejemplo n.º 1
0
  def __init__(self,networkFile,inputFile):

    self.networkFile = networkFile
    self.inputFile = inputFile

    # We just need this to identify which neuron is which
    self.network = SnuddaLoad(self.networkFile, load_synapses=False)

    self.inputData = h5py.File(inputFile,'r')
Ejemplo n.º 2
0
    def __init__(self, network):

        if os.path.isdir(network):
            network_file = os.path.join(network, "network-synapses.hdf5")
        else:
            network_file = network

        self.network_file = network_file
        self.network_path = os.path.dirname(self.network_file)

        self.sl = SnuddaLoad(self.network_file)
Ejemplo n.º 3
0
    def read_neuron_positions(self):

        position_file = os.path.join(self.network_path,
                                     "network-neuron-positions.hdf5")
        self.network_info = SnuddaLoad(position_file)

        # We also need simulation origo and voxel size
        work_history_file = os.path.join(self.network_path, "log",
                                         "network-detect-worklog.hdf5")
        with h5py.File(work_history_file, "r") as work_hist:
            self.simulation_origo = work_hist["meta/simulationOrigo"][()]
            self.voxel_size = work_hist["meta/voxelSize"][()]
Ejemplo n.º 4
0
    def __init__(self,
                 network_path,
                 blender_save_file=None,
                 blender_output_image=None,
                 network_json=None):

        self.network_path = network_path

        if network_json:
            self.network_json = network_json
            self.network_file = None
        else:
            self.network_json = None
            self.network_file = os.path.join(network_path,
                                             "network-synapses.hdf5")

        if blender_save_file:
            self.blender_save_file = blender_save_file
        else:
            self.blender_save_file = os.path.join(network_path,
                                                  "visualise-network.blend")

        self.blender_output_image = blender_output_image

        self.neuron_cache = dict([])

        # Load the neuron positions
        if self.network_file:
            self.sl = SnuddaLoad(self.network_file)
            self.data = self.sl.data
        elif self.network_json:
            from snudda.utils.fake_load import FakeLoad
            self.sl = FakeLoad()
            self.sl.import_json(self.network_json)
            self.data = self.sl.data
Ejemplo n.º 5
0
    def __init__(self, input_file, network_path=None):

        self.input_data = None

        if input_file:
            self.load_input(input_file)

        if not network_path:
            network_path = os.path.dirname(input_file)

        network_file = os.path.join(network_path, "network-synapses.hdf5")

        if os.path.exists(network_file):
            self.network_info = SnuddaLoad(network_file)
        else:
            print(
                f"Specify a network_path with a network file, to get neuron type in figure"
            )
            self.network_info = None
Ejemplo n.º 6
0
class PlotDensity(object):
    def __init__(self, network):

        if os.path.isdir(network):
            network_file = os.path.join(network, "network-synapses.hdf5")
        else:
            network_file = network

        self.network_file = network_file
        self.network_path = os.path.dirname(self.network_file)

        self.sl = SnuddaLoad(self.network_file)

    def close(self):
        self.sl.close()

    def plot(self, neuron_type, plot_axis, n_bins=5):

        p_axis = {"x": 0, "y": 1, "z": 2}
        assert plot_axis in p_axis, f"plot_axis must be one of {', '.join(p_axis.keys())}"
        neuron_pos = self.sl.data["neuronPositions"]
        cell_id = self.sl.get_cell_id_of_type(neuron_type)

        fig = plt.figure()
        plt.hist(neuron_pos[cell_id, p_axis[plot_axis]], bins=n_bins)
        plt.title(f"Density of {neuron_type}")
        plt.xlabel(plot_axis)
        plt.ylabel("Count")

        fig_path = os.path.join(self.network_path, "figures")
        if not os.path.exists(fig_path):
            os.mkdir(fig_path)

        plt.savefig(os.path.join(fig_path, f"density-hist-{plot_axis}.png"))
        plt.ion()
        plt.show()
Ejemplo n.º 7
0
    def __init__(self, in_file, out_file, save_sparse=True):

        self.sl = SnuddaLoad(in_file)

        self.outFile = out_file
        self.out_file_meta = out_file + "-meta"

        data = self.sl.data

        con_mat = self.create_con_mat()
        neuron_type = [x["type"] for x in data["neurons"]]
        neuron_name = [x["name"] for x in data["neurons"]]
        morph_file = [data["morph"][nn]["location"] for nn in neuron_name]
        pos = data["neuronPositions"]

        print("Writing " + self.outFile + " (row = src, column=dest)")
        if save_sparse:
            x_pos, y_pos = np.where(con_mat)
            sparse_data = np.zeros((len(x_pos), 3), dtype=int)
            for idx, (x, y) in enumerate(zip(x_pos, y_pos)):
                sparse_data[idx, :] = [x, y, con_mat[x, y]]

            np.savetxt(self.outFile, sparse_data, delimiter=",", fmt="%d")

            # Test to verify
            for row in sparse_data:
                assert con_mat[row[0], row[1]] == row[2]
        else:
            np.savetxt(self.outFile, con_mat, delimiter=",", fmt="%d")

        print("Writing " + self.out_file_meta)
        with open(self.out_file_meta, "w") as f_out_meta:
            for i, (nt, nn, p, mf) in enumerate(zip(neuron_type, neuron_name, pos, morph_file)):
                s = "%d,%s,%s,%f,%f,%f,%s\n" % (i, nt, nn, p[0], p[1], p[2], mf)
                f_out_meta.write(s)
            f_out_meta.close()
Ejemplo n.º 8
0
    def __init__(self,
                 spike_file_name,
                 network_file=None,
                 skip_time=0.0,
                 type_order=None,
                 end_time=2.0,
                 figsize=None):

        self.spike_file_name = spike_file_name

        self.time = []
        self.spike_id = []
        self.end_time = end_time

        try:
            self.ID = int(
                re.findall('\d+', ntpath.basename(spike_file_name))[0])
        except:
            self.ID = 0

        self.neuron_name_remap = {"FSN": "FS"}

        self.read_csv()

        if network_file is not None:
            self.network_info = SnuddaLoad(network_file)
            self.network_file = network_file

            # assert(int(self.ID) == int(self.networkInfo.data["SlurmID"]))
        else:
            self.network_info = None
            self.network_file = None

        if self.network_info is None:
            print(
                "If you also give network file, then the plot shows neuron types"
            )
            self.plot_raster(skip_time=skip_time)
            time.sleep(1)
        else:

            self.sort_traces()
            self.plot_colour_raster(skip_time=skip_time, type_order=type_order)
            time.sleep(1)
Ejemplo n.º 9
0
    def __init__(self, file_name, network_file=None):

        self.file_name = file_name
        self.network_file = network_file

        self.time = []
        self.voltage = dict([])

        self.neuron_name_remap = {"FSN": "FS"}

        self.read_csv()

        try:
            self.ID = int(re.findall('\d+', ntpath.basename(file_name))[0])
        except:
            print("Unable to guess ID, using 666.")
            self.ID = 666

        if self.network_file is not None:
            self.network_info = SnuddaLoad(self.network_file)
            # assert(int(self.ID) == int(self.networkInfo.data["SlurmID"]))

        else:
            self.network_info = None
Ejemplo n.º 10
0
    def analyseNetwork(self, simName, simType=None, nPlotMax=10):

        figDir = simName + "/figures/"
        if (not os.path.exists(figDir)):
            os.makedirs(figDir)

        if (simType is None):
            simType = self.simType

        if (simType == "Straub2016LTS"):
            preType = "LTS"
            self.setupExpDataDict()
        elif (simType == "Straub2016FS"):
            preType = "FSN"
            self.setupExpDataDict()
        elif (simType == "Chuhma2011"):
            preType = "SPN"
            self.setupExpDataDict()
        elif (simType == "Szydlowski2013"):
            preType = "FSN"
        else:
            print("Unknown simType : " + simType)
            exit(-1)

        print("Analysing data in " + simName)
        voltFile = simName + "/" + simType + "-network-stimulation-current.txt"

        # Read data from file
        data = np.genfromtxt(voltFile, delimiter=",")

        assert (data[0, 0] == -1)  # First column should be time
        time = data[0, 1:] * 1e-3

        current = dict()

        for rows in data[1:, :]:
            cID = int(rows[0])
            current[cID] = rows[1:] * 1e-9

        # Data in time, current now

        # Read the network info
        networkFile = simName + "/network-synapses.hdf5"
        self.snuddaLoad = SnuddaLoad(networkFile)
        self.data = self.snuddaLoad.data

        recordedNeurons = [x for x in current]

        # Group the neurons by type

        if (simType == "Chuhma2011"):
            neuronTypeList = ["dSPN", "iSPN", "FSN", "ChIN"]
        elif (simType == "Straub2016FS" or simType == "Straub2016LTS"):
            neuronTypeList = ["dSPN", "iSPN", "ChIN"]
        elif (simType == "Szydlowski2013"):
            neuronTypeList = ["LTS"]
        else:
            print("simulate: Unknown simType: " + simType)
            exit(-1)

        neuronPlotList = []

        minTimeIdx = np.where(time > self.tInj)[0][0]
        maxTimeIdx = np.where(time > self.tInj + self.tWindow)[0][0]

        for nt in neuronTypeList:
            IDList = [
                x for x in current if self.data["neurons"][x]["type"] == nt
            ]
            maxIdx = [np.argmax(np.abs(current[x][minTimeIdx:maxTimeIdx] \
                                       -current[x][minTimeIdx])) + minTimeIdx \
                      for x in IDList]

            neuronPlotList.append((IDList, maxIdx))

        matplotlib.rcParams.update({'font.size': 22})

        for plotID, maxIdx in neuronPlotList:

            if (len(plotID) == 0):
                continue

            plotType = self.data["neurons"][plotID[0]]["type"]
            figName = figDir + "/" + simType + "-" + plotType + "-current-traces.pdf"
            figNameHist = figDir + "/" + simType + "-" + plotType + "-current-histogram.pdf"

            goodMax = []

            plt.figure()

            peakAmp = []
            peakTime = []
            voltCurve = []

            for pID, mIdx in zip(plotID, maxIdx):
                tIdx = np.where(
                    np.logical_and(time > self.tInj,
                                   time < self.tInj + self.tWindow))[0]

                curAmp = current[pID][tIdx] - current[pID][tIdx[0] - 1]
                maxAmp = current[pID][mIdx] - current[pID][tIdx[0] - 1]

                if (mIdx < minTimeIdx or mIdx > maxTimeIdx
                        or abs(maxAmp) < 1e-12):
                    # No peaks
                    continue

                goodMax.append(maxAmp * 1e9)
                peakAmp.append(maxAmp * 1e9)
                peakTime.append((time[mIdx] - time[tIdx[0]]) * 1e3)
                voltCurve.append(
                    ((time[tIdx] - time[tIdx[0]]) * 1e3, curAmp * 1e9))

            # Pick which curves to plot
            sortIdx = np.argsort(peakAmp)
            if (len(sortIdx) < nPlotMax):
                keepIdx = sortIdx
            else:
                keepIdx = [sortIdx[int(np.round(x))] for x in \
                           np.linspace(0,len(sortIdx)-1,nPlotMax)]

            for x in keepIdx:
                plt.plot(voltCurve[x][0], voltCurve[x][1], 'k-')

            plt.scatter(peakTime, peakAmp, marker=".", c="blue", s=100)

            nType = self.data["neurons"][plotID[0]]["type"]
            if ((simType, nType) in self.expDataDict):
                expData = self.expDataDict[(simType, nType)]
                t = self.tWindow * 1e3 * (
                    1 + 0.03 * np.random.rand(expData.shape[0]))
                plt.scatter(t, -expData, marker=".", c="red", s=100)

            if (self.plotExpTrace and (simType, nType) in self.expTraceDict):
                data = self.expTraceDict[(simType, nType)]
                tExp = data[:, 0]
                vExp = data[:, 1:]
                tIdx = np.where(tExp < self.tWindow * 1e3)[0]
                plt.plot(tExp[tIdx], vExp[tIdx, :], c="red")

            plt.title(
                self.neuronName(preType) + " to " + self.neuronName(plotType))
            plt.xlabel("Time (ms)")
            plt.ylabel("Current (nA)")

            # Remove part of the frame
            plt.gca().spines["right"].set_visible(False)
            plt.gca().spines["top"].set_visible(False)

            plt.tight_layout()
            plt.ion()
            plt.show()
            plt.savefig(figName, dpi=300)

            # Also plot histogram
            plt.figure()
            plt.hist(goodMax)
            plt.xlabel("Current (nA)")
            plt.title(
                self.neuronName(preType) + " to " + self.neuronName(plotType))

            # Remove part of the frame
            plt.gca().spines["right"].set_visible(False)
            plt.gca().spines["top"].set_visible(False)

            plt.tight_layout()
            plt.ion()
            plt.show()
            plt.savefig(figNameHist, dpi=300)

        import pdb
        pdb.set_trace()
Ejemplo n.º 11
0
class SnuddaCalibrateSynapses(object):

  def __init__(self,networkFile,
               preType,postType,
               curInj = 10e-9,
               holdV = -80e-3,
               maxDist = 50e-6,
               logFile=None):

    if(os.path.isdir(networkFile)):
      self.networkFile = networkFile + "/network-synapses.hdf5"
    else:
      self.networkFile = networkFile
      
    self.preType = preType
    self.postType = postType
    self.curInj = curInj
    self.holdV = holdV
    self.logFile = logFile
    self.maxDist = maxDist
    
    print("Checking depolarisation/hyperpolarisation of " + preType \
          + " to " + postType + "synapses")

    self.injSpacing = 0.2 # 0.5
    self.injDuration = 1e-3

    # Voltage file
    self.voltFile = os.path.dirname(networkFile) \
      + "/synapse-calibration-volt-" \
      + self.preType + "-" + self.postType + ".txt"
    self.voltFileAltMask = os.path.dirname(networkFile) \
      + "/synapse-calibration-volt-" \
      + self.preType + "-*.txt"

    self.neuronNameRemap = {"FSN" : "FS"}

  ############################################################################
    
  def neuronName(self,neuronType):

    if(neuronType in self.neuronNameRemap):
      return self.neuronNameRemap[neuronType]
    else:
      return neuronType    
    
  ############################################################################

  def setup(self,simName,expType,nMSD1=120,nMSD2=120,nFS=20,nLTS=0,nChIN=0):

    from snudda.init.init import SnuddaInit

    configName= simName + "/network-config.json"
    cnc = SnuddaInit(struct_def={}, config_file=configName, nChannels=1)
    cnc.define_striatum(num_dSPN=nMSD1, num_iSPN=nMSD2, num_FS=nFS, num_LTS=nLTS, num_ChIN=nChIN,
                        volume_type="slice", side_len=200e-6, slice_depth=150e-6)

    dirName = os.path.dirname(configName)
  
    if not os.path.exists(dirName):
      os.makedirs(dirName)

    cnc.write_json(configName)

    
    print("\n\npython3 snudda.py place " + str(simName))
    print("python3 snudda.py detect " + str(simName))
    print("python3 snudda.py prune " + str(simName))
    print("python3 snudda_cut.py " + str(simName) \
          + '/network-synapses.hdf5 "abs(z)<100e-6"')

    print("\nThe last command will pop up a figure and enter debug mode, press ctrl+D in the terminal window after inspecting the plot to continue")

    print("\n!!! Remember to compile the mod files: nrnivmodl data/neurons/mechanisms")

    print("\nTo run for example dSPN -> iSPN (and dSPN->dSPN) calibration:")
    print("mpiexec -n 12 -map-by socket:OVERSUBSCRIBE python3 snudda_calibrate_synapses.py run " + str(expType) + " " + str(simName) + "/network-cut-slice.hdf5 dSPN iSPN")

    print("\npython3 snudda_calibrate_synapses.py analyse " + str(expType) + " " + str(simName) + "/network-cut-slice.hdf5 --pre dSPN --post iSPN\npython3 snudda_calibrate_synapses.py analyse " + str(simName) + "/network-cut-slice.hdf5 --pre iSPN --post dSPN")
    
  ############################################################################

  def setupHoldingVolt(self,holdV=None,simEnd=None):

    assert simEnd is not None, \
      "setupHoldingVolt: Please set simEnd, for holding current"
    
    if(holdV is None):
      holdV = self.holdV

    if(holdV is None):
      print("Not using holding voltage, skipping.")
      return

    # Setup vClamps to calculate what holding current will be needed
    somaVClamp = []

    somaList = [self.snuddaSim.neurons[x].icell.soma[0] \
                for x in self.snuddaSim.neurons]
    
    for s in somaList:
      vc = neuron.h.SEClamp(s(0.5))
      vc.rs = 1e-9
      vc.amp1 = holdV*1e3
      vc.dur1 = 100

      somaVClamp.append((s,vc))

    neuron.h.finitialize(holdV*1e3)
    neuron.h.tstop = 100
    neuron.h.run()

    self.holdingIClampList = []

    # Setup iClamps    
    for s,vc in somaVClamp:
      cur = float(vc.i)
      ic = neuron.h.i_clamp(s(0.5))
      ic.amp = cur
      ic.dur = 2*simEnd*1e3
      self.holdingIClampList.append(ic)
      
    # Remove vClamps
    vClamps = None
    vc = None  
    
  ############################################################################

  def setGABArev(self,vRevCl):

    print("Setting GABA reversal potential to " + str(vRevCl*1e3) + " mV")
    
    for s in self.snuddaSim.synapse_list:
      assert s.e == -65, "It should be GABA synapses only that we modify!"
      s.e = vRevCl * 1e3
          
  ############################################################################
  
  def runSim(self,GABArev):
    
    self.snuddaSim = SnuddaSimulate(network_file=self.networkFile,
                                    input_file=None,
                                    log_file=self.logFile,
                                    disable_gap_junctions=True)

    
    # A current pulse to all pre synaptic neurons, one at a time
    self.preID = [x["neuronID"] \
                  for x in self.snuddaSim.network_info["neurons"] \
                  if x["type"] == self.preType]

    # injInfo contains (preID,injStartTime)
    self.injInfo = list(zip(self.preID, \
                            self.injSpacing\
                            +self.injSpacing*np.arange(0,len(self.preID))))
    
    simEnd = self.injInfo[-1][1] + self.injSpacing
    
    # Set the holding voltage
    self.setupHoldingVolt(holdV=self.holdV,simEnd=simEnd)

    self.setGABArev(GABArev)
    
    
    # Add current injections defined in init
    for (nid,t) in self.injInfo:
      print("Current injection to " + str(nid) + " at " + str(t) + " s")
      self.snuddaSim.add_current_injection(neuron_id=nid,
                                           start_time=t,
                                           end_time=t + self.injDuration,
                                           amplitude=self.curInj)

    # !!! We could maybe update code so that for postType == "ALL" we
    # record voltage from all neurons

    if(self.postType == "ALL"):
      self.snuddaSim.add_recording()
    else:
      # Record from all the potential post synaptic neurons
      self.snuddaSim.add_recording_of_type(self.postType)

      # Also save the presynaptic traces for debugging, to make sure they spike
      self.snuddaSim.add_recording_of_type(self.preType)

    
    # Run simulation
    self.snuddaSim.run(simEnd * 1e3, hold_v=self.holdV)
    
    # Write results to disk
    self.snuddaSim.write_voltage(self.voltFile)


  ############################################################################

  def readVoltage(self,voltFile):

    if(not os.path.exists(voltFile)):
      print("Missing " + voltFile)

      allFile = self.voltFileAltMask.replace("*","ALL")
      
      if(os.path.exists(allFile)):
        print("Using " + allFile + " instead")
        voltFile = allFile
      elif(self.preType == self.postType):
        fileList = glob.glob(self.voltFileAltMask)
        if(len(fileList) > 0):
          voltFile = fileList[0]
          print("Using " + voltFile + " instead, since pre and post are same")
      else:
        print("Aborting")
        exit(-1)
    
    data = np.genfromtxt(voltFile, delimiter=',')
    assert(data[0,0] == -1) # First column should be time
    time = data[0,1:] / 1e3
    
    voltage = dict()
    
    for rows in data[1:,:]:
      cID = int(rows[0])
      voltage[cID] = rows[1:] * 1e-3

    return (time,voltage)
      
  ############################################################################
  
  # This extracts all the voltage deflections, to see how strong they are
  
  def analyse(self,expType,maxDist=None,nMaxShow=10):

    self.setupExpData()
    
    if(maxDist is None):
      maxDist = self.maxDist
    
    # Read the data
    self.snuddaLoad = SnuddaLoad(self.networkFile)
    self.data = self.snuddaLoad.data

    time,voltage = self.readVoltage(self.voltFile) # sets self.voltage
    checkWidth = 0.05

    # Generate current info structure
    # A current pulse to all pre synaptic neurons, one at a time
    self.preID = [x["neuronID"] \
                  for x in self.data["neurons"] \
                  if x["type"] == self.preType]

    self.possiblePostID = [x["neuronID"] \
                           for x in self.data["neurons"] \
                           if x["type"] == self.postType]
    
    # injInfo contains (preID,injStartTime)
    self.injInfo = zip(self.preID, \
                       self.injSpacing\
                       +self.injSpacing*np.arange(0,len(self.preID)))
    
    # For each pre synaptic neuron, find the voltage deflection in each
    # of its post synaptic neurons

    synapseData = []
    tooFarAway = 0
    
    for (preID,t) in self.injInfo:
      # Post synaptic neuron to preID
      synapses,coords = self.snuddaLoad.find_synapses(pre_id=preID)

      postIDset = set(synapses[:,1]).intersection(self.possiblePostID)
      prePos = self.snuddaLoad.data["neuronPositions"][preID,:]
      
      for postID in postIDset:

        if(maxDist is not None):
          postPos = self.snuddaLoad.data["neuronPositions"][postID,:]
          if(np.linalg.norm(prePos-postPos) > maxDist):
            tooFarAway += 1
            continue
        
        # There is a bit of synaptic delay, so we can take voltage
        # at first timestep as baseline
        tIdx = np.where(np.logical_and(t <= time, time <= t + checkWidth))[0]
        synapseData.append((time[tIdx],voltage[postID][tIdx]))

    if(maxDist is not None):
      print("Number of pairs excluded, distance > " \
            + str(maxDist*1e6) + "mum : " + str(tooFarAway))
        
    # Fig names:
    traceFig = os.path.dirname(self.networkFile) \
      + "/figures/" + expType +"synapse-calibration-volt-traces-" \
      + self.preType + "-" + self.postType + ".pdf"

    histFig = os.path.dirname(self.networkFile) \
      + "/figures/" + expType + "synapse-calibration-volt-histogram-" \
      + self.preType + "-" + self.postType + ".pdf"

    figDir = os.path.dirname(self.networkFile) + "/figures"
    
    if(not os.path.exists(figDir)):
      os.makedirs(figDir)

    # Extract the amplitude of all voltage pulses
    amp = np.zeros((len(synapseData),))
    idxMax = np.zeros((len(synapseData),),dtype=int)    
    tMax = np.zeros((len(synapseData),))
        
    for i,(t,v) in enumerate(synapseData):
      
      # Save the largest deflection -- with sign
      idxMax[i] = np.argmax(np.abs(v-v[0]))
      tMax[i] = t[idxMax[i]] - t[0]
      amp[i] = v[idxMax[i]]-v[0]
      
    assert len(amp) > 0, "No responses... too short distance!"
      
    print("Min amp: " + str(np.min(amp)))
    print("Max amp: " + str(np.max(amp)))
    print("Mean amp: " + str(np.mean(amp)) + " +/- " + str(np.std(amp)))
    print("Amps: " + str(amp))
      
      
    # Now we have all synapse deflections in synapseData
    matplotlib.rcParams.update({'font.size': 22})

    sortIdx = np.argsort(amp)
    if(len(sortIdx) > nMaxShow):
      keepIdx = [sortIdx[int(np.round(x))] \
                 for x in np.linspace(0,len(sortIdx)-1,nMaxShow)]
    else:
      keepIdx = sortIdx
      
    plt.figure()
    for x in keepIdx:

      t,v = synapseData[x]
      
      plt.plot((t-t[0])*1e3,(v-v[0])*1e3,color="black")
      
    plt.scatter(tMax*1e3,amp*1e3,color="blue",marker=".",s=100)

    if((expType,self.preType,self.postType) in self.expData):
      expMean,expStd = self.expData[(expType,self.preType,self.postType)]

      tEnd = (t[-1]-t[0])*1e3

      axes = plt.gca()
      ay = axes.get_ylim()
      # Plot SD or 1.96 SD?
      plt.errorbar(tEnd,expMean,expStd,ecolor="red",
                   marker='o',color="red")

      modelMean = np.mean(amp)*1e3
      modelStd = np.std(amp)*1e3
      plt.errorbar(tEnd-2,modelMean,modelStd,ecolor="blue",
                   marker="o",color="blue")
      
      axes.set_ylim(ay)
      
      
    plt.xlabel("Time (ms)")
    plt.ylabel("Voltage (mV)")
    #plt.title(str(len(synapseData)) + " traces")
    plt.title(self.neuronName(self.preType) \
              + " to " + self.neuronName(self.postType))

    # Remove part of the frame
    plt.gca().spines["right"].set_visible(False)
    plt.gca().spines["top"].set_visible(False)
    
    plt.tight_layout()
    plt.ion()
    plt.show()
    plt.savefig(traceFig,dpi=300)

      

    plt.figure()
    plt.hist(amp*1e3,bins=20)
    plt.title(self.neuronName(self.preType) \
              + " to " + self.neuronName(self.postType))    
    plt.xlabel("Voltage deflection (mV)")

    # Remove part of the frame
    plt.gca().spines["right"].set_visible(False)
    plt.gca().spines["top"].set_visible(False)    
    
    plt.tight_layout()
    plt.show()
    plt.savefig(histFig,dpi=300)
    
    import pdb
    pdb.set_trace()

############################################################################

  def setupExpData(self):

    self.expData = dict()

    planertD1D1 = (0.24,0.15)
    planertD1D2 = (0.33,0.15)
    planertD2D1 = (0.27,0.09)
    planertD2D2 = (0.45,0.44)
    planertFSD1 = (4.8,4.9)
    planertFSD2 = (3.1,4.1)

    self.expData[("Planert2010","dSPN","dSPN")] = planertD1D1
    self.expData[("Planert2010","dSPN","iSPN")] = planertD1D2
    self.expData[("Planert2010","iSPN","dSPN")] = planertD2D1
    self.expData[("Planert2010","iSPN","iSPN")] = planertD2D2    
    self.expData[("Planert2010","FSN","dSPN")]  = planertFSD1
    self.expData[("Planert2010","FSN","iSPN")]  = planertFSD2    
Ejemplo n.º 12
0
  def analyse(self,expType,maxDist=None,nMaxShow=10):

    self.setupExpData()
    
    if(maxDist is None):
      maxDist = self.maxDist
    
    # Read the data
    self.snuddaLoad = SnuddaLoad(self.networkFile)
    self.data = self.snuddaLoad.data

    time,voltage = self.readVoltage(self.voltFile) # sets self.voltage
    checkWidth = 0.05

    # Generate current info structure
    # A current pulse to all pre synaptic neurons, one at a time
    self.preID = [x["neuronID"] \
                  for x in self.data["neurons"] \
                  if x["type"] == self.preType]

    self.possiblePostID = [x["neuronID"] \
                           for x in self.data["neurons"] \
                           if x["type"] == self.postType]
    
    # injInfo contains (preID,injStartTime)
    self.injInfo = zip(self.preID, \
                       self.injSpacing\
                       +self.injSpacing*np.arange(0,len(self.preID)))
    
    # For each pre synaptic neuron, find the voltage deflection in each
    # of its post synaptic neurons

    synapseData = []
    tooFarAway = 0
    
    for (preID,t) in self.injInfo:
      # Post synaptic neuron to preID
      synapses,coords = self.snuddaLoad.find_synapses(pre_id=preID)

      postIDset = set(synapses[:,1]).intersection(self.possiblePostID)
      prePos = self.snuddaLoad.data["neuronPositions"][preID,:]
      
      for postID in postIDset:

        if(maxDist is not None):
          postPos = self.snuddaLoad.data["neuronPositions"][postID,:]
          if(np.linalg.norm(prePos-postPos) > maxDist):
            tooFarAway += 1
            continue
        
        # There is a bit of synaptic delay, so we can take voltage
        # at first timestep as baseline
        tIdx = np.where(np.logical_and(t <= time, time <= t + checkWidth))[0]
        synapseData.append((time[tIdx],voltage[postID][tIdx]))

    if(maxDist is not None):
      print("Number of pairs excluded, distance > " \
            + str(maxDist*1e6) + "mum : " + str(tooFarAway))
        
    # Fig names:
    traceFig = os.path.dirname(self.networkFile) \
      + "/figures/" + expType +"synapse-calibration-volt-traces-" \
      + self.preType + "-" + self.postType + ".pdf"

    histFig = os.path.dirname(self.networkFile) \
      + "/figures/" + expType + "synapse-calibration-volt-histogram-" \
      + self.preType + "-" + self.postType + ".pdf"

    figDir = os.path.dirname(self.networkFile) + "/figures"
    
    if(not os.path.exists(figDir)):
      os.makedirs(figDir)

    # Extract the amplitude of all voltage pulses
    amp = np.zeros((len(synapseData),))
    idxMax = np.zeros((len(synapseData),),dtype=int)    
    tMax = np.zeros((len(synapseData),))
        
    for i,(t,v) in enumerate(synapseData):
      
      # Save the largest deflection -- with sign
      idxMax[i] = np.argmax(np.abs(v-v[0]))
      tMax[i] = t[idxMax[i]] - t[0]
      amp[i] = v[idxMax[i]]-v[0]
      
    assert len(amp) > 0, "No responses... too short distance!"
      
    print("Min amp: " + str(np.min(amp)))
    print("Max amp: " + str(np.max(amp)))
    print("Mean amp: " + str(np.mean(amp)) + " +/- " + str(np.std(amp)))
    print("Amps: " + str(amp))
      
      
    # Now we have all synapse deflections in synapseData
    matplotlib.rcParams.update({'font.size': 22})

    sortIdx = np.argsort(amp)
    if(len(sortIdx) > nMaxShow):
      keepIdx = [sortIdx[int(np.round(x))] \
                 for x in np.linspace(0,len(sortIdx)-1,nMaxShow)]
    else:
      keepIdx = sortIdx
      
    plt.figure()
    for x in keepIdx:

      t,v = synapseData[x]
      
      plt.plot((t-t[0])*1e3,(v-v[0])*1e3,color="black")
      
    plt.scatter(tMax*1e3,amp*1e3,color="blue",marker=".",s=100)

    if((expType,self.preType,self.postType) in self.expData):
      expMean,expStd = self.expData[(expType,self.preType,self.postType)]

      tEnd = (t[-1]-t[0])*1e3

      axes = plt.gca()
      ay = axes.get_ylim()
      # Plot SD or 1.96 SD?
      plt.errorbar(tEnd,expMean,expStd,ecolor="red",
                   marker='o',color="red")

      modelMean = np.mean(amp)*1e3
      modelStd = np.std(amp)*1e3
      plt.errorbar(tEnd-2,modelMean,modelStd,ecolor="blue",
                   marker="o",color="blue")
      
      axes.set_ylim(ay)
      
      
    plt.xlabel("Time (ms)")
    plt.ylabel("Voltage (mV)")
    #plt.title(str(len(synapseData)) + " traces")
    plt.title(self.neuronName(self.preType) \
              + " to " + self.neuronName(self.postType))

    # Remove part of the frame
    plt.gca().spines["right"].set_visible(False)
    plt.gca().spines["top"].set_visible(False)
    
    plt.tight_layout()
    plt.ion()
    plt.show()
    plt.savefig(traceFig,dpi=300)

      

    plt.figure()
    plt.hist(amp*1e3,bins=20)
    plt.title(self.neuronName(self.preType) \
              + " to " + self.neuronName(self.postType))    
    plt.xlabel("Voltage deflection (mV)")

    # Remove part of the frame
    plt.gca().spines["right"].set_visible(False)
    plt.gca().spines["top"].set_visible(False)    
    
    plt.tight_layout()
    plt.show()
    plt.savefig(histFig,dpi=300)
    
    import pdb
    pdb.set_trace()
Ejemplo n.º 13
0
    def test_prune(self):

        pruned_output = os.path.join(self.network_path,
                                     "network-synapses.hdf5")

        with self.subTest(stage="No-pruning"):

            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=None,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()
            sp = []

            # Load the pruned data and check it

            sl = SnuddaLoad(pruned_output)
            # TODO: Call a plot function to plot entire network with synapses and all

            self.assertEqual(sl.data["nSynapses"], (20 * 8 + 10 * 2) *
                             2)  # Update, now AMPA+GABA, hence *2 at end

            # This checks that all synapses are in order
            # The synapse sort order is destID, sourceID, synapsetype (channel model id).

            syn = sl.data["synapses"][:sl.data["nSynapses"], :]
            syn_order = (syn[:, 1] * len(self.sd.neurons) + syn[:, 0]
                         ) * 12 + syn[:, 6]  # The 12 is maxChannelModelID
            self.assertTrue((np.diff(syn_order) >= 0).all())

            # Note that channel model id is dynamically allocated, starting from 10 (GJ have ID 3)
            # Check that correct number of each type
            self.assertEqual(np.sum(sl.data["synapses"][:, 6] == 10),
                             20 * 8 + 10 * 2)
            self.assertEqual(np.sum(sl.data["synapses"][:, 6] == 11),
                             20 * 8 + 10 * 2)

            self.assertEqual(sl.data["nGapJunctions"], 4 * 4 * 4)
            gj = sl.data["gapJunctions"][:sl.data["nGapJunctions"], :2]
            gj_order = gj[:, 1] * len(self.sd.neurons) + gj[:, 0]
            self.assertTrue((np.diff(gj_order) >= 0).all())

        with self.subTest(stage="load-testing"):
            sl = SnuddaLoad(pruned_output, verbose=True)

            # Try and load a neuron
            n = sl.load_neuron(0)
            self.assertTrue(type(n) == NeuronMorphology)

            syn_ctr = 0
            for s in sl.synapse_iterator(chunk_size=50):
                syn_ctr += s.shape[0]
            self.assertEqual(syn_ctr, sl.data["nSynapses"])

            gj_ctr = 0
            for gj in sl.gap_junction_iterator(chunk_size=50):
                gj_ctr += gj.shape[0]
            self.assertEqual(gj_ctr, sl.data["nGapJunctions"])

            syn, syn_coords = sl.find_synapses(pre_id=14)
            self.assertTrue((syn[:, 0] == 14).all())
            self.assertEqual(syn.shape[0], 40)

            syn, syn_coords = sl.find_synapses(post_id=3)
            self.assertTrue((syn[:, 1] == 3).all())
            self.assertEqual(syn.shape[0], 36)

            cell_id_perm = sl.get_cell_id_of_type("ballanddoublestick",
                                                  random_permute=True,
                                                  num_neurons=28)
            cell_id = sl.get_cell_id_of_type("ballanddoublestick",
                                             random_permute=False)

            self.assertEqual(len(cell_id_perm), 28)
            self.assertEqual(len(cell_id), 28)

            for cid in cell_id_perm:
                self.assertTrue(cid in cell_id)

        # It is important merge file has synapses sorted with dest_id, source_id as sort order since during pruning
        # we assume this to be able to quickly find all synapses on post synaptic cell.
        # TODO: Also include the ChannelModelID in sorting check
        with self.subTest("Checking-merge-file-sorted"):

            for mf in [
                    "temp/synapses-for-neurons-0-to-28-MERGE-ME.hdf5",
                    "temp/gapJunctions-for-neurons-0-to-28-MERGE-ME.hdf5",
                    "network-synapses.hdf5"
            ]:

                merge_file = os.path.join(self.network_path, mf)

                sl = SnuddaLoad(merge_file, verbose=True)
                if "synapses" in sl.data:
                    syn = sl.data["synapses"][:sl.data["nSynapses"], :2]
                    syn_order = syn[:, 1] * len(self.sd.neurons) + syn[:, 0]
                    self.assertTrue((np.diff(syn_order) >= 0).all())

                if "gapJunctions" in sl.data:
                    gj = sl.data["gapJunctions"][:sl.data["nGapJunctions"], :2]
                    gj_order = gj[:, 1] * len(self.sd.neurons) + gj[:, 0]
                    self.assertTrue((np.diff(gj_order) >= 0).all())

        with self.subTest("synapse-f1"):
            # Test of f1
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-1.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it

            sl = SnuddaLoad(pruned_output, verbose=True)
            # Setting f1=0.5 in config should remove 50% of GABA synapses, but does so randomly, for AMPA we used f1=0.9
            gaba_id = sl.data["connectivityDistributions"][
                "ballanddoublestick",
                "ballanddoublestick"]["GABA"]["channelModelID"]
            ampa_id = sl.data["connectivityDistributions"][
                "ballanddoublestick",
                "ballanddoublestick"]["AMPA"]["channelModelID"]

            n_gaba = np.sum(sl.data["synapses"][:, 6] == gaba_id)
            n_ampa = np.sum(sl.data["synapses"][:, 6] == ampa_id)

            self.assertTrue((20 * 8 + 10 * 2) * 0.5 -
                            10 < n_gaba < (20 * 8 + 10 * 2) * 0.5 + 10)
            self.assertTrue((20 * 8 + 10 * 2) * 0.9 -
                            10 < n_ampa < (20 * 8 + 10 * 2) * 0.9 + 10)

        with self.subTest("synapse-softmax"):
            # Test of softmax
            testing_config_file = os.path.join(
                self.network_path, "network-config-test-2.json"
            )  # Only GABA synapses in this config
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it
            sl = SnuddaLoad(pruned_output)
            # Softmax reduces number of synapses
            self.assertTrue(sl.data["nSynapses"] < 20 * 8 + 10 * 2)

        with self.subTest("synapse-mu2"):
            # Test of mu2
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-3.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it
            sl = SnuddaLoad(pruned_output)
            # With mu2 having 2 synapses means 50% chance to keep them, having 1 will be likely to have it removed
            self.assertTrue(
                20 * 8 * 0.5 - 10 < sl.data["nSynapses"] < 20 * 8 * 0.5 + 10)

        with self.subTest("synapse-a3"):
            # Test of a3
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-4.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it
            sl = SnuddaLoad(pruned_output)

            # a3=0.6 means 40% chance to remove all synapses between a pair
            self.assertTrue(
                (20 * 8 + 10 * 2) * 0.6 -
                14 < sl.data["nSynapses"] < (20 * 8 + 10 * 2) * 0.6 + 14)

        with self.subTest("synapse-distance-dependent-pruning"):
            # Testing distance dependent pruning
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-5.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it
            sl = SnuddaLoad(pruned_output)

            # "1*(d >= 100e-6)" means we remove all synapses closer than 100 micrometers
            self.assertEqual(sl.data["nSynapses"], 20 * 6)
            self.assertTrue(
                (sl.data["synapses"][:, 8] >=
                 100).all())  # Column 8 -- distance to soma in micrometers

        # TODO: Need to do same test for Gap Junctions also -- but should be same results, since same codebase
        with self.subTest("gap-junction-f1"):
            # Test of f1
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-6.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it

            sl = SnuddaLoad(pruned_output)
            # Setting f1=0.7 in config should remove 30% of gap junctions, but does so randomly
            self.assertTrue(
                64 * 0.7 - 10 < sl.data["nGapJunctions"] < 64 * 0.7 + 10)

        with self.subTest("gap-junction-softmax"):
            # Test of softmax
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-7.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it
            sl = SnuddaLoad(pruned_output)
            # Softmax reduces number of synapses
            self.assertTrue(sl.data["nGapJunctions"] < 16 * 2 + 10)

        with self.subTest("gap-junction-mu2"):
            # Test of mu2
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-8.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it
            sl = SnuddaLoad(pruned_output)
            # With mu2 having 4 synapses means 50% chance to keep them, having 1 will be likely to have it removed
            self.assertTrue(
                64 * 0.5 - 10 < sl.data["nGapJunctions"] < 64 * 0.5 + 10)

        with self.subTest("gap-junction-a3"):
            # Test of a3
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-9.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it
            sl = SnuddaLoad(pruned_output, verbose=True)

            # a3=0.7 means 30% chance to remove all synapses between a pair
            self.assertTrue(
                64 * 0.7 - 10 < sl.data["nGapJunctions"] < 64 * 0.7 + 10)

        if False:  # Distance dependent pruning currently not implemented for gap junctions
            with self.subTest("gap-junction-distance-dependent-pruning"):
                # Testing distance dependent pruning
                testing_config_file = os.path.join(
                    self.network_path, "network-config-test-10.json")
                sp = SnuddaPrune(network_path=self.network_path,
                                 config_file=testing_config_file,
                                 verbose=True,
                                 keep_files=True)  # Use default config file
                sp.prune()

                # Load the pruned data and check it
                sl = SnuddaLoad(pruned_output, verbose=True)

                # "1*(d <= 120e-6)" means we remove all synapses further away than 100 micrometers
                self.assertEqual(sl.data["nGapJunctions"], 2 * 4 * 4)
                self.assertTrue(
                    (sl.data["gapJunctions"][:, 8] <=
                     120).all())  # Column 8 -- distance to soma in micrometers
Ejemplo n.º 14
0
    def load_network_info(self, network_file):

        self.neuron_info = SnuddaLoad(network_file).data["neurons"]
Ejemplo n.º 15
0
    def load_data(self, skip_time=0.0):

        network_file = os.path.join(self.network_path, "network-synapses.hdf5")
        network_info = SnuddaLoad(network_file)

        spike_data_file = os.path.join(self.network_path, "output_spikes.txt")
        n_neurons = network_info.data["nNeurons"]
        spike_data = self.load_spike_data(spike_data_file, n_neurons)

        # We need to figure out what neuronID correspond to that morphologies
        # Then figure out what input frequencies the different runs had

        neuron_id_list = [x["neuronID"] for x in network_info.data["neurons"]]
        neuron_name_list = [x["name"] for x in network_info.data["neurons"]]

        neuron_id_name_pairs = [(x["neuronID"], x["name"])
                                for x in network_info.data["neurons"]]

        # For each morphology-model we have a list of the run with that model
        neuron_id_lookup = dict()

        for neuron_id, neuron_name in neuron_id_name_pairs:
            if neuron_name in neuron_id_lookup:
                neuron_id_lookup[neuron_name].append(neuron_id)
            else:
                neuron_id_lookup[neuron_name] = [neuron_id]

        # Next identify number of inputs each run had
        input_config = self.load_input_config()

        n_inputs_lookup = dict()
        for neuron_label in input_config.keys():
            neuron_id = int(neuron_label)
            n_inputs = 0

            for input_type in input_config[neuron_label].keys():
                n_inputs += input_config[neuron_label][input_type]["nInputs"]

            n_inputs_lookup[neuron_id] = n_inputs

        frequency_data = dict()
        for neuron_name in neuron_id_lookup.keys():
            frequency_data[neuron_name] = dict()

        for neuron_name in neuron_name_list:
            for neuron_id in neuron_id_lookup[neuron_name]:

                if neuron_id not in n_inputs_lookup:
                    print(
                        f"No inputs for neuron_id={neuron_id}, ignoring. Please update your setup."
                    )
                    continue

                n_inputs = n_inputs_lookup[neuron_id]
                frequency_data[neuron_name][n_inputs] = self.extract_spikes(
                    spike_data=spike_data,
                    config_data=input_config,
                    neuron_id=neuron_id,
                    skip_time=skip_time)

        # TODO: Load voltage trace and warn for depolarisation blocking

        return frequency_data
Ejemplo n.º 16
0
class InspectInput(object):

  def __init__(self,networkFile,inputFile):

    self.networkFile = networkFile
    self.inputFile = inputFile

    # We just need this to identify which neuron is which
    self.network = SnuddaLoad(self.networkFile, load_synapses=False)

    self.inputData = h5py.File(inputFile,'r')


  def getMorphologies(self):

    return [os.path.basename(c["morphology"]) \
            for c in self.network.data["neurons"]]


  def getInputTypes(self,cellID):

    sCellID = [str(c) for c in cellID]
    inputTypes = set()
    
    for scID in sCellID:
      inpT = set([inp for inp in self.inputData["input"][scID]])
      inputTypes = inputTypes.union(inpT)
      
    return list(inputTypes)

  
  def checkInputRatio(self, neuronType, verbose=True):

    print(f"Counting inputs for {neuronType}")
    
    cellID = self.network.get_cell_id_of_type(neuronType)
    cellIDstr = [str(c) for c in cellID]
    
    inputCount = dict()

    cellMorph = self.getMorphologies()

    uniqueMorph = set([cellMorph[c] for c in cellID])

    inputTypeList = self.getInputTypes(cellID)
    
    for inp in inputTypeList:
      inputCount[inp] = dict()
      for um in uniqueMorph:
        inputCount[inp][um] = 0

    morphCounter = dict()

    for um in uniqueMorph:
      morphCounter[um] = 0
        
      
    # !!! TODO: We should split this by morphology also...
      
    for cID in self.inputData['input']:
      if cID in cellIDstr:
        morph = cellMorph[int(cID)]
        morphCounter[morph] += 1
                          
        for inp in inputTypeList:
          if inp in self.inputData['input'][cID]:
            nInput = len(self.inputData['input'][cID][inp]['nSpikes'])
            inputCount[inp][morph] += nInput

    for inp in inputTypeList:
      for um in uniqueMorph:
        avgInp = round(inputCount[inp][um]/morphCounter[um],1)
        print(f"{inp} morphology {um}: {avgInp} inputs")
        
    return inputCount
Ejemplo n.º 17
0
class SnuddaProject(object):

    # TODO: Add support for log files!!
    def __init__(self,
                 network_path,
                 rng=None,
                 random_seed=None,
                 h5libver=None):

        self.network_path = network_path
        self.network_info = None
        self.work_history_file = os.path.join(self.network_path, "log",
                                              "network-detect-worklog.hdf5")
        self.output_file_name = os.path.join(
            self.network_path, "network-projection-synapses.hdf5")

        max_synapses = 100000
        self.synapses = np.zeros((max_synapses, 13), dtype=np.int32)
        self.synapse_ctr = 0
        self.connectivity_distributions = dict()
        self.prototype_neurons = dict()
        self.next_channel_model_id = 10

        self.simulation_origo = None
        self.voxel_size = None

        # Parameters for the HDF5 writing, this affects write speed
        self.synapse_chunk_size = 10000
        self.h5compression = "lzf"

        config_file = os.path.join(self.network_path, "network-config.json")

        with open(config_file, "r") as f:
            self.config = json.load(f)

        if not h5libver:
            self.h5libver = "latest"
        else:
            self.h5libver = h5libver

        # Setup random generator,
        # TODO: this assumes serial execution. Update for parallel version
        if rng:
            self.rng = rng
        elif random_seed:
            self.rng = np.random.default_rng(random_seed)
        else:
            random_seed = self.config["RandomSeed"]["project"]
            self.rng = np.random.default_rng(random_seed)

        self.read_neuron_positions()
        self.read_prototypes()

    def read_neuron_positions(self):

        position_file = os.path.join(self.network_path,
                                     "network-neuron-positions.hdf5")
        self.network_info = SnuddaLoad(position_file)

        # We also need simulation origo and voxel size
        work_history_file = os.path.join(self.network_path, "log",
                                         "network-detect-worklog.hdf5")
        with h5py.File(work_history_file, "r") as work_hist:
            self.simulation_origo = work_hist["meta/simulationOrigo"][()]
            self.voxel_size = work_hist["meta/voxelSize"][()]

    # This is a simplified version of the prototype load in detect
    def read_prototypes(self):

        for name, definition in self.config["Neurons"].items():

            morph = definition["morphology"]

            self.prototype_neurons[name] = NeuronMorphology(name=name,
                                                            swc_filename=morph)

        # TODO: The code below is duplicate from detect.py, update so both use same code base
        for name, definition in self.config["Connectivity"].items():

            pre_type, post_type = name.split(",")

            con_def = definition.copy()

            for key in con_def:
                if key == "GapJunction":
                    con_def[key]["channelModelID"] = 3
                else:
                    con_def[key]["channelModelID"] = self.next_channel_model_id
                    self.next_channel_model_id += 1

                # Also if conductance is just a number, add std 0
                if type(con_def[key]["conductance"]) not in [list, tuple]:
                    con_def[key]["conductance"] = [
                        con_def[key]["conductance"], 0
                    ]

            self.connectivity_distributions[pre_type, post_type] = con_def

    def project(self):

        for (pre_type, post_type
             ), connection_info in self.connectivity_distributions.items():
            print(f"pre {pre_type}, post {post_type}")
            self.connect_projection_helper(pre_type, post_type,
                                           connection_info)

    def connect_projection_helper(self, pre_neuron_type, post_neuron_type,
                                  connection_info):

        for connection_type, con_info in connection_info.items():

            if "projectionFile" not in con_info:
                # Not a projection, skipping
                continue

            projection_file = con_info["projectionFile"]
            with open(projection_file, "r") as f:
                projection_data = json.load(f)

            if "projectionName" in con_info:
                proj_name = con_info["projectionName"]
                projection_source = np.array(
                    projection_data[proj_name]["source"]) * 1e-6
                projection_destination = np.array(
                    projection_data[proj_name]["destination"]) * 1e-6
            else:
                projection_source = np.array(projection_data["source"]) * 1e-6
                projection_destination = np.array(
                    projection_data["destination"]) * 1e-6

            if "projectionRadius" in con_info:
                projection_radius = con_info["projectionRadius"]
            else:
                projection_radius = None  # Find the closest neurons

            # TODO: Add projectionDensity later
            # if "projectionDensity" in con_info:
            #     projection_density = con_info["projectionDensity"]
            # else:
            #    projection_density = None  # All neurons within projection radius equally likely

            if "numberOfTargets" in con_info:
                if type(con_info["numberOfTargets"]) == list:
                    number_of_targets = np.array(
                        con_info["numberOfTargets"])  # mean, std
                else:
                    number_of_targets = np.array(
                        [con_info["numberOfTargets"], 0])

            if "numberOfSynapses" in con_info:
                if type(con_info["numberOfSynapses"]) == list:
                    number_of_synapses = np.array(
                        con_info["numberOfSynapses"])  # mean, std
                else:
                    number_of_synapses = np.array(
                        [con_info["numberOfSynapses"], 0])
            else:
                number_of_synapses = np.array([1, 0])

            if "dendriteSynapseDensity" in con_info:
                dendrite_synapse_density = con_info["dendriteSynapseDensity"]

            if type(con_info["conductance"]) == list:
                conductance_mean, conductance_std = con_info["conductance"]
            else:
                conductance_mean, conductance_std = con_info["conductance"], 0

            # The channelModelID is added to config information
            channel_model_id = con_info["channelModelID"]

            # Find all the presynaptic neurons in the network
            pre_id_list = self.network_info.get_cell_id_of_type(
                pre_neuron_type)
            pre_positions = self.network_info.data["neuronPositions"][
                pre_id_list, :]

            # Find all the postsynaptic neurons in the network
            post_id_list = self.network_info.get_cell_id_of_type(
                post_neuron_type)
            post_name_list = [
                self.network_info.data["name"][x] for x in post_id_list
            ]
            post_positions = self.network_info.data["neuronPositions"][
                post_id_list, :]

            # For each presynaptic neuron, find their target regions.
            # -- if you want two distinct target regions, you have to create two separate maps
            target_centres = griddata(points=projection_source,
                                      values=projection_destination,
                                      xi=pre_positions,
                                      method="linear")

            num_targets = self.rng.normal(number_of_targets[0],
                                          number_of_targets[1],
                                          len(pre_id_list)).astype(int)

            # For each presynaptic neuron, using the supplied map, find the potential post synaptic targets
            for pre_id, centre_pos, n_targets in zip(pre_id_list,
                                                     target_centres,
                                                     num_targets):

                d = np.linalg.norm(centre_pos - post_positions,
                                   axis=1)  # !! Double check right axis
                d_idx = np.argsort(d)

                if projection_radius:
                    d_idx = d_idx[np.where(d[d_idx] <= projection_radius)[0]]
                    if len(d_idx) > n_targets:
                        d_idx = self.rng.permutation(d_idx)[:n_targets]

                elif len(d_idx) > n_targets:
                    d_idx = d_idx[:n_targets]

                target_id = [post_id_list[x] for x in d_idx]
                target_name = [post_name_list[x] for x in d_idx]
                axon_dist = d[d_idx]

                n_synapses = self.rng.normal(number_of_synapses[0],
                                             number_of_synapses[1],
                                             len(target_id)).astype(int)

                for t_id, t_name, n_syn, ax_dist in zip(
                        target_id, target_name, n_synapses, axon_dist):

                    # We need to place neuron correctly in space (work on clone),
                    # so that synapse coordinates are correct
                    morph_prototype = self.prototype_neurons[t_name]
                    position = self.network_info.data["neurons"][t_id][
                        "position"]
                    rotation = self.network_info.data["neurons"][t_id][
                        "rotation"]
                    morph = morph_prototype.clone(position=position,
                                                  rotation=rotation)

                    # We are not guaranteed to get n_syn positions, so use len(sec_x) to get how many after
                    # TODO: Fix so dendrite_input_locations always returns  n_syn synapses
                    xyz, sec_id, sec_x, dist_to_soma = morph.dendrite_input_locations(
                        dendrite_synapse_density,
                        self.rng,
                        num_locations=n_syn)

                    # We need to convert xyz into voxel coordinates to match data format of synapse matrix
                    xyz = np.round(
                        (xyz - self.simulation_origo) / self.voxel_size)

                    cond = self.rng.normal(conductance_mean, conductance_std,
                                           len(sec_x))
                    cond = np.maximum(cond, conductance_mean *
                                      0.1)  # Lower bound, prevent negative.
                    param_id = self.rng.integers(1000000, size=len(sec_x))

                    # TODO: Add code to extend synapses matrix if it is full
                    for i in range(len(sec_id)):
                        self.synapses[self.synapse_ctr, :] = \
                            [pre_id, t_id,
                             xyz[i, 0], xyz[i, 1], xyz[i, 2],
                             -1,  # Hypervoxelid
                             channel_model_id,
                             ax_dist, dist_to_soma[i],
                             sec_id[i], sec_x[i] * 1000,
                             cond[i] * 1e12, param_id[i]]
                        self.synapse_ctr += 1

    def write(self):

        # Before writing synapses, lets make sure they are sorted.
        # Sort order: columns 1 (dest), 0 (src), 6 (synapse type)
        sort_idx = np.lexsort(self.synapses[:self.synapse_ctr,
                                            [6, 0, 1]].transpose())
        self.synapses[:self.synapse_ctr, :] = self.synapses[sort_idx, :]

        # Write synapses to file
        with h5py.File(self.output_file_name, "w",
                       libver=self.h5libver) as out_file:

            out_file.create_dataset("config", data=json.dumps(self.config))

            network_group = out_file.create_group("network")
            network_group.create_dataset(
                "synapses",
                data=self.synapses[:self.synapse_ctr, :],
                dtype=np.int32,
                chunks=(self.synapse_chunk_size, 13),
                maxshape=(None, 13),
                compression=self.h5compression)

            network_group.create_dataset("nSynapses",
                                         data=self.synapse_ctr,
                                         dtype=int)
            network_group.create_dataset(
                "nNeurons", data=self.network_info.data["nNeurons"], dtype=int)

            # This is useful so the merge_helper knows if they need to search this file for synapses
            all_target_id = np.unique(self.synapses[:self.synapse_ctr, 1])
            network_group.create_dataset("allTargetId", data=all_target_id)

            # This creates a lookup that is used for merging later
            synapse_lookup = SnuddaDetect.create_lookup_table(
                data=self.synapses,
                n_rows=self.synapse_ctr,
                data_type="synapses",
                num_neurons=self.network_info.data["nNeurons"],
                max_synapse_type=self.next_channel_model_id)

            network_group.create_dataset("synapseLookup", data=synapse_lookup)
            network_group.create_dataset("maxChannelTypeID",
                                         data=self.next_channel_model_id,
                                         dtype=int)

        # We also need to update the work history file with how many synapses we created
        # for the projections between volumes

        with h5py.File(self.work_history_file, "a",
                       libver=self.h5libver) as hist_file:
            if "nProjectionSynapses" in hist_file:
                hist_file["nProjectionSynapses"][()] = self.synapse_ctr
            else:
                hist_file.create_dataset("nProjectionSynapses",
                                         data=self.synapse_ctr,
                                         dtype=int)
Ejemplo n.º 18
0
class PlotNetwork(object):
    def __init__(self, network):

        if os.path.isdir(network):
            network_file = os.path.join(network, "network-synapses.hdf5")
        else:
            network_file = network

        self.network_file = network_file
        self.network_path = os.path.dirname(self.network_file)

        self.sl = SnuddaLoad(self.network_file)
        self.prototype_neurons = dict()

    def close(self):
        self.sl.close()

    def plot(self,
             plot_axon=True,
             plot_dendrite=True,
             plot_synapses=True,
             title=None,
             title_pad=None,
             show_axis=True,
             elev_azim=None,
             fig_name=None,
             dpi=600,
             colour_population_unit=False):

        if type(plot_axon) == bool:
            plot_axon = np.ones(
                (self.sl.data["nNeurons"], ), dtype=bool) * plot_axon

        if type(plot_dendrite) == bool:
            plot_dendrite = np.ones(
                (self.sl.data["nNeurons"], ), dtype=bool) * plot_dendrite

        assert len(plot_axon) == len(plot_dendrite) == len(
            self.sl.data["neurons"])

        fig = plt.figure(figsize=(6, 6.5))
        ax = fig.gca(projection='3d')

        if "simulationOrigo" in self.sl.data:
            simulation_origo = self.sl.data["simulationOrigo"]
        else:
            simulation_origo = np.array([0, 0, 0])

        if "populationUnit" in self.sl.data and colour_population_unit:
            population_unit = self.sl.data["populationUnit"]
            pop_units = sorted(list(set(population_unit)))
            cmap = plt.get_cmap('tab20', len(pop_units))
            colour_lookup_helper = dict()

            for idx, pu in enumerate(population_unit):
                if pu > 0:
                    colour_lookup_helper[idx] = cmap(pu)
                else:
                    colour_lookup_helper[idx] = 'lightgrey'

            colour_lookup = lambda x: colour_lookup_helper[x]
        else:
            colour_lookup = lambda x: 'black'

        # Plot neurons
        for neuron_info, pa, pd in zip(self.sl.data["neurons"], plot_axon,
                                       plot_dendrite):

            soma_colour = colour_lookup(neuron_info["neuronID"])
            neuron = self.load_neuron(neuron_info)
            neuron.plot_neuron(
                axis=ax,
                plot_axon=pa,
                plot_dendrite=pd,
                soma_colour=soma_colour,
                axon_colour="darksalmon",  #"maroon",
                dend_colour="silver"
            )  # Can also write colours as (0, 0, 0) -- rgb

        # Plot synapses
        if plot_synapses and "synapseCoords" in self.sl.data:
            ax.scatter(self.sl.data["synapseCoords"][:, 0],
                       self.sl.data["synapseCoords"][:, 1],
                       self.sl.data["synapseCoords"][:, 2],
                       color=(0.1, 0.1, 0.1))

            plt.figtext(0.5,
                        0.20,
                        f"{self.sl.data['nSynapses']} synapses",
                        ha="center",
                        fontsize=18)

        if elev_azim:
            ax.view_init(elev_azim[0], elev_azim[1])

        if not show_axis:
            plt.axis("off")

        plt.tight_layout()

        if title is None:
            title = ""

        if title_pad is not None:
            plt.rcParams[
                'axes.titley'] = 0.95  # y is in axes-relative co-ordinates.
            plt.rcParams['axes.titlepad'] = title_pad  # pad is in points...

        plt.title(title, fontsize=18)

        # ax.dist = 8

        self.equal_axis(ax)

        if fig_name is not None:
            fig_path = os.path.join(self.network_path, "figures", fig_name)
            if not os.path.exists(os.path.dirname(fig_path)):
                os.mkdir(os.path.dirname(fig_path))
            plt.savefig(fig_path, dpi=dpi, bbox_inches="tight")

        plt.ion()
        plt.show()

        return plt, ax

    def load_neuron(self, neuron_info=None, neuron_id=None):

        assert (neuron_info is None) + (
            neuron_id is None) == 1, "Specify neuron_info or neuron_id"

        if neuron_id is not None:
            print(f"Using id {neuron_id}")
            neuron_info = self.sl.data["neurons"][neuron_id]

        neuron_name = neuron_info["name"]

        if neuron_name not in self.prototype_neurons:
            self.prototype_neurons[neuron_name] = NeuronMorphology(
                name=neuron_name, swc_filename=neuron_info["morphology"])

        neuron = self.prototype_neurons[neuron_name].clone()
        neuron.place(rotation=neuron_info["rotation"],
                     position=neuron_info["position"])

        return neuron

    def plot_populations(self):

        fig = plt.figure(figsize=(6, 6.5))
        ax = fig.gca(projection='3d')

        assert "populationUnit" in self.sl.data

        population_unit = self.sl.data["populationUnit"]
        pop_units = sorted(list(set(population_unit)))
        cmap = plt.get_cmap('tab20', len(pop_units))
        neuron_colours = []

        for idx, pu in enumerate(population_unit):
            if pu > 0:
                neuron_colours.append(list(cmap(pu)))
            else:
                neuron_colours.append([0.7, 0.7, 0.7, 1.0])

        neuron_colours = np.array(neuron_colours)
        positions = self.sl.data["neuronPositions"]

        ax.scatter(positions[:, 0],
                   positions[:, 1],
                   positions[:, 2],
                   c=neuron_colours,
                   marker='o',
                   s=20)

        self.equal_axis(ax)

    def equal_axis(self, ax):

        x_min, x_max = ax.get_xlim()
        y_min, y_max = ax.get_ylim()
        z_min, z_max = ax.get_zlim()

        x_mean = (x_min + x_max) / 2
        y_mean = (y_min + y_max) / 2
        z_mean = (z_min + z_max) / 2

        max_half_width = np.max([x_max - x_min, y_max - y_min, z_max - z_min
                                 ]) / 2

        ax.set_xlim(x_mean - max_half_width, x_mean + max_half_width)
        ax.set_ylim(y_mean - max_half_width, y_mean + max_half_width)
        ax.set_zlim(z_mean - max_half_width, z_mean + max_half_width)
Ejemplo n.º 19
0
class PlotInput(object):
    def __init__(self, input_file, network_path=None):

        self.input_data = None

        if input_file:
            self.load_input(input_file)

        if not network_path:
            network_path = os.path.dirname(input_file)

        network_file = os.path.join(network_path, "network-synapses.hdf5")

        if os.path.exists(network_file):
            self.network_info = SnuddaLoad(network_file)
        else:
            print(
                f"Specify a network_path with a network file, to get neuron type in figure"
            )
            self.network_info = None

    def load_input(self, input_file):
        self.input_data = h5py.File(input_file, "r")

    def extract_input(self, input_target):

        data = OrderedDict()

        if input_target in self.input_data["input"]:
            for input_type in self.input_data["input"][input_target]:
                input_info = self.input_data["input"][input_target][input_type]

                data[input_type] = input_info["spikes"][()]

        return data

    def get_neuron_name(self, neuron_id):

        neuron_id = int(neuron_id)

        if self.network_info:
            neuron_name = self.network_info.data["neurons"][neuron_id]["name"]
        else:
            neuron_name = ""

        return neuron_name

    def plot_input(self, neuron_type, num_neurons, fig_size=None):

        neuron_id = self.network_info.get_cell_id_of_type(
            neuron_type=neuron_type,
            num_neurons=num_neurons,
            random_permute=True)
        target_id = [str(x) for x in np.sort(neuron_id)]

        if len(target_id) == 0:
            print(f"No neurons of type {neuron_type}")
            return

        self.plot_input_to_target(target_id, fig_size=fig_size)

    def plot_input_population_unit(self,
                                   population_unit_id,
                                   num_neurons,
                                   neuron_type=None,
                                   fig_size=None):

        if not population_unit_id:
            population_unit_id = 0  # 0 = no population

        assert type(population_unit_id) == int

        neuron_id = self.network_info.get_population_unit_members(
            population_unit_id)

        assert np.array([
            self.network_info.data["populationUnit"][x] == population_unit_id
            for x in neuron_id
        ]).all()

        if neuron_type:
            neuron_id2 = self.network_info.get_cell_id_of_type(neuron_type)
            neuron_id = list(set(neuron_id).intersection(set(neuron_id2)))

        if num_neurons:
            num_neurons = min(num_neurons, len(neuron_id))
            neuron_id = np.random.permutation(neuron_id)[:num_neurons]

        target_id = [str(x) for x in np.sort(neuron_id)]

        if len(target_id) == 0:
            print(f"No neurons with population id {population_unit_id}")
            return

        assert np.array([
            self.network_info.data["populationUnit"][int(x)] ==
            population_unit_id for x in target_id
        ]).all()

        self.plot_input_to_target(target_id, fig_size=fig_size)

    def plot_input_to_target(self, input_target, fig_size=None):

        if not fig_size:
            fig_size = (10, 5)

        if type(input_target) != list:
            input_target = [input_target]

        # Make sure each target is a str
        input_target = [str(x) for x in input_target]
        colours = cm.get_cmap('tab20', len(input_target) * 2)

        y_pos = 0
        input_ctr = 0
        plt.figure(figsize=fig_size)

        ytick_pos = []
        ytick_label = []

        for it in input_target:

            data = self.extract_input(it)

            for input_type in data:

                y_pos_start = y_pos
                for spike_train in data[input_type]:
                    idx = np.where(spike_train > 0)[0]
                    plt.scatter(spike_train[idx],
                                y_pos * np.ones((len(idx), )),
                                color=colours(input_ctr),
                                marker='.',
                                s=7)
                    y_pos += 1

                y_pos_avg = (y_pos + y_pos_start) / 2
                ytick_pos.append(y_pos_avg)
                ytick_label.append(
                    f"{input_type}→{self.get_neuron_name(it)} ({it})")

                input_ctr += 1
                y_pos += 5

            y_pos += 5

        # Add yticks
        ax = plt.gca()
        ax.invert_yaxis()
        ax.set_yticks(ytick_pos)
        ax.set_yticklabels(ytick_label)
        ax.set_xlabel("Time (s)")
        plt.ion()
        plt.show()
Ejemplo n.º 20
0
    def test_project(self):

        # Are there connections dSPN->iSPN
        from snudda.utils.load import SnuddaLoad
        network_file = os.path.join(self.network_path, "network-synapses.hdf5")
        sl = SnuddaLoad(network_file)

        dspn_id_list = sl.get_cell_id_of_type("dSPN")
        ispn_id_list = sl.get_cell_id_of_type("iSPN")

        tot_proj_ctr = 0

        for dspn_id in dspn_id_list:
            for ispn_id in ispn_id_list:

                synapses, synapse_coords = sl.find_synapses(pre_id=dspn_id,
                                                            post_id=ispn_id)
                if synapses is not None:
                    tot_proj_ctr += synapses.shape[0]

        with self.subTest(stage="projection_exists"):
            # There should be projection synapses between dSPN and iSPN in this toy example
            self.assertTrue(tot_proj_ctr > 0)

        tot_dd_syn_ctr = 0
        for dspn_id in dspn_id_list:
            for dspn_id2 in dspn_id_list:

                synapses, synapse_coords = sl.find_synapses(pre_id=dspn_id,
                                                            post_id=dspn_id2)
                if synapses is not None:
                    tot_dd_syn_ctr += synapses.shape[0]

        tot_ii_syn_ctr = 0
        for ispn_id in ispn_id_list:
            for ispn_id2 in ispn_id_list:

                synapses, synapse_coords = sl.find_synapses(pre_id=ispn_id,
                                                            post_id=ispn_id2)
                if synapses is not None:
                    tot_ii_syn_ctr += synapses.shape[0]

        with self.subTest(stage="normal_synapses_exist"):
            # In this toy example neurons are quite sparsely placed, but we should have at least some
            # synapses
            self.assertTrue(tot_dd_syn_ctr > 0)
            self.assertTrue(tot_ii_syn_ctr > 0)

        # We need to run in parallel also to verify we get same result (same random seed)

        serial_synapses = sl.data["synapses"].copy()
        del sl  # Close old file so we can overwrite it

        os.environ["IPYTHONDIR"] = os.path.join(os.path.abspath(os.getcwd()),
                                                ".ipython")
        os.environ["IPYTHON_PROFILE"] = "default"
        os.system(
            "ipcluster start -n 4 --profile=$IPYTHON_PROFILE --ip=127.0.0.1&")
        time.sleep(10)

        # Run place, detect and prune in parallel by passing rc
        from ipyparallel import Client
        u_file = os.path.join(".ipython", "profile_default", "security",
                              "ipcontroller-client.json")
        rc = Client(url_file=u_file, timeout=120, debug=False)
        d_view = rc.direct_view(
            targets='all')  # rc[:] # Direct view into clients

        from snudda.detect.detect import SnuddaDetect
        sd = SnuddaDetect(network_path=self.network_path,
                          hyper_voxel_size=100,
                          rc=rc,
                          verbose=True)
        sd.detect()

        from snudda.detect.project import SnuddaProject
        # TODO: Currently SnuddaProject only runs in serial
        sp = SnuddaProject(network_path=self.network_path)
        sp.project()
        sp.write()

        from snudda.detect.prune import SnuddaPrune
        # Prune has different methods for serial and parallel execution, important to test it!
        sp = SnuddaPrune(network_path=self.network_path, rc=rc, verbose=True)
        sp.prune()

        with self.subTest(stage="check-parallel-identical"):
            sl2 = SnuddaLoad(network_file)
            parallel_synapses = sl2.data["synapses"].copy()

            # ParameterID, sec_X etc are randomised in hyper voxel, so you need to use same
            # hypervoxel size for reproducability between serial and parallel execution

            # All synapses should be identical regardless of serial or parallel execution path
            self.assertTrue((serial_synapses == parallel_synapses).all())

        os.system("ipcluster stop")
Ejemplo n.º 21
0
class SnuddaExportConnectionMatrix(object):

    def __init__(self, in_file, out_file, save_sparse=True):

        self.sl = SnuddaLoad(in_file)

        self.outFile = out_file
        self.out_file_meta = out_file + "-meta"

        data = self.sl.data

        con_mat = self.create_con_mat()
        neuron_type = [x["type"] for x in data["neurons"]]
        neuron_name = [x["name"] for x in data["neurons"]]
        morph_file = [data["morph"][nn]["location"] for nn in neuron_name]
        pos = data["neuronPositions"]

        print("Writing " + self.outFile + " (row = src, column=dest)")
        if save_sparse:
            x_pos, y_pos = np.where(con_mat)
            sparse_data = np.zeros((len(x_pos), 3), dtype=int)
            for idx, (x, y) in enumerate(zip(x_pos, y_pos)):
                sparse_data[idx, :] = [x, y, con_mat[x, y]]

            np.savetxt(self.outFile, sparse_data, delimiter=",", fmt="%d")

            # Test to verify
            for row in sparse_data:
                assert con_mat[row[0], row[1]] == row[2]
        else:
            np.savetxt(self.outFile, con_mat, delimiter=",", fmt="%d")

        print("Writing " + self.out_file_meta)
        with open(self.out_file_meta, "w") as f_out_meta:
            for i, (nt, nn, p, mf) in enumerate(zip(neuron_type, neuron_name, pos, morph_file)):
                s = "%d,%s,%s,%f,%f,%f,%s\n" % (i, nt, nn, p[0], p[1], p[2], mf)
                f_out_meta.write(s)
            f_out_meta.close()

        # import pdb
        # pdb.set_trace()

    ############################################################################

    def create_con_mat(self):

        num_neurons = self.sl.data["nNeurons"]

        con_mat = np.zeros((num_neurons, num_neurons), dtype=int)
        cnt = 0
        pre, post = 0, 0

        for syn_chunk in self.sl.synapse_iterator(data_type="synapses"):
            for syn in syn_chunk:
                p1 = syn[0]
                p2 = syn[1]

                if p1 == pre and p2 == post:
                    cnt += 1
                else:
                    con_mat[pre, post] += cnt
                    pre = p1
                    post = p2
                    cnt = 1

        con_mat[pre, post] += cnt
        cnt = 0

        assert np.sum(np.sum(con_mat)) == self.sl.data["nSynapses"], \
            "Synapse numbers in connection matrix does not match"

        return con_mat