Beispiel #1
0
    def get_local_segmentation(self, roi: daisy.Roi, threshold: float):
        # open fragments
        fragments = daisy.open_ds(self.fragments_file, self.fragments_dataset)

        # open RAG DB
        rag_provider = MongoDbRagProvider(
            self.fragments_db,
            host=self.fragments_host,
            mode="r",
            edges_collection=self.edges_collection,
        )

        segmentation = fragments[roi]
        segmentation.materialize()
        ids = [int(id) for id in list(np.unique(segmentation.data))]
        rag = rag_provider.read_rag(ids)

        if len(rag.nodes()) == 0:
            raise Exception('RAG is empty')

        components = rag.get_connected_components(threshold)

        values_map = np.array(
            [[fragment, i] for i in range(1,
                                          len(components) + 1)
             for fragment in components[i - 1]],
            dtype=np.uint64,
        )
        old_values = values_map[:, 0]
        new_values = values_map[:, 1]
        replace_values(segmentation.data, old_values, new_values, inplace=True)

        return segmentation
def get_segmentation(
        fragments,
        fragments_file,
        lut_fragment_segment,
        edges_collection,
        threshold):

    logging.info(
        "Loading fragment - segment lookup table for threshold %s..." %
        threshold)
    fragment_segment_lut_dir = os.path.join(
        fragments_file,
        lut_fragment_segment)

    fragment_segment_lut_file = os.path.join(
        fragment_segment_lut_dir,
        'seg_%s_%d.npz' % (edges_collection, int(threshold * 100)))

    fragment_segment_lut = np.load(
        fragment_segment_lut_file)['fragment_segment_lut']

    assert fragment_segment_lut.dtype == np.uint64

    # fragments = fragments.to_ndarray(block.write_roi)

    logging.info("Relabeling fragment ids with segment ids...")

    segment_ids = replace_values(
        fragments, fragment_segment_lut[0], fragment_segment_lut[1])

    return segment_ids
Beispiel #3
0
def get_segmentation(fragments, fragments_file, edges_collection, threshold,
                     run_type):

    logging.info(f"Loading fragment - segment lookup table for threshold \
            {threshold}...")

    fragment_segment_lut_dir = os.path.join(fragments_file, 'luts',
                                            'fragment_segment')

    if run_type:
        logging.info(f"Run type set, evaluating on {run_type} dataset")

        fragment_segment_lut_dir = os.path.join(fragment_segment_lut_dir,
                                                run_type)

    fragment_segment_lut_file = os.path.join(
        fragment_segment_lut_dir,
        f'seg_{edges_collection}_{int(threshold*100)}.npz')

    fragment_segment_lut = np.load(
        fragment_segment_lut_file)['fragment_segment_lut']

    assert fragment_segment_lut.dtype == np.uint64

    logging.info("Relabeling fragment ids with segment ids...")

    segment_ids = replace_values(fragments, fragment_segment_lut[0],
                                 fragment_segment_lut[1])

    return segment_ids
    def get_site_segment_ids(self, threshold):

        # get fragment-segment LUT
        logging.info("Reading fragment-segment LUT...")
        start = time.time()

        fragment_segment_lut_dir = os.path.join(self.fragments_file,
                                                'luts/fragment_segment')

        if self.run_type:
            logging.info(f"Using lookup tables for {self.run_type} data")
            fragment_segment_lut_dir = os.path.join(fragment_segment_lut_dir,
                                                    self.run_type)

        logging.info("Reading fragment segment luts from: "
                     f"{fragment_segment_lut_dir}")

        fragment_segment_lut_file = os.path.join(
            fragment_segment_lut_dir,
            'seg_%s_%d.npz' % (self.edges_collection, int(threshold * 100)))

        fragment_segment_lut = np.load(
            fragment_segment_lut_file)['fragment_segment_lut']

        assert fragment_segment_lut.dtype == np.uint64

        # get the segment ID for each site
        logging.info("Mapping sites to segments...")

        site_mask = np.isin(fragment_segment_lut[0], self.site_fragment_ids)
        site_segment_ids = replace_values(self.site_fragment_ids,
                                          fragment_segment_lut[0][site_mask],
                                          fragment_segment_lut[1][site_mask])

        return site_segment_ids, fragment_segment_lut
