Ejemplo n.º 1
0
def parse_prediction(db_credentials, predict_config_path):

    predict_cfg = read_predict_config(predict_config_path)
    db = SynisterDb(db_credentials, predict_cfg["db_name_data"])

    predictions = db.get_predictions(predict_cfg["split_name"],
                                     predict_cfg["experiment"],
                                     predict_cfg["train_number"],
                                     predict_cfg["predict_number"])

    synapses = db.get_synapses()
    skeletons = db.get_skeletons()

    predicted_synapses =\
    {
            synapse_id:
            {
                **{"prediction": prediction["prediction"]},
                **synapses[synapse_id],
                **skeletons[synapses[synapse_id]["skeleton_id"]]
                }
            for synapse_id, prediction in predictions.items()
    }

    return predicted_synapses, predict_cfg
Ejemplo n.º 2
0
def monitor_prediction(predict_config, interval=60):

    db = SynisterDb(predict_config["db_credentials"],
                    predict_config["db_name_data"])
    done_0, _ = db.count_predictions(predict_config["split_name"],
                                     predict_config["experiment"],
                                     predict_config["train_number"],
                                     predict_config["predict_number"])
    start = time.time()
    done_interval = []
    exit = False
    while not exit:
        done, total = db.count_predictions(predict_config["split_name"],
                                           predict_config["experiment"],
                                           predict_config["train_number"],
                                           predict_config["predict_number"])

        time_elapsed = time.time() - start
        if done - done_0 > 0:
            eta = time_elapsed / (done - done_0) * (total - done)
            sps = (done - done_0) / time_elapsed
        else:
            eta = "NA"
            sps = "NA"
        print("{} from {} predictions done".format(done, total))
        print("Time elapsed {}".format(time_elapsed))
        print("{} samples/second".format(sps))
        print("ETA: {}".format(eta))
        done_interval.append(done)
        if len(done_interval) > 5:
            if done_interval[-1] == done_interval[-2] == done_interval[
                    -3] == done_interval[-4]:
                exit = True

        time.sleep(interval)
Ejemplo n.º 3
0
def prediction_writer(prediction_queue, db_credentials, db_name_data,
                      split_name, experiment, train_number, predict_number):

    db = SynisterDb(db_credentials, db_name_data)

    while True:
        data_synapse = prediction_queue.get()

        db.write_prediction(split_name, data_synapse["prediction"], experiment,
                            train_number, predict_number, data_synapse["x"],
                            data_synapse["y"], data_synapse["z"])

        prediction_queue.task_done()
Ejemplo n.º 4
0
 def setUp(self):
     self.db_credentials = os.path.join(os.path.abspath(os.path.dirname(__file__)) + "/../../db_credentials.ini")
     self.db_name = "synister_v2_refactor"
     self.split_name = "neuron"
     self.points = PointsKey('SYNAPSES')
     self.db = SynisterDb(self.db_credentials, self.db_name)
     self.neurotransmitters = [
         ('gaba',),
         #('acetylcholine',), TODO: Fix position query for large number of objects. DB hangs for this.
         ('glutamate',),
         ('dopamine',),
         ('octopamine',),
         ('serotonin',),
     ]
Ejemplo n.º 5
0
class SynapseSourceMongo(CsvPointsSource):
    def __init__(self,
                 db_credentials,
                 db_name,
                 split_name,
                 synapse_type,
                 points,
                 points_spec=None,
                 scale=None):

        self.db = SynisterDb(db_credentials, db_name)
        self.split_name = split_name
        self.db_name = db_name
        self.synapse_type = synapse_type
        super(SynapseSourceMongo, self).__init__(filename=None,
                                                 points=points,
                                                 points_spec=points_spec,
                                                 scale=scale)

    def _read_points(self):
        print("Reading split {} from db {}".format(self.split_name,
                                                   self.db_name))
        synapses = self.db.get_synapses(neurotransmitters=self.synapse_type,
                                        split_name=self.split_name)

        points = np.array(
            [[int(synapse["z"]),
              int(synapse["y"]),
              int(synapse["x"])] for synapse in synapses.values()
             if synapse["splits"][self.split_name] == "train"])

        self.data = points
        self.ndims = 3
Ejemplo n.º 6
0
class SynapseSourceMongoTestCase(unittest.TestCase):
    def setUp(self):
        self.db_credentials = os.path.join(os.path.abspath(os.path.dirname(__file__)) + "/../../db_credentials.ini")
        self.db_name = "synister_v2_refactor"
        self.split_name = "neuron"
        self.points = PointsKey('SYNAPSES')
        self.db = SynisterDb(self.db_credentials, self.db_name)
        self.neurotransmitters = [
            ('gaba',),
            #('acetylcholine',), TODO: Fix position query for large number of objects. DB hangs for this.
            ('glutamate',),
            ('dopamine',),
            ('octopamine',),
            ('serotonin',),
        ]

    def runTest(self):
        for synapse_type in self.neurotransmitters:
            source = SynapseSourceMongo(self.db_credentials,
                                        self.db_name,
                                        self.split_name,
                                        synapse_type,
                                        self.points)

            
            source.setup()
            points = source.data
            
            print("query pos...")
            synapses = self.db.get_synapses(positions=points)

            print("get skeletons...")
            skeletons = self.db.get_skeletons()

            n = 1
            for synapse in synapses.values():
                self.assertTrue(synapse["splits"][self.split_name] == "train")
                nt = skeletons[synapse["skeleton_id"]]["nt_known"]
                self.assertTrue(nt == synapse_type)
                n += 1


            synapses_in_split = self.db.get_synapses(split_name=self.split_name, neurotransmitters=synapse_type)
            all_synapse_ids_in_train = [id_ for id_, s in synapses_in_split.items() if s["splits"][self.split_name] == "train"]
            synapse_ids_retrieved = [s for s in synapses]
            self.assertTrue(sorted(all_synapse_ids_in_train) == sorted(synapse_ids_retrieved))
