Ejemplo n.º 1
0
def uresnet_metrics(cfg, data_blob, res, logdir, iteration):
    # UResNet prediction
    if not 'segmentation' in res: return

    method_cfg = cfg['post_processing']['uresnet_metrics']

    index = data_blob['index']
    segment_data = res['segmentation']
    # input_data   = data_blob.get('input_data' if method_cfg is None else method_cfg.get('input_data', 'input_data'), None)
    segment_label = data_blob.get(
        'segment_label' if method_cfg is None else method_cfg.get(
            'segment_label', 'segment_label'), None)
    num_classes = 5 if method_cfg is None else method_cfg.get('num_classes', 5)

    store_per_iteration = True
    if method_cfg is not None and method_cfg.get('store_method',
                                                 None) is not None:
        assert (method_cfg['store_method'] in ['per-iteration', 'per-event'])
        store_per_iteration = method_cfg['store_method'] == 'per-iteration'
    fout = None
    if store_per_iteration:
        fout = CSVData(
            os.path.join(logdir, 'uresnet-metrics-iter-%07d.csv' % iteration))

    for data_idx, tree_idx in enumerate(index):

        if not store_per_iteration:
            fout = CSVData(
                os.path.join(logdir,
                             'uresnet-metrics-event-%07d.csv' % tree_idx))

        predictions = np.argmax(segment_data[data_idx], axis=1)
        label = segment_label[data_idx][:, -1]

        acc = (predictions == label).sum() / float(len(label))
        class_acc = []
        pix = []
        for c1 in range(num_classes):
            for c2 in range(num_classes):
                class_mask = label == c1
                class_acc.append((predictions[class_mask] == c2).sum() /
                                 float(np.count_nonzero(class_mask)))
                pix.append(
                    np.count_nonzero((label == c1) & (predictions == c2)))
        fout.record(('idx', 'acc') + tuple([
            'confusion_%d_%d' % (c1, c2) for c1 in range(num_classes)
            for c2 in range(num_classes)
        ]) + tuple([
            'num_pix_%d_%d' % (c1, c2) for c1 in range(num_classes)
            for c2 in range(num_classes)
        ]), (tree_idx, acc) + tuple(class_acc) + tuple(pix))
        fout.write()

        if not store_per_iteration: fout.close()

    if store_per_iteration: fout.close()
Ejemplo n.º 2
0
def store_uresnet(cfg, data_blob, res, logdir, iteration):
    # UResNet prediction
    if not 'segmentation' in res: return

    method_cfg = cfg['post_processing']['store_uresnet']

    index = data_blob['index']
    segment_data = res['segmentation']
    input_data = data_blob.get(
        'input_data' if method_cfg is None else method_cfg.get(
            'input_data', 'input_data'), None)

    store_per_iteration = True
    if method_cfg is not None and method_cfg.get('store_method',
                                                 None) is not None:
        assert (method_cfg['store_method'] in ['per-iteration', 'per-event'])
        store_per_iteration = method_cfg['store_method'] == 'per-iteration'
    fout = None
    if store_per_iteration:
        fout = CSVData(
            os.path.join(logdir,
                         'uresnet-segmentation-iter-%07d.csv' % iteration))

    for data_idx, tree_idx in enumerate(index):

        if not store_per_iteration:
            fout = CSVData(
                os.path.join(logdir,
                             'uresnet-segmentation-event-%07d.csv' % tree_idx))

        predictions = np.argmax(segment[data_idx], axis=1)
        for row in predictions:
            event = input_data[i]
            fout.record(('idx', 'x', 'y', 'z', 'type', 'value'),
                        (idx, event[0], event[1], event[2], 4, row))
            fout.write()

        if not store_per_iteration: fout.close()

    if store_per_iteration: fout.close()
def instance_clustering(cfg, data_blob, res, logdir, iteration):
    """
    Simple DBSCAN on uresnet clustering output for instance segmentation

    Parameters
    ----------
    data_blob: dict
        Input dictionary returned by iotools
    res: dict
        Results from the network, dictionary using `analysis_keys`
    cfg: dict
        Configuration
    idx: int
        Iteration number

    Input
    -----
    Requires the following analysis keys
    - `segmentation`: output of UResNet segmentation (scores)
    - `clustering`: coordinates in hyperspace, also output of the network
    Requires the following data blob keys
    - `input_data`
    - `segment_label` UResNet 5 classes label
    - `cluster_label`

    Output
    ------
    Writes 2 CSV files:
    - `instance_clustering-*` with the clustering predictions (point type 0 =
    event data, point type 1 = predictions, point type 2 = T-SNE visualizations)
    - `instance_clustering_metrics-*` with some event-wise metrics such as AMI and ARI.
    """

    method_cfg = cfg['post_processing']['instance_clustering']
    model_cfg  = cfg['model']['modules']['uresnet_clustering']

    tsne = TSNE(n_components = 2 if method_cfg is None else method_cfg.get('tsne_dim',2))
    compute_tsne = False if method_cfg is None else method_cfg.get('compute_tsne',False)
    
    store_per_iteration = True
    if method_cfg is not None and method_cfg.get('store_method',None) is not None:
        assert(method_cfg['store_method'] in ['per-iteration','per-event'])
        store_per_iteration = method_cfg['store_method'] == 'per-iteration'
    fout_cluster,fout_metric=None,None
    if store_per_iteration:
        fout_cluster=CSVData(os.path.join(logdir, 'instance-clustering-iter-%07d.csv' % iteration))
        fout_metrics=CSVData(os.path.join(logdir, 'instance-clustering-metrics-iter-%07d.csv' % iteration))

    model_cfg = cfg['model']['modules']['uresnet_clustering']
    data_dim = model_cfg.get('data_dim', 3)
    depth = model_cfg.get('num_strides', 5)
    num_classes = model_cfg.get('num_classes', 5)

    # Loop over batch index
    for batch_index, event_data in enumerate(data_blob['input_data']):
        
        event_index = data_blob['index'][batch_index]
        
        if not store_per_iteration:
            fout_cluster=CSVData(os.path.join(logdir, 'instance-clustering-iter-%07d.csv' % event_index))
            fout_metrics=CSVData(os.path.join(logdir, 'instance-clustering-metrics-iter-%07d.csv' % event_index))
            
        event_segmentation = res['segmentation'][batch_index]
        event_label = data_blob['segment_label'][batch_index]
        event_cluster_label = data_blob['cluster_label'][batch_index]
        max_depth = len(event_cluster_label)
        for d, event_feature_map in enumerate(res['cluster_feature'][batch_index]):
            coords = event_feature_map[:, :data_dim]
            perm = np.lexsort((coords[:, 2], coords[:, 1], coords[:, 0]))
            coords = coords[perm]
            class_label = event_label[-(d+1+max_depth-depth)]
            cluster_count = 0
            for class_ in range(num_classes):
                class_index = class_label[:, -1] == class_
                if np.count_nonzero(class_index) == 0:
                    continue
                clusters_label = event_cluster_label[-(d+1+max_depth-depth)][class_index]
                embedding = event_feature_map[perm][class_index]
                # DBSCAN in high dimension embedding
                predicted_clusters = DBSCAN(eps=20, min_samples=1).fit(embedding).labels_
                predicted_clusters += cluster_count  # To avoid overlapping id
                cluster_count += len(np.unique(predicted_clusters))

                # Cluster similarity metrics
                ARI = metrics.adjusted_rand_score(clusters_label[:, -1], predicted_clusters)
                AMI = metrics.adjusted_mutual_info_score(clusters_label[:, -1], predicted_clusters)
                fout_metrics.record(('class', 'batch_id', 'AMI', 'ARI', 'idx'),
                                    (class_, batch_index, AMI, ARI, event_index))
                fout_metrics.write()

                for i, point in enumerate(clusters_label):
                    fout_cluster.record(('type', 'x', 'y', 'z', 'batch_id', 'value', 'predicted_class', 'true_class', 'true_cluster_id', 'predicted_cluster_id', 'idx'),
                                        (1, point[0], point[1], point[2], batch_index, d, -1, class_label[class_index][i, -1], clusters_label[i, -1], predicted_clusters[i], event_index))
                    fout_cluster.write()
                # TSNE to visualize embedding
                if compute_tsne and embedding.shape[0] > 1:
                    new_embedding = tsne.fit_transform(embedding)
                    for i, point in enumerate(new_embedding):
                        fout_cluster.record(('type', 'x', 'y', 'z', 'batch_id', 'value', 'predicted_class', 'true_class', 'true_cluster_id', 'predicted_cluster_id', 'idx'),
                                            (2, point[0], point[1], -1, batch_index, d, -1, class_label[class_index][i, -1], clusters_label[i, -1], predicted_clusters[i], event_index))
                        fout_cluster.write()

        # Record in CSV everything
        perm = np.lexsort((event_data[:, 2], event_data[:, 1], event_data[:, 0]))
        event_data = event_data[perm]
        event_segmentation = event_segmentation[perm]
        # Point in data and semantic class predictions/true information
        for i, point in enumerate(event_data):
            fout_cluster.record(('type', 'x', 'y', 'z', 'batch_id', 'value', 'predicted_class', 'true_class', 'true_cluster_id', 'predicted_cluster_id', 'idx'),
                                (0, point[0], point[1], point[2], batch_index, point[4], np.argmax(event_segmentation[i]), event_label[0][:,-1][i], -1, -1, event_index))
            fout_cluster.write()

        if not store_per_iteration:
            fout_cluster.close()
            fout_metrics.close()
    if store_per_iteration:
        fout_cluster.close()
        fout_metrics.close()