Beispiel #5
0
    def __relabel(self, array, components, component_labels):

        old_values = []
        new_values = []

        for component, label in zip(components, component_labels):
            for c in component:
                old_values.append(c)
                new_values.append(label)

        array[:] = replace_values(array, old_values, new_values)
def segment_in_block(block, fragments_file, segmentation, fragments, lut):

    logging.info("Copying fragments to memory...")

    # load fragments
    fragments = fragments.to_ndarray(block.write_roi)

    # replace values, write to empty array
    relabelled = np.zeros_like(fragments)
    relabelled = replace_values(fragments,
                                lut[0],
                                lut[1],
                                out_array=relabelled)

    segmentation[block.write_roi] = relabelled
def segment_in_block(block, fragments_file, segmentation, fragments, lut):

    logging.info("Copying fragments to memory...")
    start = time.time()
    fragments = fragments.to_ndarray(block.write_roi)
    logging.info("%.3fs" % (time.time() - start))

    # get segments

    num_segments = len(np.unique(lut[1]))
    logging.info("Relabelling fragments to %d segments", num_segments)
    start = time.time()
    relabelled = replace_values(fragments, lut[0], lut[1])
    logging.info("%.3fs" % (time.time() - start))

    segmentation[block.write_roi] = relabelled
Beispiel #8
0
def watershed_in_block(
        affs,
        block,
        context,
        rag_provider,
        fragments_out,
        num_voxels_in_block,
        mask=None,
        fragments_in_xy=False,
        epsilon_agglomerate=0.0,
        filter_fragments=0.0,
        min_seed_distance=10,
        replace_sections=None):
    '''

    Args:

        filter_fragments (float):

            Filter fragments that have an average affinity lower than this
            value.

        min_seed_distance (int):

            Controls distance between seeds in the initial watershed. Reducing
            this value improves downsampled segmentation.
    '''

    total_roi = affs.roi

    logger.debug("reading affs from %s", block.read_roi)

    affs = affs.intersect(block.read_roi)
    affs.materialize()

    if affs.dtype == np.uint8:
        logger.info("Assuming affinities are in [0,255]")
        max_affinity_value = 255.0
        affs.data = affs.data.astype(np.float32)
    else:
        max_affinity_value = 1.0

    if mask is not None:

        logger.debug("reading mask from %s", block.read_roi)
        mask_data = get_mask_data_in_roi(mask, affs.roi, affs.voxel_size)
        logger.debug("masking affinities")
        affs.data *= mask_data

    # extract fragments
    fragments_data, _ = watershed_from_affinities(
        affs.data,
        max_affinity_value,
        fragments_in_xy=fragments_in_xy,
        min_seed_distance=min_seed_distance)

    if mask is not None:
        fragments_data *= mask_data.astype(np.uint64)

    if filter_fragments > 0:

        if fragments_in_xy:
            average_affs = np.mean(affs.data[0:2]/max_affinity_value, axis=0)
        else:
            average_affs = np.mean(affs.data/max_affinity_value, axis=0)

        filtered_fragments = []

        fragment_ids = np.unique(fragments_data)

        for fragment, mean in zip(
                fragment_ids,
                measurements.mean(
                    average_affs,
                    fragments_data,
                    fragment_ids)):
            if mean < filter_fragments:
                filtered_fragments.append(fragment)

        filtered_fragments = np.array(
            filtered_fragments,
            dtype=fragments_data.dtype)
        replace = np.zeros_like(filtered_fragments)
        replace_values(fragments_data, filtered_fragments, replace, inplace=True)

    if epsilon_agglomerate > 0:

        logger.info(
            "Performing initial fragment agglomeration until %f",
            epsilon_agglomerate)

        generator = waterz.agglomerate(
                affs=affs.data/max_affinity_value,
                thresholds=[epsilon_agglomerate],
                fragments=fragments_data,
                scoring_function='OneMinus<HistogramQuantileAffinity<RegionGraphType, 25, ScoreValue, 256, false>>',
                discretize_queue=256,
                return_merge_history=False,
                return_region_graph=False)
        fragments_data[:] = next(generator)

        # cleanup generator
        for _ in generator:
            pass

    if replace_sections:

        logger.info("Replacing sections...")

        block_begin = block.write_roi.get_begin()
        shape = block.write_roi.get_shape()

        z_context = context[0]/affs.voxel_size[0]
        logger.info("Z context: %i",z_context)

        mapping = {}

        voxel_offset = block_begin[0]/affs.voxel_size[0]

        for i,j in zip(
                range(fragments_data.shape[0]),
                range(shape[0])):
            mapping[i] = i
            mapping[j] = int(voxel_offset + i) \
                    if block_begin[0] == total_roi.get_begin()[0] \
                    else int(voxel_offset + (i - z_context))

        logging.info('Mapping: %s', mapping)

        replace = [k for k,v in mapping.items() if v in replace_sections]

        for r in replace:
            logger.info("Replacing mapped section %i with zero", r)
            fragments_data[r] = 0

    #todo add key value replacement option

    fragments = daisy.Array(fragments_data, affs.roi, affs.voxel_size)

    # crop fragments to write_roi
    fragments = fragments[block.write_roi]
    fragments.materialize()
    max_id = fragments.data.max()

    # ensure we don't have IDs larger than the number of voxels (that would
    # break uniqueness of IDs below)
    if max_id > num_voxels_in_block:
        logger.warning(
            "fragments in %s have max ID %d, relabelling...",
            block.write_roi, max_id)
        fragments.data, max_id = relabel(fragments.data)

        assert max_id < num_voxels_in_block

    # ensure unique IDs
    id_bump = block.block_id[1]*num_voxels_in_block
    logger.debug("bumping fragment IDs by %i", id_bump)
    fragments.data[fragments.data>0] += id_bump
    fragment_ids = range(id_bump + 1, id_bump + 1 + int(max_id))

    # store fragments
    logger.debug("writing fragments to %s", block.write_roi)
    fragments_out[block.write_roi] = fragments

    # following only makes a difference if fragments were found
    if max_id == 0:
        return

    # get fragment centers
    fragment_centers = {
        fragment: block.write_roi.get_offset() + affs.voxel_size*daisy.Coordinate(center)
        for fragment, center in zip(
            fragment_ids,
            measurements.center_of_mass(fragments.data, fragments.data, fragment_ids))
        if not np.isnan(center[0])
    }

    # store nodes
    rag = rag_provider[block.write_roi]
    rag.add_nodes_from([
        (node, {
            'center_z': c[0],
            'center_y': c[1],
            'center_x': c[2]
            }
        )
        for node, c in fragment_centers.items()
    ])
    rag.write_nodes(block.write_roi)
    def parse_rag_excerpt(self, nodes_list, edges_list):

        # TODO parametrize the used names
        id_field = 'id'
        node1_field = 'u'
        node2_field = 'v'
        merge_score_field = 'merge_score'
        gt_merge_score_field = 'gt_merge_score'
        merge_labeled_field = 'merge_labeled'

        # TODO remove duplicate code, this is also used in hemibrain_graph
        def to_np_arrays(inp):
            d = {}
            for i in inp:
                for k, v in i.items():
                    d.setdefault(k, []).append(v)
            for k, v in d.items():
                d[k] = np.array(v)
            return d

        node_attrs = to_np_arrays(nodes_list)
        # TODO maybe port to numpy, but generally fast
        # Drop edges for which one of the incident nodes is not in the
        # extracted node set
        start = time.time()
        for e in reversed(edges_list):
            if e[node1_field] not in node_attrs[id_field] or e[
                    node2_field] not in node_attrs[id_field]:
                edges_list.remove(e)
        logger.debug(f'drop edges at the border in {time.time() - start}s')

        # If all edges were removed in the step above, raise a ValueError
        # that is caught later on
        if len(edges_list) == 0:
            raise ValueError(
                f'Removed all edges in ROI, as one node is outside of ROI')

        edges_attrs = to_np_arrays(edges_list)

        node_ids_np = node_attrs[id_field].astype(np.int64)
        node_ids = torch.tensor(node_ids_np, dtype=torch.long)

        # By not operating inplace and providing out_array, we always use
        # the C++ implementation of replace_values

        logger.debug(
            f'before: interval {node_ids_np.max() - node_ids_np.min()}, min id {node_ids_np.min()}, max id {node_ids_np.max()}, shape {node_ids_np.shape}'
        )
        start = time.time()
        edges_node1 = np.zeros_like(edges_attrs[node1_field], dtype=np.int64)
        edges_node1 = replace_values(
            in_array=edges_attrs[node1_field].astype(np.int64),
            old_values=node_ids_np,
            new_values=np.arange(len(node_attrs[id_field]), dtype=np.int64),
            inplace=False,
            out_array=edges_node1)
        edges_attrs[node1_field] = edges_node1
        logger.debug(
            f'remapping {len(edges_attrs[node1_field])} edges (u) in {time.time() - start} s'
        )
        logger.debug(
            f'edges after: min id {edges_attrs[node1_field].min()}, max id {edges_attrs[node1_field].max()}'
        )

        start = time.time()
        edges_node2 = np.zeros_like(edges_attrs[node2_field], dtype=np.int64)
        edges_node2 = replace_values(
            in_array=edges_attrs[node2_field].astype(np.int64),
            old_values=node_ids_np,
            new_values=np.arange(len(node_attrs[id_field]), dtype=np.int64),
            inplace=False,
            out_array=edges_node2)
        edges_attrs[node2_field] = edges_node2
        logger.debug(
            f'remapping {len(edges_attrs[node2_field])} edges (v) in {time.time() - start} s'
        )
        logger.debug(
            f'edges after: min id {edges_attrs[node2_field].min()}, max id {edges_attrs[node2_field].max()}'
        )

        # TODO I could potentially avoid transposing twice
        # edge index requires dimensionality of (2,e)
        # pyg works with directed edges, duplicate each edge here
        edge_index_undir = np.array(
            [edges_attrs[node1_field], edges_attrs[node2_field]]).transpose()
        edge_index_dir = np.repeat(edge_index_undir, 2, axis=0)
        edge_index_dir[1::2, :] = np.flip(edge_index_dir[1::2, :], axis=1)
        edge_index = torch.tensor(edge_index_dir.astype(np.int64).transpose(),
                                  dtype=torch.long)

        edge_attr_undir = np.expand_dims(edges_attrs[merge_score_field],
                                         axis=1)
        edge_attr_dir = np.repeat(edge_attr_undir, 2, axis=0)
        edge_attr = torch.tensor(edge_attr_dir, dtype=torch.float)

        pos = torch.transpose(input=torch.tensor([
            node_attrs['center_z'], node_attrs['center_y'],
            node_attrs['center_x']
        ],
                                                 dtype=torch.float),
                              dim0=0,
                              dim1=1)

        # TODO node features go here
        x = torch.ones(len(node_attrs[id_field]), 1, dtype=torch.float)

        # Targets operate on undirected edges, therefore no duplicate necessary
        mask = torch.tensor(edges_attrs[merge_labeled_field],
                            dtype=torch.float)
        y = torch.tensor(edges_attrs[gt_merge_score_field], dtype=torch.long)

        return edge_index, edge_attr, x, pos, node_ids, mask, y
