예제 #1
0
파일: test_input.py 프로젝트: wthun/Snudda
    def setUp(self):

        os.chdir(os.path.dirname(__file__))

        self.network_path = os.path.join("networks", "network_testing_input")
        self.config_file = os.path.join(self.network_path,
                                        "network-config.json")
        self.position_file = os.path.join(self.network_path,
                                          "network-neuron-positions.hdf5")
        self.save_file = os.path.join(self.network_path, "voxels",
                                      "network-putative-synapses.hdf5")

        # Setup network so we can test input generation
        from snudda.init.init import SnuddaInit
        cell_spec = os.path.join(os.path.dirname(__file__), "validation")
        cnc = SnuddaInit(struct_def={},
                         config_file=self.config_file,
                         random_seed=1234)
        cnc.define_striatum(num_dSPN=5,
                            num_iSPN=0,
                            num_FS=5,
                            num_LTS=0,
                            num_ChIN=0,
                            volume_type="cube",
                            neurons_dir=cell_spec)
        cnc.write_json(self.config_file)

        # Place neurons
        from snudda.place.place import SnuddaPlace
        npn = SnuddaPlace(
            config_file=self.config_file,
            log_file=None,
            verbose=True,
            d_view=
            None,  # TODO: If d_view is None code run sin serial, add test parallel
            h5libver="latest")
        npn.parse_config()
        npn.write_data(self.position_file)

        # Detect
        self.sd = SnuddaDetect(config_file=self.config_file,
                               position_file=self.position_file,
                               save_file=self.save_file,
                               rc=None,
                               hyper_voxel_size=120,
                               verbose=True)

        self.sd.detect(restart_detection_flag=True)

        # Prune
        self.network_file = os.path.join(self.network_path,
                                         "network-synapses.hdf5")

        sp = SnuddaPrune(network_path=self.network_path,
                         config_file=None)  # Use default config file
        sp.prune(pre_merge_only=False)
예제 #2
0
    def prune_synapses(self, args):
        # self.networkPath = args.path
        print("Prune synapses")
        print("Network path: " + str(self.network_path))

        from snudda.detect.prune import SnuddaPrune

        log_filename = os.path.join(self.network_path, "log",
                                    "logFile-synapse-pruning.txt")

        random_seed = args.randomseed

        self.setup_log_file(log_filename)  # sets self.logfile

        if args.parallel:
            self.setup_parallel()  # sets self.d_view and self.lb_view

        # Optionally set this
        scratch_path = None

        if args.merge_only:
            pre_merge_only = True
        else:
            pre_merge_only = False

        print(f"preMergeOnly : {pre_merge_only}")

        if args.h5legacy:
            h5libver = "earliest"
        else:
            h5libver = "latest"  # default

        sp = SnuddaPrune(network_path=self.network_path,
                         logfile=self.logfile,
                         logfile_name=log_filename,
                         config_file=args.config_file,
                         d_view=self.d_view,
                         lb_view=self.lb_view,
                         scratch_path=scratch_path,
                         h5libver=h5libver,
                         random_seed=random_seed,
                         verbose=args.verbose)

        sp.prune(pre_merge_only=pre_merge_only)

        self.stop_parallel()
        self.close_log_file()