Ejemplo n.º 4
0
def michel_reconstruction_2d(cfg, data_blob, res, logdir, iteration):
    """
    Very simple algorithm to reconstruct Michel clusters from UResNet semantic
    segmentation output.

    Parameters
    ----------
    data_blob: dict
        Input dictionary returned by iotools
    res: dict
        Results from the network, dictionary using `analysis_keys`
    cfg: dict
        Configuration
    idx: int
        Iteration number

    Notes
    -----
    Assumes 2D

    Input
    -----
    Requires the following analysis keys:
    - `segmentation` output of UResNet
    Requires the following input keys:
    - `input_data`
    - `segment_label`
    - `particles_label` to get detailed information such as energy.
    - `clusters_label` from `cluster3d_mcst` for true clusters informations

    Output
    ------
    Writes 2 CSV files:
    - `michel_reconstruction-*`
    - `michel_reconstruction2-*`
    """
    method_cfg = cfg['post_processing']['michel_reconstruction_2d']

    # Create output CSV
    store_per_iteration = True
    if method_cfg is not None and method_cfg.get('store_method',
                                                 None) is not None:
        assert (method_cfg['store_method'] in ['per-iteration', 'per-event'])
        store_per_iteration = method_cfg['store_method'] == 'per-iteration'

    fout_reco, fout_true = None, None
    if store_per_iteration:
        fout_reco = CSVData(
            os.path.join(
                logdir,
                'michel-reconstruction-reco-iter-%07d.csv' % iteration))
        fout_true = CSVData(
            os.path.join(
                logdir,
                'michel-reconstruction-true-iter-%07d.csv' % iteration))

    # Loop over events
    for batch_id, data in enumerate(data_blob['input_data']):

        event_idx = data_blob['index'][batch_id]

        if not store_per_iteration:
            fout_reco = CSVData(
                os.path.join(
                    logdir,
                    'michel-reconstruction-reco-event-%07d.csv' % event_idx))
            fout_true = CSVData(
                os.path.join(
                    logdir,
                    'michel-reconstruction-true-event-%07d.csv' % event_idx))

        # from input/labels
        data = data_blob['input_data'][batch_id]
        label = data_blob['segment_label'][batch_id][:, -1]
        meta = data_blob['meta'][batch_id]

        # clusters    = data_blob['clusters_label' ][batch_id]
        # particles   = data_blob['particles_label'][batch_id]

        # Michel_particles = particles[particles[:, 2] == 4]  # FIXME 3 or 4 in 2D? Also not sure if type is registered for Michel

        # from network output
        segmentation = res['segmentation'][batch_id]
        predictions = np.argmax(segmentation, axis=1)
        Michel_label = 3
        MIP_label = 0

        data_dim = 2
        # 0. Retrieve coordinates of true and predicted Michels
        MIP_coords = data[(label == MIP_label).reshape((-1, )),
                          ...][:, :data_dim]
        Michel_coords = data[(label == Michel_label).reshape((-1, )),
                             ...][:, :data_dim]
        # MIP_coords = clusters[clusters[:, -1] == 1][:, :3]
        # Michel_coords = clusters[clusters[:, -1] == 4][:, :3]
        if Michel_coords.shape[0] == 0:  # FIXME
            continue
        MIP_coords_pred = data[(predictions == MIP_label).reshape((-1, )),
                               ...][:, :data_dim]
        Michel_coords_pred = data[(predictions == Michel_label).reshape(
            (-1, )), ...][:, :data_dim]

        # DBSCAN epsilon used for many things... TODO list here
        #one_pixel = 15#2.8284271247461903
        one_pixel_dbscan = 5
        one_pixel_is_attached = 2
        # 1. Find true particle information matching the true Michel cluster
        Michel_true_clusters = DBSCAN(eps=one_pixel_dbscan,
                                      min_samples=5).fit(Michel_coords).labels_
        MIP_true_clusters = DBSCAN(eps=one_pixel_dbscan,
                                   min_samples=5).fit(MIP_coords).labels_

        # compute all edges of true MIP clusters
        MIP_edges = []
        for cluster in np.unique(MIP_true_clusters[MIP_true_clusters > -1]):
            touching_idx = find_edges(MIP_coords[MIP_true_clusters == cluster])
            MIP_edges.append(
                MIP_coords[MIP_true_clusters == cluster][touching_idx[0]])
            MIP_edges.append(
                MIP_coords[MIP_true_clusters == cluster][touching_idx[1]])
        # Michel_true_clusters = [Michel_coords[Michel_coords[:, -2] == gid] for gid in np.unique(Michel_coords[:, -2])]
        # Michel_true_clusters = clusters[clusters[:, -1] == 4][:, -2].astype(np.int64)
        # Michel_start = Michel_particles[:, :data_dim]

        true_Michel_is_attached = {}
        true_Michel_is_edge = {}
        true_Michel_is_too_close = {}
        for cluster in np.unique(Michel_true_clusters):
            min_y = Michel_coords[Michel_true_clusters ==
                                  cluster][:, 1].min()  # * meta[-1] + meta[1]
            max_y = Michel_coords[Michel_true_clusters ==
                                  cluster][:, 1].max()  # * meta[-1] + meta[1]
            min_x = Michel_coords[Michel_true_clusters ==
                                  cluster][:, 0].min()  # * meta[-2] + meta[0]
            max_x = Michel_coords[Michel_true_clusters ==
                                  cluster][:, 0].max()  # * meta[-2] + meta[0]

            # Find coordinates of Michel pixel touching MIP edge
            Michel_edges_idx = find_edges(
                Michel_coords[Michel_true_clusters == cluster])
            distances = cdist(
                Michel_coords[Michel_true_clusters == cluster]
                [Michel_edges_idx], MIP_coords[MIP_true_clusters > -1])

            # Make sure true Michel is attached at edge of MIP
            Michel_min, MIP_min = np.unravel_index(np.argmin(distances),
                                                   distances.shape)
            is_attached = np.min(distances) < one_pixel_is_attached
            is_too_close = np.max(distances) < one_pixel_is_attached
            # Check whether the Michel is at the edge of a predicted MIP
            # From the MIP pixel closest to the Michel, remove all pixels in
            # a radius of 15px. DBSCAN what is left and make sure it is all in
            # one single piece.
            is_edge = False  # default
            if is_attached:
                # cluster id of MIP closest
                MIP_id = MIP_true_clusters[MIP_true_clusters > -1][MIP_min]
                # coordinates of closest MIP pixel in this cluster
                MIP_min_coords = MIP_coords[MIP_true_clusters > -1][MIP_min]
                # coordinates of the whole cluster
                MIP_cluster_coords = MIP_coords[MIP_true_clusters == MIP_id]
                is_edge = is_at_edge(MIP_cluster_coords,
                                     MIP_min_coords,
                                     one_pixel=one_pixel_dbscan,
                                     radius=15.0)
            true_Michel_is_attached[cluster] = is_attached
            true_Michel_is_edge[cluster] = is_edge
            true_Michel_is_too_close[cluster] = is_too_close

            # these are the coordinates of Michel edges
            edge1_x = Michel_coords[Michel_true_clusters == cluster][
                Michel_edges_idx[0], 0]
            edge1_y = Michel_coords[Michel_true_clusters == cluster][
                Michel_edges_idx[0], 1]
            edge2_x = Michel_coords[Michel_true_clusters == cluster][
                Michel_edges_idx[1], 0]
            edge2_y = Michel_coords[Michel_true_clusters == cluster][
                Michel_edges_idx[1], 1]

            # Find for each Michel edge the closest MIP pixels
            # Check for each of these whether they are at the edge of MIP
            # FIXME what happens if both are at the edge of a MIP?? unlikely
            closest_MIP_pixels = np.argmin(distances, axis=1)
            clusters_idx = MIP_true_clusters[closest_MIP_pixels]
            edge0 = is_at_edge(
                MIP_coords[MIP_true_clusters == clusters_idx[0]],
                MIP_coords[closest_MIP_pixels[0]],
                one_pixel=one_pixel_dbscan,
                radius=10.0)
            edge1 = is_at_edge(
                MIP_coords[MIP_true_clusters == clusters_idx[1]],
                MIP_coords[closest_MIP_pixels[1]],
                one_pixel=one_pixel_dbscan,
                radius=10.0)
            if edge0 and not edge1:
                touching_x = edge1_x
                touching_y = edge1_y
            elif not edge0 and edge1:
                touching_x = edge2_x
                touching_y = edge2_y
            else:
                if distances[0, closest_MIP_pixels[0]] < distances[
                        1, closest_MIP_pixels[1]]:
                    touching_x = edge1_x
                    touching_y = edge1_y
                else:
                    touching_x = edge2_x
                    touching_y = edge2_y
            # touching_idx = np.unravel_index(np.argmin(distances), distances.shape)
            # touching_x = Michel_coords[Michel_true_clusters == cluster][Michel_edges_idx][touching_idx[0], 0]
            # touching_y = Michel_coords[Michel_true_clusters == cluster][Michel_edges_idx][touching_idx[0], 1]
            #
            # if touching_x not in [edge1_x, edge2_x] or touching_y not in [edge1_y, edge2_y]:
            #     print('true', event_idx, touching_x, touching_y, edge1_x, edge1_y, edge2_x, edge2_y)
            #if event_idx == 127:
            #    print('true', touching_x, touching_y, edge1_x, edge1_y, edge2_x, edge2_y)
            fout_true.record(
                ('batch_id', 'iteration', 'event_idx', 'num_pix', 'sum_pix',
                 'min_y', 'max_y', 'min_x', 'max_x', 'pixel_width',
                 'pixel_height', 'meta_min_x', 'meta_min_y', 'touching_x',
                 'touching_y', 'edge1_x', 'edge1_y', 'edge2_x', 'edge2_y',
                 'edge0', 'edge1', 'is_attached', 'is_edge', 'is_too_close',
                 'cluster_id'),
                (
                    batch_id,
                    iteration,
                    event_idx,
                    np.count_nonzero(Michel_true_clusters == cluster),
                    data[(label == Michel_label).reshape((-1, )),
                         ...][Michel_true_clusters == cluster][:, -1].sum(),
                    # clusters[clusters[:, -1] == 4][Michel_true_clusters == cluster][:, -3].sum()
                    min_y,
                    max_y,
                    min_x,
                    max_x,
                    meta[-2],
                    meta[-1],
                    meta[0],
                    meta[1],
                    touching_x,
                    touching_y,
                    edge1_x,
                    edge1_y,
                    edge2_x,
                    edge2_y,
                    edge0,
                    edge1,
                    is_attached,
                    is_edge,
                    is_too_close,
                    cluster))
            fout_true.write()
        # e.g. deposited energy, creation energy
        # TODO retrieve particles information
        # if Michel_coords.shape[0] > 0:
        #     Michel_clusters_id = np.unique(Michel_true_clusters[Michel_true_clusters>-1])
        #     for Michel_id in Michel_clusters_id:
        #         current_index = Michel_true_clusters == Michel_id
        #         distances = cdist(Michel_coords[current_index], MIP_coords)
        #         is_attached = np.min(distances) < 2.8284271247461903
        #         # Match to MC Michel
        #         distances2 = cdist(Michel_coords[current_index], Michel_start)
        #         closest_mc = np.argmin(distances2, axis=1)
        #         closest_mc_id = closest_mc[np.bincount(closest_mc).argmax()]

        # TODO how do we count events where there are no predictions but true?
        if MIP_coords_pred.shape[0] == 0 or Michel_coords_pred.shape[0] == 0:
            continue

        #
        # 2. Compute true and predicted clusters
        #
        MIP_clusters = DBSCAN(eps=one_pixel_dbscan,
                              min_samples=10).fit(MIP_coords_pred).labels_
        MIP_clusters_id = np.unique(MIP_clusters[MIP_clusters > -1])

        # If no predicted MIP then continue TODO how do we count this?
        if MIP_coords_pred[MIP_clusters > -1].shape[0] == 0:
            continue

        # MIP_edges = []
        # for cluster in MIP_clusters_id:
        #     touching_idx = find_edges(MIP_coords_pred[MIP_clusters == cluster])
        #     MIP_edges.append(MIP_coords_pred[MIP_clusters == cluster][touching_idx[0]])
        #     MIP_edges.append(MIP_coords_pred[MIP_clusters == cluster][touching_idx[1]])

        Michel_pred_clusters = DBSCAN(
            eps=one_pixel_dbscan,
            min_samples=5).fit(Michel_coords_pred).labels_
        Michel_pred_clusters_id = np.unique(
            Michel_pred_clusters[Michel_pred_clusters > -1])
        # print(len(Michel_pred_clusters_id))

        # Loop over predicted Michel clusters
        for Michel_id in Michel_pred_clusters_id:
            current_index = Michel_pred_clusters == Michel_id
            # 3. Check whether predicted Michel is attached to a predicted MIP
            # and at the edge of the predicted MIP
            Michel_edges_idx = find_edges(Michel_coords_pred[current_index])
            # distances_edges = cdist(Michel_coords_pred[current_index][Michel_edges_idx], MIP_edges)
            # distances = cdist(Michel_coords_pred[current_index], MIP_coords_pred[MIP_clusters>-1])
            distances = cdist(
                Michel_coords_pred[current_index][Michel_edges_idx],
                MIP_coords_pred[MIP_clusters > -1])
            Michel_min, MIP_min = np.unravel_index(np.argmin(distances),
                                                   distances.shape)
            is_attached = np.min(distances) < one_pixel_is_attached
            is_too_close = np.max(distances) < one_pixel_is_attached
            # Check whether the Michel is at the edge of a predicted MIP
            # From the MIP pixel closest to the Michel, remove all pixels in
            # a radius of 15px. DBSCAN what is left and make sure it is all in
            # one single piece.
            is_edge = False  # default
            if is_attached:
                # cluster id of MIP closest
                MIP_id = MIP_clusters[MIP_clusters > -1][MIP_min]
                # coordinates of closest MIP pixel in this cluster
                MIP_min_coords = MIP_coords_pred[MIP_clusters > -1][MIP_min]
                # coordinates of the whole cluster
                MIP_cluster_coords = MIP_coords_pred[MIP_clusters == MIP_id]
                is_edge = is_at_edge(MIP_cluster_coords,
                                     MIP_min_coords,
                                     one_pixel=one_pixel_dbscan,
                                     radius=15.0)

            michel_pred_num_pix_true, michel_pred_sum_pix_true = -1, -1
            michel_true_num_pix, michel_true_sum_pix = -1, -1
            michel_true_energy = -1
            touching_x, touching_y = -1, -1
            edge1_x, edge1_y, edge2_x, edge2_y = -1, -1, -1, -1
            true_is_attached, true_is_edge, true_is_too_close = -1, -1, -1
            closest_true_id = -1

            # Find point where MIP and Michel touches
            # touching_idx = np.unravel_index(np.argmin(distances), distances.shape)
            # touching_x = Michel_coords_pred[current_index][Michel_edges_idx][touching_idx[0], 0]
            # touching_y = Michel_coords_pred[current_index][Michel_edges_idx][touching_idx[0], 1]
            edge1_x = Michel_coords_pred[current_index][Michel_edges_idx[0], 0]
            edge1_y = Michel_coords_pred[current_index][Michel_edges_idx[0], 1]
            edge2_x = Michel_coords_pred[current_index][Michel_edges_idx[1], 0]
            edge2_y = Michel_coords_pred[current_index][Michel_edges_idx[1], 1]

            closest_MIP_pixels = np.argmin(distances, axis=1)
            clusters_idx = MIP_clusters[MIP_clusters > -1][np.argmin(distances,
                                                                     axis=1)]
            edge0 = is_at_edge(
                MIP_coords_pred[MIP_clusters == clusters_idx[0]],
                MIP_coords_pred[MIP_clusters > -1][closest_MIP_pixels[0]],
                one_pixel=one_pixel_dbscan,
                radius=10.0)
            edge1 = is_at_edge(
                MIP_coords_pred[MIP_clusters == clusters_idx[1]],
                MIP_coords_pred[MIP_clusters > -1][closest_MIP_pixels[1]],
                one_pixel=one_pixel_dbscan,
                radius=10.0)
            if edge0 and not edge1:
                touching_x = edge1_x
                touching_y = edge1_y
            elif not edge0 and edge1:
                touching_x = edge2_x
                touching_y = edge2_y
            else:
                if distances[0, closest_MIP_pixels[0]] < distances[
                        1, closest_MIP_pixels[1]]:
                    touching_x = edge1_x
                    touching_y = edge1_y
                else:
                    touching_x = edge2_x
                    touching_y = edge2_y

            if is_attached and is_edge:
                # Distance from current Michel pred cluster to all true points
                distances = cdist(Michel_coords_pred[current_index],
                                  Michel_coords)
                closest_clusters = Michel_true_clusters[np.argmin(distances,
                                                                  axis=1)]
                closest_clusters_final = closest_clusters[
                    (closest_clusters > -1)
                    & (np.min(distances, axis=1) < one_pixel_dbscan)]
                if len(closest_clusters_final) > 0:
                    # print(closest_clusters_final, np.bincount(closest_clusters_final), np.bincount(closest_clusters_final).argmax())
                    # cluster id of closest true Michel cluster
                    # we take the one that has most overlap
                    # closest_true_id = closest_clusters_final[np.bincount(closest_clusters_final).argmax()]
                    closest_true_id = np.bincount(
                        closest_clusters_final).argmax()
                    overlap_pixels_index = (
                        closest_clusters == closest_true_id) & (np.min(
                            distances, axis=1) < one_pixel_dbscan)
                    if closest_true_id > -1:
                        closest_true_index = label[
                            predictions ==
                            Michel_label][current_index] == Michel_label
                        # Intersection
                        michel_pred_num_pix_true = 0
                        michel_pred_sum_pix_true = 0.
                        for v in data[(predictions == Michel_label).reshape(
                            (-1, )), ...][current_index]:
                            count = int(
                                np.any(
                                    np.all(v[:data_dim] ==
                                           Michel_coords[Michel_true_clusters
                                                         == closest_true_id],
                                           axis=1)))
                            michel_pred_num_pix_true += count
                            if count > 0:
                                michel_pred_sum_pix_true += v[-1]

                        michel_true_num_pix = np.count_nonzero(
                            Michel_true_clusters == closest_true_id)
                        # michel_true_sum_pix = clusters[clusters[:, -1] == 4][Michel_true_clusters == closest_true_id][:, -3].sum()
                        michel_true_sum_pix = data[(
                            label == Michel_label).reshape(
                                (-1, )), ...][Michel_true_clusters ==
                                              closest_true_id][:, -1].sum()

                        # Check whether true Michel is attached to MIP, otherwise exclude
                        true_is_attached = true_Michel_is_attached[
                            closest_true_id]
                        true_is_edge = true_Michel_is_edge[closest_true_id]
                        true_is_too_close = true_Michel_is_too_close[
                            closest_true_id]
                        # Register true energy
                        # Match to MC Michel
                        # FIXME in 2D Michel_start is no good
                        # distances2 = cdist(Michel_coords[Michel_true_clusters == closest_true_id], Michel_start)
                        # closest_mc = np.argmin(distances2, axis=1)
                        # closest_mc_id = closest_mc[np.bincount(closest_mc).argmax()]
                        # michel_true_energy = Michel_particles[closest_mc_id, 7]
                        michel_true_energy = -1

            # Record every predicted Michel cluster in CSV
            # Record min and max x in real coordinates
            min_y = Michel_coords_pred[current_index][:, 1].min(
            )  # * meta[-1] + meta[1]
            max_y = Michel_coords_pred[current_index][:, 1].max(
            )  # * meta[-1] + meta[1]
            min_x = Michel_coords_pred[current_index][:, 0].min(
            )  # * meta[-2] + meta[0]
            max_x = Michel_coords_pred[current_index][:, 0].max(
            )  # * meta[-2] + meta[0]
            fout_reco.record(
                ('batch_id', 'iteration', 'event_idx', 'pred_num_pix',
                 'pred_sum_pix', 'pred_num_pix_true', 'pred_sum_pix_true',
                 'true_num_pix', 'true_sum_pix', 'is_attached', 'is_edge',
                 'michel_true_energy', 'min_y', 'max_y', 'min_x', 'max_x',
                 'pixel_width', 'pixel_height', 'meta_min_x', 'meta_min_y',
                 'touching_x', 'touching_y', 'edge1_x', 'edge1_y', 'edge2_x',
                 'edge2_y', 'edge0', 'edge1', 'true_is_attached',
                 'true_is_edge', 'true_is_too_close', 'is_too_close',
                 'closest_true_index'),
                (batch_id, iteration, event_idx,
                 np.count_nonzero(current_index),
                 data[(predictions == Michel_label).reshape(
                     (-1, )), ...][current_index][:, -1].sum(),
                 michel_pred_num_pix_true, michel_pred_sum_pix_true,
                 michel_true_num_pix, michel_true_sum_pix, is_attached,
                 is_edge, michel_true_energy, min_y, max_y, min_x, max_x,
                 meta[-2], meta[-1], meta[0], meta[1], touching_x, touching_y,
                 edge1_x, edge1_y, edge2_x, edge2_y, edge0, edge1,
                 true_is_attached, true_is_edge, true_is_too_close,
                 is_too_close, closest_true_id))
            fout_reco.write()

        if not store_per_iteration:
            fout_reco.close()
            fout_true.close()

    if store_per_iteration:
        fout_reco.close()
        fout_true.close()