Beispiel #10
0
def simulate_random_cages(volume,
                          segmentation,
                          cages,
                          min_density,
                          max_density,
                          fm_intensity,
                          point_spread_function,
                          return_cage_map=False,
                          return_density_map=False,
                          no_cage_probability=0.0):
    '''Randomly render cages with a range of densities for each segment into a
    volume.

    Args:

        volume (Volume): The volume to render to. The volume is expected to be
        real valued with values between 0 and 1.

        segmentation (Volume): A segmentation of the volume. The segmentation
        is expected to be int valued with values between 1 and n. 0 will be
        treated as background.

        cages (list of Cages): A list of cages to randomly select from.

        min_density, max_density (float): The minimum and maximum density to
        uniformly choose from.

        fm_intensity (float): Render intensity for element 100 (Fermium), to be
        used as reference point for cubic intensity transfer function.

        point_spread_function (PointSpreadFunction): The PSF to use to render
        points.

        return_cage_map (bool): Return a map of which segment contains which
        type of cage (as an integer).

        return_density_map (bool): Return a map of the cage densities per
        segment.

        no_cage_probability (float): The probability of expressing no cage, per
        segment.
    '''
    assert (volume.data.min() >= 0 and volume.data.max() <= 1)

    id_list = np.unique(segmentation.data)
    id_list = id_list[np.nonzero(id_list)]

    random_cages = {}
    random_densities = {}

    for id_element in id_list:
        test = random.random()

        if test > no_cage_probability:
            random_cages[id_element] = random.choice(cages)
            random_densities[id_element] = random.uniform(
                min_density, max_density)
        else:
            random_cages[id_element] = None
            random_densities[id_element] = 0

    simulate_cages(volume, segmentation, random_cages, random_densities,
                   fm_intensity, point_spread_function)

    ret = ()

    if return_cage_map:

        # replace segmentation IDs with cage IDs
        cage_map = replace_values(segmentation.data, id_list, [
            random_cages[i].cage_id if random_cages[i] else 0 for i in id_list
        ])

        ret = ret + (cage_map, )

    if return_density_map:

        densities = np.array([random_densities[i] for i in id_list],
                             dtype=np.float64)

        # (almost) the same for the density map:
        density_map = replace_values(segmentation.data.astype(np.uint64),
                                     id_list.astype(np.uint64),
                                     densities.view(np.uint64)).view(
                                         np.float64)
        density_map = density_map.astype(np.float32)

        ret = ret + (density_map, )

    if len(ret) > 0:
        return ret