예제 #3
0
    def prune_network(self, pruning_config=None, fig_name=None, title=None, verbose=False, plot_network=True,
                      random_seed=None, n_repeats=None):

        if n_repeats is None:
            n_repeats = self.n_repeats

        work_log = os.path.join(self.network_path, "log", "network-detect-worklog.hdf5")
        pruned_output = os.path.join(self.network_path, "network-synapses.hdf5")

        if pruning_config is not None and not os.path.exists(pruning_config):
            pruning_config = os.path.join(self.network_path, pruning_config)

        # We keep temp files
        sp = SnuddaPrune(network_path=self.network_path, config_file=pruning_config,
                         verbose=verbose, keep_files=True, random_seed=random_seed)  # Use default config file
        sp.prune()

        n_synapses = sp.out_file["network/synapses"].shape[0]
        n_gap_junctions = sp.out_file["network/gapJunctions"].shape[0]

        sp = []

        plot_axon = True
        plot_dendrite = True
        #plot_axon = np.ones((20,), dtype=bool)
        #plot_dendrite = np.ones((20,), dtype=bool)
        #plot_axon[:10] = False
        #plot_dendrite[10:] = False

        if plot_network:
            pn = PlotNetwork(pruned_output)
            plt, ax = pn.plot(fig_name=fig_name, show_axis=False,
                              plot_axon=plot_axon, plot_dendrite=plot_dendrite,
                              title=title, title_pad=-14,
                              elev_azim=(90, 0))

            if n_repeats > 1:
                n_syn_mean, n_syn_std, _, _ = self.gather_pruning_statistics(pruning_config=pruning_config, n_repeats=n_repeats)
                plt.figtext(0.5, 0.15, f"(${n_syn_mean:.1f} \pm {n_syn_std:.1f}$)", ha="center", fontsize=16)
                plt.savefig(fig_name, dpi=300, bbox_inches='tight')

            # Load the pruned data and check it
            # sl = SnuddaLoad(pruned_output)

        return n_synapses, n_gap_junctions
예제 #4
0
    def prune_network(self, pruning_config=None, fig_name=None, title=None):

        work_log = os.path.join(self.network_path, "log", "network-detect-worklog.hdf5")
        pruned_output = os.path.join(self.network_path, "network-synapses.hdf5")

        if pruning_config is not None and not os.path.exists(pruning_config):
            pruning_config = os.path.join(self.network_path, pruning_config)

        sp = SnuddaPrune(network_path=self.network_path, config_file=pruning_config)  # Use default config file
        sp.prune(pre_merge_only=False)
        sp = []

        plot_axon = True
        plot_dendrite = True
        #plot_axon = np.ones((20,), dtype=bool)
        #plot_dendrite = np.ones((20,), dtype=bool)
        #plot_axon[:10] = False
        #plot_dendrite[10:] = False

        pn = PlotNetwork(pruned_output)
        plt, ax = pn.plot(fig_name=fig_name, show_axis=False,
                          plot_axon=plot_axon, plot_dendrite=plot_dendrite,
                          title=title, title_pad=-14,
                          elev_azim=(90, 0))