Ejemplo n.º 5
0
def michel_reconstruction(cfg, data_blob, res, logdir, iteration):
    """
    Very simple algorithm to reconstruct Michel clusters from UResNet semantic
    segmentation output.

    Parameters
    ----------
    data_blob: dict
        Input dictionary returned by iotools
    res: dict
        Results from the network, dictionary using `analysis_keys`
    cfg: dict
        Configuration
    idx: int
        Iteration number

    Notes
    -----
    Assumes 3D

    Input
    -----
    Requires the following analysis keys:
    - `segmentation` output of UResNet
    - `ghost` predictions of GhostNet
    Requires the following input keys:
    - `input_data`
    - `segment_label`
    - `particles_label` to get detailed information such as energy.
    - `clusters_label` from `cluster3d_mcst` for true clusters informations

    Output
    ------
    Writes 2 CSV files:
    - `michel_reconstruction-*`
    - `michel_reconstruction2-*`
    """
    method_cfg = cfg['post_processing']['michel_reconstruction']

    # Create output CSV
    store_per_iteration = True
    if method_cfg is not None and method_cfg.get('store_method',
                                                 None) is not None:
        assert (method_cfg['store_method'] in ['per-iteration', 'per-event'])
        store_per_iteration = method_cfg['store_method'] == 'per-iteration'

    fout_reco, fout_true = None, None
    if store_per_iteration:
        fout_reco = CSVData(
            os.path.join(
                logdir,
                'michel-reconstruction-reco-iter-%07d.csv' % iteration))
        fout_true = CSVData(
            os.path.join(
                logdir,
                'michel-reconstruction-true-iter-%07d.csv' % iteration))

    # Loop over events
    for batch_id, data in enumerate(data_blob['input_data']):

        event_idx = data_blob['index'][batch_id]

        if not store_per_iteration:
            fout_reco = CSVData(
                os.path.join(
                    logdir,
                    'michel-reconstruction-reco-event-%07d.csv' % event_idx))
            fout_true = CSVData(
                os.path.join(
                    logdir,
                    'michel-reconstruction-true-event-%07d.csv' % event_idx))

        # from input/labels
        label = data_blob['segment_label'][batch_id][:, -1]
        label_raw = data_blob['sparse3d_pcluster_semantics'][batch_id]
        clusters = data_blob['clusters_label'][batch_id]
        particles = data_blob['particles_label'][batch_id]
        true_ghost_mask = label < 5
        data_masked = data[true_ghost_mask]
        label_masked = label[true_ghost_mask]

        one_pixel = 5  #2.8284271247461903

        # Retrieve semantic labels corresponding to clusters
        clusters_semantics = np.zeros((clusters.shape[0])) - 1
        for cluster_id in np.unique(clusters[:, -2]):
            cluster_idx = clusters[:, -2] == cluster_id
            coords = clusters[cluster_idx][:, :3]
            d = cdist(coords, label_raw[:, :3])
            semantic_id = np.bincount(label_raw[d.argmin(
                axis=1)[d.min(axis=1) < one_pixel]][:,
                                                    -1].astype(int)).argmax()
            clusters_semantics[cluster_idx] = semantic_id

        # Find cluster id for semantics_reco
        # clusters_new = np.ones((label_masked.shape[0],))*-1
        # clusters_E = np.ones((label_masked.shape[0],))
        # for cluster_id in np.unique(clusters[:, -2]):
        #     cluster_idx = clusters[:, -2] == cluster_id
        #     coords = clusters[cluster_idx][:, :3]
        #     d = cdist(coords, data_masked[:, :3])
        #     overlap_idx = d.argmin(axis=0)[d.min(axis=0)<one_pixel]
        #     clusters_new[overlap_idx] = np.bincount(clusters[cluster_idx][d.argmin(axis=0)[d.min(axis=0)<one_pixel]][:, -2].astype(int)).argmax()
        #     clusters_E[overlap_idx] = clusters[cluster_idx][overlap_idx][:, -1]
        #     #clusters_new[overlap_idx][:, :3] = data_masked[overlap_idx][:, :3]
        # print('clusters new', np.unique(clusters_new, return_counts=True))

        # from network output
        segmentation = res['segmentation'][batch_id]
        predictions = np.argmax(segmentation, axis=1)
        ghost_mask = (np.argmax(res['ghost'][batch_id], axis=1) == 0)

        data_pred = data[ghost_mask]  # coords
        label_pred = label[ghost_mask]  # labels
        predictions = (np.argmax(segmentation, axis=1))[ghost_mask]
        segmentation = segmentation[ghost_mask]

        Michel_label = 2
        MIP_label = 1

        # 0. Retrieve coordinates of true and predicted Michels
        # MIP_coords = data[(label == 1).reshape((-1,)), ...][:, :3]
        # Michel_coords = data[(label == 4).reshape((-1,)), ...][:, :3]
        # Michel_particles = particles[particles[:, 4] == Michel_label]
        MIP_coords = data[label == MIP_label][:, :3]
        # Michel_coords = data[label == Michel_label][:, :3]
        Michel_coords = clusters[clusters_semantics == Michel_label][:, :3]
        if Michel_coords.shape[0] == 0:  # FIXME
            continue
        MIP_coords_pred = data_pred[(predictions == MIP_label).reshape((-1, )),
                                    ...][:, :3]
        Michel_coords_pred = data_pred[(predictions == Michel_label).reshape(
            (-1, )), ...][:, :3]

        # 1. Find true particle information matching the true Michel cluster
        # Michel_true_clusters = DBSCAN(eps=one_pixel, min_samples=5).fit(Michel_coords).labels_
        # Michel_true_clusters = [Michel_coords[Michel_coords[:, -2] == gid] for gid in np.unique(Michel_coords[:, -2])]
        #print(clusters.shape, label.shape)
        Michel_true_clusters = clusters[clusters_semantics ==
                                        Michel_label][:, -2].astype(np.int64)
        # Michel_start = Michel_particles[:, :3]
        for cluster in np.unique(Michel_true_clusters):
            # print("True", np.count_nonzero(Michel_true_clusters == cluster))
            # TODO sum_pix
            fout_true.record(
                ('batch_id', 'iteration', 'event_idx', 'num_pix', 'sum_pix'),
                (batch_id, iteration, event_idx,
                 np.count_nonzero(Michel_true_clusters == cluster),
                 clusters[clusters_semantics == Michel_label][
                     Michel_true_clusters == cluster][:, -1].sum()))
            fout_true.write()
        # e.g. deposited energy, creation energy
        # TODO retrieve particles information
        # if Michel_coords.shape[0] > 0:
        #     Michel_clusters_id = np.unique(Michel_true_clusters[Michel_true_clusters>-1])
        #     for Michel_id in Michel_clusters_id:
        #         current_index = Michel_true_clusters == Michel_id
        #         distances = cdist(Michel_coords[current_index], MIP_coords)
        #         is_attached = np.min(distances) < 2.8284271247461903
        #         # Match to MC Michel
        #         distances2 = cdist(Michel_coords[current_index], Michel_start)
        #         closest_mc = np.argmin(distances2, axis=1)
        #         closest_mc_id = closest_mc[np.bincount(closest_mc).argmax()]

        # TODO how do we count events where there are no predictions but true?
        if MIP_coords_pred.shape[0] == 0 or Michel_coords_pred.shape[0] == 0:
            continue
        # print("Also predicted!")
        # 2. Compute true and predicted clusters
        MIP_clusters = DBSCAN(eps=one_pixel,
                              min_samples=10).fit(MIP_coords_pred).labels_
        Michel_pred_clusters = DBSCAN(
            eps=one_pixel, min_samples=5).fit(Michel_coords_pred).labels_
        Michel_pred_clusters_id = np.unique(
            Michel_pred_clusters[Michel_pred_clusters > -1])
        # print(len(Michel_pred_clusters_id))
        # Loop over predicted Michel clusters
        for Michel_id in Michel_pred_clusters_id:
            current_index = Michel_pred_clusters == Michel_id
            # 3. Check whether predicted Michel is attached to a predicted MIP
            # and at the edge of the predicted MIP
            distances = cdist(Michel_coords_pred[current_index],
                              MIP_coords_pred[MIP_clusters > -1])
            # is_attached = np.min(distances) < 2.8284271247461903
            is_attached = np.min(distances) < 5
            is_edge = False  # default
            # print("Min distance:", np.min(distances))
            if is_attached:
                Michel_min, MIP_min = np.unravel_index(np.argmin(distances),
                                                       distances.shape)
                MIP_id = MIP_clusters[MIP_clusters > -1][MIP_min]
                MIP_min_coords = MIP_coords_pred[MIP_clusters > -1][MIP_min]
                MIP_cluster_coords = MIP_coords_pred[MIP_clusters == MIP_id]
                ablated_cluster = MIP_cluster_coords[np.linalg.norm(
                    MIP_cluster_coords - MIP_min_coords, axis=1) > 15.0]
                if ablated_cluster.shape[0] > 0:
                    new_cluster = DBSCAN(
                        eps=one_pixel,
                        min_samples=5).fit(ablated_cluster).labels_
                    is_edge = len(np.unique(
                        new_cluster[new_cluster > -1])) == MIP_label
                else:
                    is_edge = True
            # print(is_attached, is_edge)

            michel_pred_num_pix_true, michel_pred_sum_pix_true = -1, -1
            michel_true_num_pix, michel_true_sum_pix = -1, -1
            michel_true_energy = -1
            if is_attached and is_edge and Michel_coords.shape[0] > 0:
                # Distance from current Michel pred cluster to all true points
                distances = cdist(Michel_coords_pred[current_index],
                                  Michel_coords)
                closest_clusters = Michel_true_clusters[np.argmin(distances,
                                                                  axis=1)]
                closest_clusters_final = closest_clusters[
                    (closest_clusters > -1)
                    & (np.min(distances, axis=1) < one_pixel)]
                if len(closest_clusters_final) > 0:
                    # print(closest_clusters_final, np.bincount(closest_clusters_final), np.bincount(closest_clusters_final).argmax())
                    # cluster id of closest true Michel cluster
                    # we take the one that has most overlap
                    # closest_true_id = closest_clusters_final[np.bincount(closest_clusters_final).argmax()]
                    closest_true_id = np.bincount(
                        closest_clusters_final).argmax()
                    overlap_pixels_index = (closest_clusters
                                            == closest_true_id) & (np.min(
                                                distances, axis=1) < one_pixel)
                    if closest_true_id > -1:
                        closest_true_index = label_pred[
                            predictions ==
                            Michel_label][current_index] == Michel_label
                        # Intersection
                        michel_pred_num_pix_true = 0
                        michel_pred_sum_pix_true = 0.
                        for v in data_pred[(
                                predictions == Michel_label).reshape((-1, )),
                                           ...][current_index]:
                            count = int(
                                np.any(
                                    np.all(v[:3] ==
                                           Michel_coords[Michel_true_clusters
                                                         == closest_true_id],
                                           axis=1)))
                            michel_pred_num_pix_true += count
                            if count > 0:
                                michel_pred_sum_pix_true += v[-1]

                        michel_true_num_pix = np.count_nonzero(
                            Michel_true_clusters == closest_true_id)
                        michel_true_sum_pix = clusters[
                            clusters_semantics == Michel_label][
                                Michel_true_clusters ==
                                closest_true_id][:, -1].sum()
                        # Register true energy
                        # Match to MC Michel
                        # distances2 = cdist(Michel_coords[Michel_true_clusters == closest_true_id], Michel_start)
                        # closest_mc = np.argmin(distances2, axis=1)
                        # closest_mc_id = closest_mc[np.bincount(closest_mc).argmax()]
                        # michel_true_energy = Michel_particles[closest_mc_id, 7]
                        michel_true_energy = particles[
                            closest_true_id].energy_init()
                        #print('michel true energy', particles[closest_true_id].energy_init(), particles[closest_true_id].pdg_code(), particles[closest_true_id].energy_deposit())
            # Record every predicted Michel cluster in CSV
            fout_reco.record(
                ('batch_id', 'iteration', 'event_idx', 'pred_num_pix',
                 'pred_sum_pix', 'pred_num_pix_true', 'pred_sum_pix_true',
                 'true_num_pix', 'true_sum_pix', 'is_attached', 'is_edge',
                 'michel_true_energy'),
                (batch_id, iteration, event_idx,
                 np.count_nonzero(current_index),
                 data_pred[(predictions == Michel_label).reshape(
                     (-1, )), ...][current_index][:, -1].sum(),
                 michel_pred_num_pix_true, michel_pred_sum_pix_true,
                 michel_true_num_pix, michel_true_sum_pix, is_attached,
                 is_edge, michel_true_energy))
            fout_reco.write()

        if not store_per_iteration:
            fout_reco.close()
            fout_true.close()

    if store_per_iteration:
        fout_reco.close()
        fout_true.close()