Beispiel #11
0
    def atexit_tasks(model):

        # -----------------------------------------------
        # ---------------- EVALUATION ROUTINE -----------
        # -----------------------------------------------

        _log.info('saving tensorboardx summary files ...')
        # save the tensorboardx summary files
        summary_dir_exit = os.path.join(config.run_abs_path,
                                        config.summary_dir)
        summary_compressed = summary_dir_exit + '.tar.gz'
        # remove old tar file
        if os.path.isfile(summary_compressed):
            os.remove(summary_compressed)

        with tarfile.open(summary_compressed, mode='w:gz') as archive:
            archive.add(summary_dir_exit, arcname='summary', recursive=True)
        _run.add_artifact(filename=summary_compressed, name='summary.tar.gz')

        model.eval()
        model.current_writer = None

        # final print routine
        train_dataset.print_summary()

        _log.info(f'Total number of parameters: {total_params}')

        if config.final_training_pass:
            # train loss
            final_loss_train = 0.0
            final_metric_train = 0.0
            final_nr_nodes_train = 0

            _log.info('final training pass ...')
            start = time.time()
            for data_ft in data_loader_train:
                data_ft = data_ft.to(device)
                out_ft = model(data_ft)
                final_loss_train += model.loss(
                    out_ft, data_ft.y, data_ft.mask).item() * data_ft.num_nodes
                final_metric_train += model.out_to_metric(
                    out_ft, data_ft.y) * data_ft.num_nodes
                final_nr_nodes_train += data_ft.num_nodes
            final_loss_train /= final_nr_nodes_train
            final_metric_train /= final_nr_nodes_train

            _run.log_scalar('loss_train_final', final_loss_train,
                            config.training_epochs)
            _run.log_scalar('accuracy_train_final', final_metric_train,
                            config.training_epochs)
            _log.info(f'final training pass in {time.time() - start:.3f}s')
        else:
            # report training loss of last epoch
            final_loss_train = epoch_loss
            final_metric_train = epoch_metric_train

        _log.info(
            f'Mean train loss ({train_dataset.__len__()} samples): {final_loss_train:.3f}'
        )
        _log.info(f'Mean accuracy on train set: {final_metric_train:.3f}')

        if config.final_test_pass:

            # test loss
            data_loader_test = DataLoader(
                test_dataset,
                batch_size=config.batch_size_eval,
                shuffle=False,
                num_workers=config.num_workers,
                worker_init_fn=lambda idx: np.random.seed())
            test_loss = 0.0
            test_metric = 0.0
            nr_nodes_test = 0
            test_predictions = []
            test_targets = []

            test_1d_outputs = dict()

            _log.info('test pass ...')
            start_test_pass = time.time()
            for data_fe in data_loader_test:
                data_fe = data_fe.to(device)
                out_fe = model(data_fe)

                if config.write_to_db:
                    start = time.time()
                    out_1d = model.out_to_one_dim(out_fe).cpu()
                    # TODO this assumes again that every pairs of directed edges are next to each other
                    edges = torch.transpose(data_fe.edge_index, 0, 1)[0::2]
                    edges = edges[data_fe.roi_mask].cpu().numpy().astype(
                        np.int64)

                    edges_orig_labels = np.zeros_like(edges, dtype=np.int64)
                    edges_orig_labels = replace_values(
                        in_array=edges,
                        out_array=edges_orig_labels,
                        old_values=np.arange(data_fe.num_nodes,
                                             dtype=np.int64),
                        new_values=data_fe.node_ids.cpu().numpy().astype(
                            np.int64),
                        inplace=False)

                    # TODO min max might be unnecessary here
                    # convert to tuples, make sure that directedness is not a problem
                    edges_list = [
                        tuple([np.min(i), np.max(i)])
                        for i in edges_orig_labels
                    ]

                    for k, v in zip(edges_list, out_1d):
                        if k not in test_1d_outputs:
                            test_1d_outputs[k] = v
                        else:
                            # TODO adapt strategy here if desired
                            test_1d_outputs[k] = max(test_1d_outputs[k], v)

                    _log.info(
                        f'writing outputs to dict in {time.time() - start}s')

                test_loss += model.loss(
                    out_fe, data_fe.y, data_fe.mask).item() * data_fe.num_nodes
                test_metric += model.out_to_metric(
                    out_fe, data_fe.y) * data_fe.num_nodes
                nr_nodes_test += data_fe.num_nodes
                pred = model.out_to_predictions(out_fe)
                test_predictions.extend(model.predictions_to_list(pred))
                test_targets.extend(data_fe.y.tolist())
            test_loss /= nr_nodes_test
            test_metric /= nr_nodes_test

            _run.log_scalar('loss_test', test_loss, config.training_epochs)
            _run.log_scalar('accuracy_test', test_metric,
                            config.training_epochs)
            _log.info(f'test pass in {time.time() - start_test_pass:.3f}s\n')

            _log.info(
                f'Mean test loss ({test_dataset.__len__()} samples): {test_loss:.3f}'
            )
            _log.info(f'Mean accuracy on test set: {test_metric:.3f}\n')

            if config.write_to_db:
                # timestamp = datetime.datetime.now(
                #     pytz.timezone('US/Eastern')).strftime('%Y%m%dT%H%M%S.%f%z')
                comment = _run.meta_info['options']['--comment']
                test_dataset.write_outputs_to_db(
                    outputs_dict=test_1d_outputs,
                    collection_name=f'{_run.start_time}_{comment}')

            if config.plot_targets_vs_predictions:
                # TODO fix to run on cluster
                # plot targets vs predictions. default is a confusion matrix
                model.plot_targets_vs_predictions(targets=test_targets,
                                                  predictions=test_predictions)
                _run.add_artifact(filename=os.path.join(
                    config.run_abs_path, config.confusion_matrix_path),
                                  name=config.confusion_matrix_path)

                # if Regression, plot targets vs. continuous outputs
                # if isinstance(model.model_type, RegressionProblem):
                #     test_outputs = []
                #     for data in data_loader_test:
                #         data = data.to(device)
                #         out = torch.squeeze(model(data)).tolist()
                #         test_outputs.extend(out)
                #     model.model_type.plot_targets_vs_outputs(
                #         targets=test_targets, outputs=test_outputs)

            # plot the graphs in the test dataset for visual inspection
            if config.plot_graphs_testset:
                if config.plot_graphs_testset < 0 or config.plot_graphs_testset > test_dataset.__len__(
                ):
                    plot_limit = test_dataset.__len__()
                else:
                    plot_limit = config.plot_graphs_testset

                for i in range(plot_limit):
                    g = test_dataset[i]
                    g.to(device)
                    out_p = model(g)
                    g.plot_predictions(config=config,
                                       pred=model.predictions_to_list(
                                           model.out_to_predictions(out_p)),
                                       graph_nr=i,
                                       run=_run,
                                       acc=model.out_to_metric(out_p, g.y),
                                       logger=_log)
        else:
            # report validation loss of last epoch
            test_loss = validation_loss
            test_metric = epoch_metric_val
            _log.info(
                f'Mean validation loss ({test_dataset.__len__()} samples): {test_loss:.3f}'
            )
            _log.info(f'Mean accuracy on validation set: {test_metric:.3f}\n')

        return '\n{0}\ntrain acc: {1:.3f}\ntest acc: {2:.3f}'.format(
            _run.meta_info['options']['--comment'], final_metric_train,
            test_metric)