예제 #5
0
    def test_prune(self):

        pruned_output = os.path.join(self.network_path,
                                     "network-synapses.hdf5")

        with self.subTest(stage="No-pruning"):

            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=None,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()
            sp = []

            # Load the pruned data and check it

            sl = SnuddaLoad(pruned_output)
            # TODO: Call a plot function to plot entire network with synapses and all

            self.assertEqual(sl.data["nSynapses"], (20 * 8 + 10 * 2) *
                             2)  # Update, now AMPA+GABA, hence *2 at end

            # This checks that all synapses are in order
            # The synapse sort order is destID, sourceID, synapsetype (channel model id).

            syn = sl.data["synapses"][:sl.data["nSynapses"], :]
            syn_order = (syn[:, 1] * len(self.sd.neurons) + syn[:, 0]
                         ) * 12 + syn[:, 6]  # The 12 is maxChannelModelID
            self.assertTrue((np.diff(syn_order) >= 0).all())

            # Note that channel model id is dynamically allocated, starting from 10 (GJ have ID 3)
            # Check that correct number of each type
            self.assertEqual(np.sum(sl.data["synapses"][:, 6] == 10),
                             20 * 8 + 10 * 2)
            self.assertEqual(np.sum(sl.data["synapses"][:, 6] == 11),
                             20 * 8 + 10 * 2)

            self.assertEqual(sl.data["nGapJunctions"], 4 * 4 * 4)
            gj = sl.data["gapJunctions"][:sl.data["nGapJunctions"], :2]
            gj_order = gj[:, 1] * len(self.sd.neurons) + gj[:, 0]
            self.assertTrue((np.diff(gj_order) >= 0).all())

        with self.subTest(stage="load-testing"):
            sl = SnuddaLoad(pruned_output, verbose=True)

            # Try and load a neuron
            n = sl.load_neuron(0)
            self.assertTrue(type(n) == NeuronMorphology)

            syn_ctr = 0
            for s in sl.synapse_iterator(chunk_size=50):
                syn_ctr += s.shape[0]
            self.assertEqual(syn_ctr, sl.data["nSynapses"])

            gj_ctr = 0
            for gj in sl.gap_junction_iterator(chunk_size=50):
                gj_ctr += gj.shape[0]
            self.assertEqual(gj_ctr, sl.data["nGapJunctions"])

            syn, syn_coords = sl.find_synapses(pre_id=14)
            self.assertTrue((syn[:, 0] == 14).all())
            self.assertEqual(syn.shape[0], 40)

            syn, syn_coords = sl.find_synapses(post_id=3)
            self.assertTrue((syn[:, 1] == 3).all())
            self.assertEqual(syn.shape[0], 36)

            cell_id_perm = sl.get_cell_id_of_type("ballanddoublestick",
                                                  random_permute=True,
                                                  num_neurons=28)
            cell_id = sl.get_cell_id_of_type("ballanddoublestick",
                                             random_permute=False)

            self.assertEqual(len(cell_id_perm), 28)
            self.assertEqual(len(cell_id), 28)

            for cid in cell_id_perm:
                self.assertTrue(cid in cell_id)

        # It is important merge file has synapses sorted with dest_id, source_id as sort order since during pruning
        # we assume this to be able to quickly find all synapses on post synaptic cell.
        # TODO: Also include the ChannelModelID in sorting check
        with self.subTest("Checking-merge-file-sorted"):

            for mf in [
                    "temp/synapses-for-neurons-0-to-28-MERGE-ME.hdf5",
                    "temp/gapJunctions-for-neurons-0-to-28-MERGE-ME.hdf5",
                    "network-synapses.hdf5"
            ]:

                merge_file = os.path.join(self.network_path, mf)

                sl = SnuddaLoad(merge_file, verbose=True)
                if "synapses" in sl.data:
                    syn = sl.data["synapses"][:sl.data["nSynapses"], :2]
                    syn_order = syn[:, 1] * len(self.sd.neurons) + syn[:, 0]
                    self.assertTrue((np.diff(syn_order) >= 0).all())

                if "gapJunctions" in sl.data:
                    gj = sl.data["gapJunctions"][:sl.data["nGapJunctions"], :2]
                    gj_order = gj[:, 1] * len(self.sd.neurons) + gj[:, 0]
                    self.assertTrue((np.diff(gj_order) >= 0).all())

        with self.subTest("synapse-f1"):
            # Test of f1
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-1.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it

            sl = SnuddaLoad(pruned_output, verbose=True)
            # Setting f1=0.5 in config should remove 50% of GABA synapses, but does so randomly, for AMPA we used f1=0.9
            gaba_id = sl.data["connectivityDistributions"][
                "ballanddoublestick",
                "ballanddoublestick"]["GABA"]["channelModelID"]
            ampa_id = sl.data["connectivityDistributions"][
                "ballanddoublestick",
                "ballanddoublestick"]["AMPA"]["channelModelID"]

            n_gaba = np.sum(sl.data["synapses"][:, 6] == gaba_id)
            n_ampa = np.sum(sl.data["synapses"][:, 6] == ampa_id)

            self.assertTrue((20 * 8 + 10 * 2) * 0.5 -
                            10 < n_gaba < (20 * 8 + 10 * 2) * 0.5 + 10)
            self.assertTrue((20 * 8 + 10 * 2) * 0.9 -
                            10 < n_ampa < (20 * 8 + 10 * 2) * 0.9 + 10)

        with self.subTest("synapse-softmax"):
            # Test of softmax
            testing_config_file = os.path.join(
                self.network_path, "network-config-test-2.json"
            )  # Only GABA synapses in this config
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it
            sl = SnuddaLoad(pruned_output)
            # Softmax reduces number of synapses
            self.assertTrue(sl.data["nSynapses"] < 20 * 8 + 10 * 2)

        with self.subTest("synapse-mu2"):
            # Test of mu2
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-3.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it
            sl = SnuddaLoad(pruned_output)
            # With mu2 having 2 synapses means 50% chance to keep them, having 1 will be likely to have it removed
            self.assertTrue(
                20 * 8 * 0.5 - 10 < sl.data["nSynapses"] < 20 * 8 * 0.5 + 10)

        with self.subTest("synapse-a3"):
            # Test of a3
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-4.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it
            sl = SnuddaLoad(pruned_output)

            # a3=0.6 means 40% chance to remove all synapses between a pair
            self.assertTrue(
                (20 * 8 + 10 * 2) * 0.6 -
                14 < sl.data["nSynapses"] < (20 * 8 + 10 * 2) * 0.6 + 14)

        with self.subTest("synapse-distance-dependent-pruning"):
            # Testing distance dependent pruning
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-5.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it
            sl = SnuddaLoad(pruned_output)

            # "1*(d >= 100e-6)" means we remove all synapses closer than 100 micrometers
            self.assertEqual(sl.data["nSynapses"], 20 * 6)
            self.assertTrue(
                (sl.data["synapses"][:, 8] >=
                 100).all())  # Column 8 -- distance to soma in micrometers

        # TODO: Need to do same test for Gap Junctions also -- but should be same results, since same codebase
        with self.subTest("gap-junction-f1"):
            # Test of f1
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-6.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it

            sl = SnuddaLoad(pruned_output)
            # Setting f1=0.7 in config should remove 30% of gap junctions, but does so randomly
            self.assertTrue(
                64 * 0.7 - 10 < sl.data["nGapJunctions"] < 64 * 0.7 + 10)

        with self.subTest("gap-junction-softmax"):
            # Test of softmax
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-7.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it
            sl = SnuddaLoad(pruned_output)
            # Softmax reduces number of synapses
            self.assertTrue(sl.data["nGapJunctions"] < 16 * 2 + 10)

        with self.subTest("gap-junction-mu2"):
            # Test of mu2
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-8.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it
            sl = SnuddaLoad(pruned_output)
            # With mu2 having 4 synapses means 50% chance to keep them, having 1 will be likely to have it removed
            self.assertTrue(
                64 * 0.5 - 10 < sl.data["nGapJunctions"] < 64 * 0.5 + 10)

        with self.subTest("gap-junction-a3"):
            # Test of a3
            testing_config_file = os.path.join(self.network_path,
                                               "network-config-test-9.json")
            sp = SnuddaPrune(network_path=self.network_path,
                             config_file=testing_config_file,
                             verbose=True,
                             keep_files=True)  # Use default config file
            sp.prune()

            # Load the pruned data and check it
            sl = SnuddaLoad(pruned_output, verbose=True)

            # a3=0.7 means 30% chance to remove all synapses between a pair
            self.assertTrue(
                64 * 0.7 - 10 < sl.data["nGapJunctions"] < 64 * 0.7 + 10)

        if False:  # Distance dependent pruning currently not implemented for gap junctions
            with self.subTest("gap-junction-distance-dependent-pruning"):
                # Testing distance dependent pruning
                testing_config_file = os.path.join(
                    self.network_path, "network-config-test-10.json")
                sp = SnuddaPrune(network_path=self.network_path,
                                 config_file=testing_config_file,
                                 verbose=True,
                                 keep_files=True)  # Use default config file
                sp.prune()

                # Load the pruned data and check it
                sl = SnuddaLoad(pruned_output, verbose=True)

                # "1*(d <= 120e-6)" means we remove all synapses further away than 100 micrometers
                self.assertEqual(sl.data["nGapJunctions"], 2 * 4 * 4)
                self.assertTrue(
                    (sl.data["gapJunctions"][:, 8] <=
                     120).all())  # Column 8 -- distance to soma in micrometers