Ejemplo n.º 6
0
def store_input(cfg, data_blob, res, logdir, iteration):
    """
    Store input data blob.

    Configuration
    -------------
    threshold: float, optional
        Default: 0.
    input_data: str, optional
    particles_label: str, optional
    segment_label: str, optional
    clusters_label: str, optional
    cluster3d_mcst_true: str, optional
    store_method: str, optional
        Can be `per-iteration` or `per-event`
    """
    method_cfg = cfg['post_processing']['store_input']

    if (method_cfg is not None
            and not method_cfg.get('input_data', 'input_data') in data_blob
        ) or (method_cfg is None and 'input_data' not in data_blob):
        return

    threshold = 0. if method_cfg is None else method_cfg.get('threshold', 0.)
    data_dim = 3 if method_cfg is None else method_cfg.get('data_dim', 3)

    index = data_blob.get('index', None)
    input_dat = data_blob.get(
        'input_data' if method_cfg is None else method_cfg.get(
            'input_data', 'input_data'), None)
    label_ppn = data_blob.get(
        'particles_label' if method_cfg is None else method_cfg.get(
            'particles_label', 'particles_label'), None)
    label_seg = data_blob.get(
        'segment_label' if method_cfg is None else method_cfg.get(
            'segment_label', 'segment_label'), None)
    label_cls = data_blob.get(
        'clusters_label' if method_cfg is None else method_cfg.get(
            'clusters_label', 'clusters_label'), None)
    label_mcst = data_blob.get(
        'cluster3d_mcst_true' if method_cfg is None else method_cfg.get(
            'cluster3d_mcst_true', 'cluster3d_mcst_true'), None)

    store_per_iteration = True
    if method_cfg is not None and method_cfg.get('store_method',
                                                 None) is not None:
        assert (method_cfg['store_method'] in ['per-iteration', 'per-event'])
        store_per_iteration = method_cfg['store_method'] == 'per-iteration'
    fout = None
    if store_per_iteration:
        fout = CSVData(os.path.join(logdir, 'input-iter-%07d.csv' % iteration))

    if input_dat is None: return

    for data_index, tree_index in enumerate(index):

        if not store_per_iteration:
            fout = CSVData(
                os.path.join(logdir, 'input-event-%07d.csv' % tree_index))

        mask = input_dat[data_index][:, -1] > threshold

        # type 0 = input data
        for row in input_dat[data_index][mask]:
            coords_labels, coords = get_coords(row, data_dim, tree_index)
            fout.record(coords_labels + ('type', 'value'),
                        coords + (0, row[data_dim + 1]))
            fout.write()

        # type 1 = Labels for PPN
        if label_ppn is not None:
            for row in label_ppn[data_index]:
                fout.record(('idx', 'x', 'y', 'z', 'type', 'value'),
                            (tree_index, row[0], row[1], row[2], 1, row[4]))
                fout.write()
        # 2 = UResNet labels
        if label_seg is not None:
            for row in label_seg[data_index][mask]:
                coords_labels, coords = get_coords(row, data_dim, tree_index)
                fout.record(coords_labels + ('type', 'value'),
                            coords + (2, row[data_dim + 1]))
                fout.write()
        # type 15 = group id, 16 = semantic labels, 17 = energy
        if label_cls is not None:
            for row in label_cls[data_index]:
                fout.record(('idx', 'x', 'y', 'z', 'type', 'value'),
                            (tree_index, row[0], row[1], row[2], 15, row[5]))
                fout.write()
            for row in label_cls[data_index]:
                fout.record(('idx', 'x', 'y', 'z', 'type', 'value'),
                            (tree_index, row[0], row[1], row[2], 16, row[6]))
                fout.write()
            for row in label_cls[data_index]:
                fout.record(('idx', 'x', 'y', 'z', 'type', 'value'),
                            (tree_index, row[0], row[1], row[2], 17, row[4]))
                fout.write()
        # type 18 = cluster3d_mcst_true
        if label_mcst is not None:
            for row in label_mcst[data_index]:
                fout.record(('idx', 'x', 'y', 'z', 'type', 'value'),
                            (tree_index, row[0], row[1], row[2], 19, row[4]))
                fout.write()

        if not store_per_iteration: fout.close()

    if store_per_iteration: fout.close()