Ejemplo n.º 7
0
    def __init__(self,
                 db_credentials,
                 db_name,
                 split_name,
                 synapse_type,
                 points,
                 points_spec=None,
                 scale=None):

        self.db = SynisterDb(db_credentials, db_name)
        self.split_name = split_name
        self.db_name = db_name
        self.synapse_type = synapse_type
        super(SynapseSourceMongo, self).__init__(filename=None,
                                                 points=points,
                                                 points_spec=points_spec,
                                                 scale=scale)
Ejemplo n.º 8
0
import threading
from synister.synister_db import SynisterDb
import time

self_path = os.path.realpath(os.path.dirname(__file__))

worker_config = read_worker_config(os.path.join(self_path,
                                                "worker_config.ini"))
predict_config = read_predict_config(
    os.path.join(self_path, "predict_config.ini"))

base_cmd = "python {}".format(os.path.join(self_path, "predict_pipeline.py"))

num_block_workers = worker_config["num_block_workers"]

db = SynisterDb(predict_config["db_credentials"],
                predict_config["db_name_data"])

db.initialize_prediction(
    predict_config["split_name"],
    predict_config["experiment"],
    predict_config["train_number"],
    predict_config["predict_number"],
    overwrite=predict_config["overwrite"],
    validation=predict_config["split_part"] == "validation")


def monitor_prediction(predict_config, interval=60):

    db = SynisterDb(predict_config["db_credentials"],
                    predict_config["db_name_data"])
    done_0, _ = db.count_predictions(predict_config["split_name"],
Ejemplo n.º 9
0
def test(worker_id,
         train_checkpoint,
         db_credentials,
         db_name_data,
         split_name,
         batch_size,
         input_shape,
         fmaps,
         downsample_factors,
         voxel_size,
         synapse_types,
         raw_container,
         raw_dataset,
         experiment,
         train_number,
         predict_number,
         num_cache_workers,
         num_block_workers,
         split_part="test",
         output_classes=None,
         network="VGG",
         fmap_inc=(2, 2, 2, 2),
         n_convolutions=(2, 2, 2, 2),
         network_appendix=None,
         **kwargs):

    if not split_part in ["validation", "test"]:
        raise ValueError("'split_part' must be either 'test' or 'validation'")

    print("Network: ", network)
    if network == "VGG":
        model = Vgg3D(input_size=input_shape,
                      fmaps=fmaps,
                      downsample_factors=downsample_factors,
                      fmap_inc=fmap_inc,
                      n_convolutions=n_convolutions)

    elif network == "Efficient":
        if network_appendix is None:
            network_appendix = "b0"

        print(network_appendix)
        model = EfficientNet3D.from_name(
            "efficientnet-{}".format(network_appendix),
            override_params={'num_classes': len(synapse_types)},
            in_channels=1)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    checkpoint = torch.load(train_checkpoint, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()

    logger.info('Load test sample locations from db {} and split {}...'.format(
        db_name_data, split_name))
    db = SynisterDb(db_credentials, db_name_data)

    logger.info('Initialize prediction writers...')
    prediction_queue = multiprocessing.JoinableQueue()

    for i in range(num_cache_workers):
        worker = multiprocessing.Process(
            target=prediction_writer,
            args=(prediction_queue, db_credentials, db_name_data, split_name,
                  experiment, train_number, predict_number))
        #worker.daemon = True
        worker.start()

    logger.info('Start prediction...')

    locations = []
    synapses = db.get_synapses(split_name=split_name)
    predict_synapses = db.get_predictions(split_name, experiment, train_number,
                                          predict_number)

    locations = [(int(synapse["z"]), int(synapse["y"]), int(synapse["x"]))
                 for synapse_id, synapse in synapses.items()
                 if synapse["splits"][split_name] == split_part
                 and predict_synapses[synapse_id]["prediction"] == None]

    loc_start = int(float(worker_id) / num_block_workers * len(locations))
    loc_end = int(float(worker_id + 1) / num_block_workers * len(locations))
    my_locations = locations[loc_start:loc_end]

    for i in range(0, len(my_locations), batch_size):
        logger.info('Predict location {}/{}'.format(i, len(my_locations)))
        locs = my_locations[i:i + batch_size]
        raw, raw_normalized = get_raw(locs, input_shape, voxel_size,
                                      raw_container, raw_dataset)
        if network == "Efficient":
            shape = tuple(raw_normalized.shape)
            raw_normalized = raw_normalized.reshape(
                [batch_size, 1, shape[1], shape[2],
                 shape[3]]).astype(np.float32)
        output = predict(raw_normalized, model)

        for k in range(np.shape(output)[0]):
            loc_k = locs[k]
            out_k = output[k, :]
            loc_k_list = loc_k

            data_synapse = {
                "prediction": out_k.tolist(),
                "z": loc_k_list[0],
                "y": loc_k_list[1],
                "x": loc_k_list[2]
            }

            prediction_queue.put(data_synapse)

    logger.info("Wait for write...")
    prediction_queue.join()
Ejemplo n.º 10
0
 def setUp(self):
     self.db_credentials = os.path.join(
         os.path.abspath(os.path.dirname(__file__)) +
         "/../../db_credentials.ini")
     self.db = SynisterDb(self.db_credentials, "synister_v2_refactor")