Ejemplo n.º 7
0
def track_clustering(cfg, data_blob, res, logdir, iteration):
    """
    Track clustering on PPN+UResNet output.

    Parameters
    ----------
    data_blob: dict
        The input data dictionary from iotools.
    res: dict
        The output of the network, formatted using `analysis_keys`.
    cfg: dict
        Configuration.
    debug: bool, optional
        Whether to print some stats or not in the stdout.

    Notes
    -----
    Based on
    - semantic segmentation output
    - point position and type predictions
    In addition, includes a break algorithm to break track clusters into
    smaller clusters based on predicted points.
    Stores all points and informations in a CSV file.
    """
    method_cfg = cfg['post_processing']['track_clustering']
    dbscan_cfg = cfg['model']['modules']['dbscan']
    data_dim = int(dbscan_cfg['data_dim'])
    min_samples = int(dbscan_cfg['minPoints'])

    debug = bool(method_cfg.get('debug', None))
    score_threshold = float(method_cfg.get('score_threshold', 0.6))
    threshold_association = float(method_cfg.get('threshold_association', 3))
    exclusion_radius = float(method_cfg.get('exclusion_radius', 5))
    type_threshold = float(method_cfg.get('type_threshold', 2))

    store_per_iteration = True
    if method_cfg is not None and method_cfg.get('store_method',
                                                 None) is not None:
        assert (method_cfg['store_method'] in ['per-iteration', 'per-event'])
        store_per_iteration = method_cfg['store_method'] == 'per-iteration'
    fout = None
    if store_per_iteration:
        fout = CSVData(
            os.path.join(logdir, 'track-clustering-iter-%07d.csv' % iteration))

    # Loop over batch index
    #for b in batch_ids:
    for batch_index, data in enumerate(data_blob['input_data']):

        if not store_per_iteration:
            fout = CSVData(
                os.path.join(logdir,
                             'track-clustering-event-%07d.csv' % event_index))

        event_clusters = res['final'][batch_index]
        event_index = data_blob['index'][batch_index]
        event_data = data[:, :data_dim]
        event_clusters_label = data_blob['clusters_label'][batch_index]
        event_particles_label = data_blob['particles_label'][batch_index]
        event_segmentation = res['segmentation'][batch_index]
        event_segmentation_label = data_blob['segment_label'][batch_index]
        points = res['points'][batch_index]
        event_xyz = points[:, :data_dim]
        event_scores = points[:, data_dim:data_dim + 2]
        event_mask = res['mask_ppn2'][batch_index]

        anchors = (event_data + 0.5)
        event_xyz = event_xyz + anchors

        dbscan_points = []
        predicted_points = []

        # 0) Postprocessing on predicted pixels
        # Apply selection mask from PPN2 + score thresholding
        scores = scipy.special.softmax(event_scores, axis=1)
        event_mask = ((~(event_mask == 0)).any(
            axis=1)) & (scores[:, 1] > score_threshold)
        # Now loop through semantic classes and look at ppn+uresnet predictions
        uresnet_predictions = np.argmax(event_segmentation[event_mask], axis=1)
        num_classes = event_segmentation.shape[1]
        ppn_type_predictions = np.argmax(scipy.special.softmax(
            points[event_mask][:, 5:], axis=1),
                                         axis=1)
        for c in range(num_classes):
            uresnet_points = uresnet_predictions == c
            ppn_points = ppn_type_predictions == c
            # We want to keep only points of type X within 2px of uresnet prediction of type X
            d = scipy.spatial.distance.cdist(
                event_xyz[event_mask][ppn_points],
                event_data[event_mask][uresnet_points])
            ppn_mask = (d < type_threshold).any(axis=1)
            # dbscan_points stores coordinates only
            # predicted_points stores everything for each point
            dbscan_points.append(event_xyz[event_mask][ppn_points][ppn_mask])
            pp = points[event_mask][ppn_points][ppn_mask]
            pp[:, :3] += anchors[event_mask][ppn_points][ppn_mask]
            predicted_points.append(pp)
        dbscan_points = np.concatenate(dbscan_points, axis=0)
        predicted_points = np.concatenate(predicted_points, axis=0)

        # 0.5) Remove points for delta rays
        point_types = np.argmax(predicted_points[:, -5:], axis=1)
        dbscan_points = dbscan_points[point_types != 3]

        # 1) Break algorithm
        # Using PPN point predictions, the idea is to mask an area around
        # each point associated with a given predicted cluster. Dbscan
        # then tells us whether this predicted cluster should be broken in
        # two or more smaller clusters. Pixel that were masked are then
        # assigned to the closest cluster among the newly formed clusters.
        if dbscan_points.shape[0] > 0:  # If PPN predicted some points
            cluster_ids = np.unique(event_clusters[:, -1])
            final_clusters = []
            # Loop over predicted clusters
            for c in cluster_ids:
                # Find predicted points associated to this predicted cluster
                cluster = event_clusters[event_clusters[:,
                                                        -1] == c][:, :data_dim]
                d = cdist(dbscan_points, cluster)
                index = d.min(axis=1) < threshold_association
                new_d = d[index.reshape((-1, )), :]
                # Now mask around these points
                new_index = (new_d > exclusion_radius).all(axis=0)
                # Main body of the cluster (far way from the points)
                new_cluster = cluster[new_index]
                # Cluster part around the points
                remaining_cluster = cluster[~new_index]
                # FIXME this might eliminate too small clusters?
                # put a threshold here? sometimes clusters with 1px only
                if new_cluster.shape[0] == 0:
                    continue
                # Now dbscan on the main body of the cluster to find if we need
                # to break it or not
                db2 = DBSCAN(eps=exclusion_radius,
                             min_samples=min_samples).fit(new_cluster).labels_
                # All points were garbage
                if (len(new_cluster[db2 == -1]) == len(new_cluster)):
                    continue
                # These are going to be the new bodies of predicted clusters
                new_cluster_ids = np.unique(db2)
                new_clusters = []
                for c2 in new_cluster_ids:
                    if c2 > -1:
                        new_clusters.append([new_cluster[db2 == c2]])
                # If some points were left by dbscan, put them in remaining
                # cluster and assign them to closest cluster
                if len(new_cluster[db2 == -1]) > 0:
                    print(len(new_cluster[db2 == -1]), len(new_cluster))
                    remaining_cluster = np.concatenate(
                        [remaining_cluster, new_cluster[db2 == -1]], axis=0)
                    # effectively remove them from new_cluster for the argmin
                    new_cluster[db2 == -1] = 100000
                # Now assign remaining pixels in remaining_cluster based on
                # their distance to the new clusters.
                # First we find which point of new_cluster was closest
                d3 = cdist(remaining_cluster, new_cluster)
                # Then we find what is the corresponding new cluster id of this
                # closest pixel
                remaining_db = db2[d3.argmin(axis=1)]
                # Now append each pixel of remaining_cluster to correct new
                # cluster
                for i, c in enumerate(remaining_cluster):
                    new_clusters[remaining_db[i]].append(c[None, :])
                # Turn everything into np arrays
                for i in range(len(new_clusters)):
                    new_clusters[i] = np.concatenate(new_clusters[i], axis=0)
                final_clusters.extend(new_clusters)
        else:  # no predicted points: no need to break, keep predicted clusters
            final_clusters = []
            cluster_idx = np.unique(event_clusters[:, -1])
            for c in cluster_idx:
                final_clusters.append(
                    event_clusters[event_clusters[:, -1] == c][:, :data_dim])

        # 2) Compute cluster efficiency/purity
        # ie associate final clusters after breaking with true clusters
        label_cluster_ids = np.unique(event_clusters_label[:, -1])
        true_clusters = []
        for c in label_cluster_ids:
            true_clusters.append(
                event_clusters_label[event_clusters_label[:, -1] == c][:, :-2])

        # Match each predicted cluster to a true cluster
        matches = []
        overlaps = []
        for predicted_cluster in final_clusters:
            overlap = []
            for true_cluster in true_clusters:
                overlap_pixel_count = np.count_nonzero(
                    (cdist(predicted_cluster, true_cluster) < 1).any(axis=0))
                overlap.append(overlap_pixel_count)
            overlap = np.array(overlap)
            if overlap.max() > 0:
                matches.append(overlap.argmax())
                overlaps.append(overlap.max())
            else:
                matches.append(-1)
                overlaps.append(0)

        # Compute cluster purity/efficiency
        purity, efficiency = [], []
        npix_predicted, npix_true = [], []
        for i, predicted_cluster in enumerate(final_clusters):
            if matches[i] > -1:
                matched_cluster = true_clusters[matches[i]]
                purity.append(overlaps[i] / predicted_cluster.shape[0])
                efficiency.append(overlaps[i] / matched_cluster.shape[0])
                npix_predicted.append(predicted_cluster.shape[0])
                npix_true.append(matched_cluster.shape[0])

        if debug:
            print("Purity: ", purity)
            print("Efficiency: ", efficiency)
            print("Match indices: ", matches)
            print("Overlaps: ", overlaps)
            print("Npix predicted: ", npix_predicted)
            print("Npix true: ", npix_true)

        # Record in CSV everything
        # Point in data and semantic class predictions/true information
        for i, point in enumerate(data):
            fout.record(
                ('type', 'x', 'y', 'z', 'batch_id', 'value', 'predicted_class',
                 'true_class', 'cluster_id', 'point_type', 'idx'),
                (0, point[0], point[1], point[2], batch_index, point[4],
                 np.argmax(event_segmentation[i]),
                 event_segmentation_label[i, -1], -1, -1, event_index))
            fout.write()
        # Predicted clusters
        for c, cluster in enumerate(final_clusters):
            for point in cluster:
                fout.record(
                    ('type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value',
                     'predicted_class', 'true_class', 'point_type', 'idx'),
                    (1, point[0], point[1], point[2], batch_index, c, -1, -1,
                     -1, -1, event_index))
                fout.write()
        # True clusters
        for c, cluster in enumerate(true_clusters):
            for point in cluster:
                fout.record(
                    ('type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value',
                     'predicted_class', 'true_class', 'point_type', 'idx'),
                    (2, point[0], point[1], point[2], batch_index, c, -1, -1,
                     -1, -1, event_index))
                fout.write()
        # for point in event_xyz:
        #     fout.record(('type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value', 'predicted_class', 'true_class', 'point_type'),
        #                       (3, point[0], point[1], point[2], batch_index, -1, -1, -1, -1, -1))
        #     fout.write()
        # for point in event_clusters_label:
        #     fout.record(('type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value', 'predicted_class', 'true_class', 'point_type'),
        #                       (4, point[0], point[1], point[2], batch_index, point[4], -1, -1, -1, -1))
        #     fout.write()
        # True PPN points
        for point in event_particles_label:
            fout.record(
                ('type', 'x', 'y', 'z', 'batch_id', 'point_type', 'cluster_id',
                 'value', 'predicted_class', 'true_class', 'idx'),
                (5, point[0], point[1], point[2], batch_index, point[4], -1,
                 -1, -1, -1, event_index))
            fout.write()
        # Predicted PPN points
        for point in predicted_points:
            fout.record(
                ('type', 'x', 'y', 'z', 'batch_id', 'predicted_class', 'value',
                 'true_class', 'cluster_id', 'point_type', 'idx'),
                (6, point[0], point[1], point[2], batch_index,
                 np.argmax(point[-5:]), -1, -1, -1, -1, event_index))
            fout.write()

        if not store_per_iteration: fout.close()

    if store_per_iteration: fout.close()