Пример #1
0
def output(output_formatters_list, data_blob, res, cfg, idx):
    event_id = 0
    print(res)
    for i in range(len(data_blob['input_data'])):
        for j in range(len(data_blob['input_data'][i])):
            batch_idx = np.unique(data_blob['input_data'][i][j][:, -2])
            for b in batch_idx:
                new_data_blob = {}
                data_index = data_blob['input_data'][i][j][:, 3] == b
                for key in data_blob:
                    if isinstance(data_blob[key][i][j], np.ndarray) and len(
                            data_blob[key][i][j].shape) == 2:
                        new_data_blob[key] = data_blob[key][i][j][
                            data_blob[key][i][j][:, 3] == b]
                    elif isinstance(data_blob[key][i][j], list):
                        new_data_blob[key] = data_blob[key][i][j][int(b)]
                # FIXME with minibatch
                new_res = {}
                if 'analysis_keys' in cfg['model']:
                    for key in cfg['model']['analysis_keys']:
                        if res[key][j].shape[0] == data_index.shape[0]:
                            new_res[key] = res[key][j][data_index]
                        else:  # assumes batch is in column 3
                            new_res[key] = res[key][j][res[key][j][:, 3] == b]

                csv_logger = utils.CSVData(
                    "%s/output-%.07d.csv" %
                    (cfg['training']['log_dir'], event_id))
                for output in output_formatters_list:
                    f = getattr(output_formatters, output)
                    f(csv_logger, new_data_blob, new_res)
                csv_logger.close()
                event_id += 1
Пример #2
0
def output(output_formatters_list, data_blob, res, cfg, idx, **kwargs):
    """
    Break down the data_blob and res dictionary into events.

    Need to account for: multi-gpu, minibatching, multiple outputs, batches.

    Input
    =====
    output_formatters_list: list of strings refering to the output formatters
        functions that will be applied to each event
    data_blob: from I/O
    res: results dictionary, output of trainval
    cfg: configuration
    idx: iteration index (to number events correctly and avoid overwriting)
    kwargs: other keyword arguments that will be passed to formatter functions
    """
    event_id = idx * cfg['iotool']['batch_size']
    num_forward = len(data_blob['input_data'])
    if len(cfg['training']['gpus']) > 0:
        assert num_forward == cfg['iotool']['batch_size'] / (
            cfg['training']['minibatch_size'] * len(cfg['training']['gpus']))
    for i in range(num_forward):
        num_gpus = len(data_blob['input_data'][i])
        for j in range(num_gpus):
            batch_idx = np.unique(data_blob['input_data'][i][j][:, 3])
            for b in batch_idx:
                new_data_blob = {}
                data_index = data_blob['input_data'][i][j][:, 3] == b
                for key in data_blob:
                    # 2D numpy array, assumes batch id is in column 3
                    if isinstance(data_blob[key][i][j], np.ndarray) and len(
                            data_blob[key][i][j].shape) == 2:
                        new_data_blob[key] = data_blob[key][i][j][
                            data_blob[key][i][j][:, 3] == b]
                    elif isinstance(data_blob[key][i][j], list):
                        new_data_blob[key] = data_blob[key][i][j][int(b)]
                # FIXME with minibatch
                new_res = {}
                if 'analysis_keys' in cfg['model']:
                    for key in cfg['model']['analysis_keys']:
                        idx = i * num_forward + j
                        if res[key][idx].shape[0] == data_index.shape[0]:
                            new_res[key] = res[key][idx][data_index]
                        else:  # FIXME assumes batch is in column 3 otherwise
                            new_res[key] = res[key][idx][res[key][idx][:,
                                                                       3] == b]

                csv_logger = utils.CSVData(
                    "%s/output-%.07d.csv" %
                    (cfg['training']['log_dir'], event_id))
                for output in output_formatters_list:
                    f = getattr(output_formatters, output)
                    f(csv_logger, new_data_blob, new_res, **kwargs)
                csv_logger.close()
                event_id += 1
Пример #3
0
def make_directories(cfg, loaded_iteration, handlers=None):
    # Weight save directory
    if cfg['training']['weight_prefix']:
        save_dir = cfg['training']['weight_prefix'][0:cfg['training']['weight_prefix'].rfind('/')]
        if save_dir and not os.path.isdir(save_dir):
            os.makedirs(save_dir)

    # Log save directory
    if cfg['training']['log_dir']:
        if not os.path.exists(cfg['training']['log_dir']):
            os.mkdir(cfg['training']['log_dir'])
        logname = '%s/train_log-%07d.csv' % (cfg['training']['log_dir'], loaded_iteration)
        if not cfg['training']['train']:
            logname = '%s/inference_log-%07d.csv' % (cfg['training']['log_dir'], loaded_iteration)
        if handlers is not None:
            handlers.csv_logger = utils.CSVData(logname)
def michel_reconstruction(data_blob, res, cfg, idx):
    """
    Very simple algorithm to reconstruct Michel clusters from UResNet semantic
    segmentation output.

    Assumptions
    ===========
    3D
    """
    # Create output CSV
    csv_logger = utils.CSVData("%s/michel_reconstruction-%.07d.csv" %
                               (cfg['training']['log_dir'], idx))

    model_cfg = cfg['model']

    segmentation_all = res['segmentation'][0]  # (N, 5)
    predictions_all = np.argmax(segmentation_all, axis=1)
    ghost_all = res['ghost'][0]  # (N, 2)
    data_all = data_blob['input_data'][0][0]
    label_all = data_blob['segment_label'][0][0][:, -1]
    particles_all = data_blob['particles_label'][0][0]  # (N_particles, 4+C)

    # First mask ghost points in predictions
    ghost_predictions = np.argmax(ghost_all, axis=1)
    mask = ghost_predictions == 0
    # data_all = data_all[mask]  # (M, 5)
    # label_all = label_all[mask]  # (M,)
    # predictions_all = predictions_all[mask]  # (M, )
    # segmentation_all = segmentation_all[mask]  # (M, 5)
    # particles_all = particles_all[mask]

    # Loop over events
    batch_ids = np.unique(data_all[:, 3])
    for b in batch_ids:
        batch_index = data_all[:, 3] == b
        data = data_all[batch_index]
        label = label_all[batch_index]

        data_pred = data_all[mask & batch_index]  # coords
        label_pred = label_all[mask & batch_index]  # labels
        predictions = predictions_all[mask & batch_index]
        segmentation = segmentation_all[mask & batch_index]
        particles = particles_all[particles_all[:, 3] == b]
        Michel_particles = particles[particles[:, 4] == 4]

        # 0. Retrieve coordinates of true and predicted Michels
        MIP_coords = data[(label == 1).reshape((-1, )), ...][:, :3]
        Michel_coords = data[(label == 4).reshape((-1, )), ...][:, :3]
        if Michel_coords.shape[0] == 0:  # FIXME
            continue
        # print("Michel in true labels")
        MIP_coords_pred = data_pred[(predictions == 1).reshape((-1, )),
                                    ...][:, :3]
        Michel_coords_pred = data_pred[(predictions == 4).reshape((-1, )),
                                       ...][:, :3]

        # 1. Find true particle information matching the true Michel cluster
        Michel_true_clusters = DBSCAN(eps=2.8284271247461903,
                                      min_samples=5).fit(Michel_coords).labels_
        Michel_start = Michel_particles[:, :3]
        # e.g. deposited energy, creation energy
        # TODO retrieve particles information
        # if Michel_coords.shape[0] > 0:
        #     Michel_clusters_id = np.unique(Michel_true_clusters[Michel_true_clusters>-1])
        #     for Michel_id in Michel_clusters_id:
        #         current_index = Michel_true_clusters == Michel_id
        #         distances = cdist(Michel_coords[current_index], MIP_coords)
        #         is_attached = np.min(distances) < 2.8284271247461903
        #         # Match to MC Michel
        #         distances2 = cdist(Michel_coords[current_index], Michel_start)
        #         closest_mc = np.argmin(distances2, axis=1)
        #         closest_mc_id = closest_mc[np.bincount(closest_mc).argmax()]

        # TODO how do we count events where there are no predictions but true?
        if MIP_coords_pred.shape[0] == 0 or Michel_coords_pred.shape[0] == 0:
            continue
        # print("Also predicted!")
        # 2. Compute true and predicted clusters
        MIP_clusters = DBSCAN(eps=2.8284271247461903,
                              min_samples=10).fit(MIP_coords_pred).labels_
        Michel_pred_clusters = DBSCAN(
            eps=2.8284271247461903,
            min_samples=5).fit(Michel_coords_pred).labels_
        Michel_pred_clusters_id = np.unique(
            Michel_pred_clusters[Michel_pred_clusters > -1])
        # print(len(Michel_pred_clusters_id))
        # Loop over predicted Michel clusters
        for Michel_id in Michel_pred_clusters_id:
            current_index = Michel_pred_clusters == Michel_id
            # 3. Check whether predicted Michel is attached to a predicted MIP
            # and at the edge of the predicted MIP
            distances = cdist(Michel_coords_pred[current_index],
                              MIP_coords_pred[MIP_clusters > -1])
            is_attached = np.min(distances) < 2.8284271247461903
            is_edge = False  # default
            # print("Min distance:", np.min(distances))
            if is_attached:
                Michel_min, MIP_min = np.unravel_index(np.argmin(distances),
                                                       distances.shape)
                MIP_id = MIP_clusters[MIP_clusters > -1][MIP_min]
                MIP_min_coords = MIP_coords_pred[MIP_clusters > -1][MIP_min]
                MIP_cluster_coords = MIP_coords_pred[MIP_clusters == MIP_id]
                ablated_cluster = MIP_cluster_coords[np.linalg.norm(
                    MIP_cluster_coords - MIP_min_coords, axis=1) > 15.0]
                if ablated_cluster.shape[0] > 0:
                    new_cluster = DBSCAN(
                        eps=2.8284271247461903,
                        min_samples=5).fit(ablated_cluster).labels_
                    is_edge = len(np.unique(
                        new_cluster[new_cluster > -1])) == 1
                else:
                    is_edge = True
            # print(is_attached, is_edge)

            michel_pred_num_pix_true, michel_pred_sum_pix_true = -1, -1
            michel_true_num_pix, michel_true_sum_pix = -1, -1
            michel_true_energy = -1
            if is_attached and is_edge and Michel_coords.shape[0] > 0:
                distances = cdist(Michel_coords_pred[current_index],
                                  Michel_coords)
                closest_clusters = Michel_true_clusters[np.argmin(distances,
                                                                  axis=1)]
                closest_clusters_final = closest_clusters[
                    (closest_clusters > -1)
                    & (np.min(distances, axis=1) < 2.8284271247461903)]
                # print(closest_clusters, np.min(distances, axis=1))
                if len(closest_clusters_final) > 0:
                    closest_true_id = closest_clusters_final[np.bincount(
                        closest_clusters_final).argmax()]
                    overlap_pixels_index = (
                        closest_clusters == closest_true_id) & (np.min(
                            distances, axis=1) < 2.8284271247461903)
                    if closest_true_id > -1:
                        closest_true_index = label_pred[predictions ==
                                                        4][current_index] == 4
                        michel_pred_num_pix_true = np.count_nonzero(
                            closest_true_index)
                        michel_pred_sum_pix_true = data_pred[(
                            predictions == 4).reshape(
                                (-1, )), ...][current_index][(
                                    closest_true_index).reshape((-1, )),
                                                             ...][:, -1].sum()
                        michel_true_num_pix = np.count_nonzero(
                            Michel_true_clusters == closest_true_id)
                        michel_true_sum_pix = data[(label == 4).reshape(
                            (-1, )), ...][Michel_true_clusters ==
                                          closest_true_id][:, -1].sum()
                        # Register true energy
                        # Match to MC Michel
                        distances2 = cdist(
                            Michel_coords[Michel_true_clusters ==
                                          closest_true_id], Michel_start)
                        closest_mc = np.argmin(distances2, axis=1)
                        closest_mc_id = closest_mc[np.bincount(
                            closest_mc).argmax()]
                        michel_true_energy = Michel_particles[closest_mc_id, 7]
            # Record every predicted Michel cluster in CSV
            csv_logger.record(
                ('pred_num_pix', 'pred_sum_pix', 'pred_num_pix_true',
                 'pred_sum_pix_true', 'true_num_pix', 'true_sum_pix',
                 'is_attached', 'is_edge', 'michel_true_energy'),
                (np.count_nonzero(current_index),
                 data_pred[(predictions == 4).reshape(
                     (-1, )), ...][current_index][:, -1].sum(),
                 michel_pred_num_pix_true, michel_pred_sum_pix_true,
                 michel_true_num_pix, michel_true_sum_pix, is_attached,
                 is_edge, michel_true_energy))
            csv_logger.write()
    csv_logger.close()
def track_clustering(data_blob, res, cfg, idx):
    # Create output CSV
    csv_logger = utils.CSVData("%s/track_clustering-%.07d.csv" % (cfg['training']['log_dir'], idx))

    model_cfg = cfg['model']
    clusters = res['clusters'][0]  # (N1, 6)
    points = res['points'][0]  # (N, 5)
    segmentation = res['segmentation'][0]  # (N, 5)
    # FIXME N1 >= N because some points might belong to several clusters?
    clusters_label = data_blob['clusters_label'][0][0]  # (N1, 5)
    particles_label = data_blob['particles_label'][0][0]  # (N_gt, 5)
    data = data_blob['input_data'][0][0]
    segmentation_label = data_blob['segment_label'][0][0]

    print("Predicted points: ", points)

    data_dim = 3  # model_cfg['data_dim']
    batch_ids = np.unique(data[:, data_dim])
    # print(segmentation[: 10], batch_ids)
    score_threshold = 0.6
    threshold_association = 3
    exclusion_radius = 5
    for b in batch_ids:
        event_clusters = clusters[clusters[:, data_dim] == b]
        batch_index = points[:, data_dim] == b
        event_points = points[batch_index][:, :-2]
        event_scores = points[batch_index][:, -2:]
        event_data = data[:, :data_dim][batch_index]
        event_segmentation = segmentation[batch_index]
        event_clusters_label = clusters_label[clusters_label[:, data_dim] == b]
        event_particles_label = particles_label[particles_label[:, data_dim] == b]

        anchors = (event_data + 0.5)
        event_points = event_points + anchors

        if event_points.shape[0] > 0:
            score_index = event_scores[:, 1] > score_threshold
            event_points = event_points[score_index]
            event_scores = event_scores[score_index]
        # 0) DBScan predicted pixels
        # print(event_points[:10])
        if event_points.shape[0] > 0:
            # db = DBSCAN(eps=1.0, min_samples=5).fit(event_points).labels_
            # print(np.unique(db))
            # dbscan_points = []
            # for label in np.unique(db):
            #     dbscan_points.append(event_points[db == label].mean(axis=0))
            # dbscan_points = np.stack(dbscan_points)
            # print(dbscan_points.shape)
            print("Predicted points: ", event_points.shape)
            keep = nms_numpy(event_points, event_scores, 0.1, 5)
            dbscan_points = event_points[keep]
            print("Remaining predicted points: ", dbscan_points.shape)

            # 1) Break algorithm
            print(len(event_clusters), np.unique(event_clusters[:, -1]))
            cluster_ids = np.unique(event_clusters[:, -1])
            final_clusters = []
            for c in cluster_ids:
                # print("Cluster ", c)
                cluster = event_clusters[event_clusters[:, -1] == c][:, :data_dim]
                d = cdist(dbscan_points, cluster)
                # print(d.shape)
                index = d.min(axis=1) < threshold_association
                cluster_points = dbscan_points[index]
                # print(cluster_points)
                new_d = d[index.reshape((-1,)), :]
                # print(new_d.shape)
                new_index = (new_d > exclusion_radius).all(axis=0)
                new_cluster = cluster[new_index]
                remaining_cluster = cluster[~new_index]
                # print(new_cluster.shape)
                db2 = DBSCAN(eps=1.0, min_samples=1).fit(new_cluster).labels_
                # print(db2)
                new_cluster_ids = np.unique(db2)
                new_clusters = []
                for c2 in new_cluster_ids:
                    new_clusters.append([new_cluster[db2 == c2]])
                d3 = cdist(remaining_cluster, new_cluster)
                remaining_db = db2[d3.argmin(axis=1)]
                for i, c in enumerate(remaining_cluster):
                    new_clusters[remaining_db[i]].append(c[None, :])
                for i in range(len(new_clusters)):
                    new_clusters[i] = np.concatenate(new_clusters[i], axis=0)
                # print(new_clusters)
                final_clusters.extend(new_clusters)
        else:
            final_clusters = []
            cluster_idx = np.unique(event_clusters[:, -1])
            for c in cluster_idx:
                final_clusters.append(event_clusters[event_clusters[:, -1] == c][:, :data_dim])
            # FIXME is this right?


        # 2) Compute cluster efficiency/purity
        # ie associate final clusters after breaking with true clusters
        label_cluster_ids = np.unique(event_clusters_label[:, -1])
        true_clusters = []
        for c in label_cluster_ids:
            true_clusters.append(event_clusters_label[event_clusters_label[:, -1] == c][:, :-2])

        # Match each predicted cluster to a true cluster
        matches = []
        overlaps = []
        for predicted_cluster in final_clusters:
            overlap = []
            for true_cluster in true_clusters:
                overlap_pixel_count = np.count_nonzero((cdist(predicted_cluster, true_cluster)<1).any(axis=0))
                overlap.append(overlap_pixel_count)
            overlap = np.array(overlap)
            if overlap.max() > 0:
                matches.append(overlap.argmax())
                overlaps.append(overlap.max())
            else:
                matches.append(-1)
                overlaps.append(0)

        # Compute cluster purity/efficiency
        purity, efficiency = [], []
        npix_predicted, npix_true = [], []
        for i, predicted_cluster in enumerate(final_clusters):
            if matches[i] > -1:
                matched_cluster = true_clusters[matches[i]]
                purity.append(overlaps[i] / predicted_cluster.shape[0])
                efficiency.append(overlaps[i] / matched_cluster.shape[0])
                npix_predicted.append(predicted_cluster.shape[0])
                npix_true.append(matched_cluster.shape[0])

        print("Purity: ", purity)
        print("Efficiency: ", efficiency)
        print("Match indices: ", matches)
        print("Overlaps: ", overlaps)
        print("Npix predicted: ", npix_predicted)
        print("Npix true: ", npix_true)

        # Record in CSV everything
        for i, point in enumerate(data[batch_index]):
            csv_logger.record(('point_type', 'x', 'y', 'z', 'batch_id', 'value', 'predicted_class', 'true_class', 'cluster_id', 'type'),
                              (0, point[0], point[1], point[2], point[3], point[4], np.argmax(event_segmentation[i]), segmentation_label[segmentation_label[:, data_dim] == b][i, -1], -1, -1 ))
            csv_logger.write()
        for c, cluster in enumerate(final_clusters):
            for point in cluster:
                csv_logger.record(('point_type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value', 'predicted_class', 'true_class', 'type'),
                                  (1, point[0], point[1], point[2], b, c, -1, -1, -1, -1))
                csv_logger.write()
        for c, cluster in enumerate(true_clusters):
            for point in cluster:
                csv_logger.record(('point_type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value', 'predicted_class', 'true_class', 'type'),
                                  (2, point[0], point[1], point[2], b, c, -1, -1, -1, -1))
        for point in event_points:
            csv_logger.record(('point_type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value', 'predicted_class', 'true_class', 'type'),
                              (3, point[0], point[1], point[2], b, -1, -1, -1, -1, -1))
            csv_logger.write()
        for point in event_clusters_label:
            csv_logger.record(('point_type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value', 'predicted_class', 'true_class', 'type'),
                              (4, point[0], point[1], point[2], b, point[4], -1, -1, -1, -1))
            csv_logger.write()
        for point in event_particles_label:
            csv_logger.record(('point_type', 'x', 'y', 'z', 'batch_id', 'type', 'cluster_id', 'value', 'predicted_class', 'true_class'),
                              (5, point[0], point[1], point[2], b, point[4], -1, -1, -1, -1))
            csv_logger.write()
        # for point in segmentation:
        #     csv_logger.record(('point_type', 'x', 'y', 'z', 'batch_id', 'class_id'),
        #                       (6, point[0], point[1], point[2], b, np.argmax()))
        #     csv_logger.write()
        csv_logger.close()
def deghosting_metrics(cfg, data_blob, res, logdir, iteration):  #, idx):
    """
    Some useful metrics to measure deghosting performance

    Parameters
    ----------
    data_blob: dict
        Input dictionary returned by iotools
    res: dict
        Results from the network, dictionary using `analysis_keys`
    cfg: dict
        Configuration
    idx: int
        Iteration number

    Input
    -----
    Requires the following analysis keys:
    - `segmentation`
    - `ghost` only if 5+2 types architecture for GhostNet
    Requires the following input keys:
    - `input_data`
    - `segment_label`
    Assumes no minibatching

    Output
    ------
    Writes to a CSV file `deghosting_metrics-*`
    """

    method_cfg = cfg['post_processing']['deghosting_metrics']

    csv_logger = utils.CSVData(
        os.path.join(logdir, "deghosting_metrics-iter-%.07d.csv" % iteration))
    for data_idx, tree_idx in enumerate(data_blob['index']):

        deghosting_type = method_cfg['method']
        assert (deghosting_type in ['5+2', '6', '2'])

        pcluster = None
        if 'pcluster' in data_blob:
            pcluster = data_blob['pcluster'][data_idx][:, -1]
        label = data_blob['segment_label'][data_idx][:, -1]
        segmentation = res['segmentation'][data_idx]  # (N, 5)
        predictions = np.argmax(segmentation, axis=1)

        num_classes = segmentation.shape[1]
        num_ghost_points = np.count_nonzero(label == 5)
        num_nonghost_points = np.count_nonzero(label < 5)

        csv_logger.record(('num_ghost_points', 'num_nonghost_points', 'idx'),
                          (num_ghost_points, num_nonghost_points, tree_idx))

        if deghosting_type == '5+2':
            # Accuracy for ghost prediction for 5+2
            ghost_predictions = np.argmax(res['ghost'][data_idx], axis=1)
            mask = ghost_predictions == 0
            # 0 = non ghost, 1 = ghost
            # Fraction of true points predicted correctly
            ghost_acc = ((ghost_predictions == 1)
                         == (label == 5)).sum() / float(label.shape[0])
            # Fraction of ghost points predicted as ghost points
            ghost2ghost = (ghost_predictions[label == 5]
                           == 1).sum() / float(num_ghost_points)
            # Fraction of true non-ghost points predicted as true non-ghost points
            nonghost2nonghost = (ghost_predictions[label < 5]
                                 == 0).sum() / float(num_nonghost_points)
            csv_logger.record(("ghost2ghost", "nonghost2nonghost"),
                              (ghost2ghost, nonghost2nonghost))

            # Accuracy for 5 types, global
            uresnet_acc = (label[label < 5]
                           == predictions[label < 5]).sum() / float(
                               np.count_nonzero(label < 5))
            csv_logger.record(('ghost_acc', 'uresnet_acc'),
                              (ghost_acc, uresnet_acc))
            # Class-wise nonzero accuracy for 5 types, based on true mask
            acc, num_true_pix, num_pred_pix = [], [], []
            num_pred_pix_true = []
            num_true_deghost_pix, num_original_pix = [], []
            ghost_false_positives, ghost_true_positives = [], []
            for c in range(num_classes):
                class_mask = label == c
                class_predictions = predictions[class_mask]
                # Fraction of pixels in this class predicted correctly
                acc.append((class_predictions == c).sum() /
                           float(class_predictions.shape[0]))
                # Pixel counts
                # Pixels in sparse3d_semantics_reco
                num_true_pix.append(np.count_nonzero(class_mask))
                # Pixels in sparse3d_semantics_reco predicted as nonghost
                num_true_deghost_pix.append(np.count_nonzero(class_mask
                                                             & mask))
                # Pixels in original pcluster
                if pcluster is not None:
                    num_original_pix.append(np.count_nonzero(pcluster == c))
                # Pixels in predictions + nonghost
                num_pred_pix.append(np.count_nonzero(predictions[mask] == c))
                # Pixels in predictions + nonghost that are correctly classified
                num_pred_pix_true.append(
                    np.count_nonzero(class_predictions == c))
                # Fraction of pixels in this class (wrongly) predicted as ghost
                ghost_false_positives.append(
                    np.count_nonzero(ghost_predictions[class_mask] == 1))
                # Fraction of pixels in this class (correctly) predicted as nonghost
                ghost_true_positives.append(
                    np.count_nonzero(ghost_predictions[class_mask] == 0))
                # confusion matrix
                # pixels predicted as nonghost + should be in class c, but predicted as c2
                for c2 in range(num_classes):
                    csv_logger.record(
                        ('confusion_%d_%d' % (c, c2), ),
                        (((class_predictions == c2) &
                          (ghost_predictions[class_mask] == 0)).sum(), ))
            csv_logger.record(['acc_class%d' % c for c in range(num_classes)],
                              acc)
            csv_logger.record(
                ['num_true_pix_class%d' % c for c in range(num_classes)],
                num_true_pix)
            csv_logger.record([
                'num_true_deghost_pix_class%d' % c for c in range(num_classes)
            ], num_true_deghost_pix)
            if pcluster is not None:
                csv_logger.record([
                    'num_original_pix_class%d' % c for c in range(num_classes)
                ], num_original_pix)
            csv_logger.record(
                ['num_pred_pix_class%d' % c for c in range(num_classes)],
                num_pred_pix)
            csv_logger.record(
                ['num_pred_pix_true_class%d' % c for c in range(num_classes)],
                num_pred_pix_true)
            csv_logger.record([
                'ghost_false_positives_class%d' % c for c in range(num_classes)
            ], ghost_false_positives)
            csv_logger.record([
                'ghost_true_positives_class%d' % c for c in range(num_classes)
            ], ghost_true_positives)

        elif deghosting_type == '6':
            ghost2ghost = (predictions[label == 5]
                           == 5).sum() / float(num_ghost_points)
            nonghost2nonghost = (predictions[label < 5] <
                                 5).sum() / float(num_nonghost_points)
            csv_logger.record(("ghost2ghost", "nonghost2nonghost"),
                              (ghost2ghost, nonghost2nonghost))
            # 6 types confusion matrix
            for c in range(num_classes):
                for c2 in range(num_classes):
                    # Fraction of points of class c, predicted as c2
                    x = (predictions[label == c] == c2).sum() / float(
                        np.count_nonzero(label == c))
                    csv_logger.record(('confusion_%d_%d' % (c, c2), ), (x, ))
        elif deghosting_type == '2':
            ghost2ghost = (predictions[label == 5]
                           == 1).sum() / float(num_ghost_points)
            nonghost2nonghost = (predictions[label < 5]
                                 == 0).sum() / float(num_nonghost_points)
            csv_logger.record(("ghost2ghost", "nonghost2nonghost"),
                              (ghost2ghost, nonghost2nonghost))
        else:
            print('Invalid "deghosting_type" config parameter value:',
                  deghosting_type)
            raise ValueError
        csv_logger.write()
    csv_logger.close()
def deghosting_metrics(data_blob, res, cfg, idx):
    """
    Some useful metrics to measure deghosting performance
    """
    csv_logger = utils.CSVData("%s/deghosting_metrics-%.07d.csv" %
                               (cfg['training']['log_dir'], idx))

    model_cfg = cfg['model']

    segmentation_all = res['segmentation'][0]  # (N, 5)
    predictions_all = np.argmax(segmentation_all, axis=1)
    ghost_all = res['ghost'][0]  # (N, 2)
    data_all = data_blob['input_data'][0][0]
    label_all = data_blob['segment_label'][0][0][:, -1]

    # First mask ghost points in predictions
    ghost_predictions = np.argmax(ghost_all, axis=1)
    mask = ghost_predictions == 0

    batch_ids = np.unique(data_all[:, 3])
    num_classes = segmentation_all.shape[1]
    for b in batch_ids:
        batch_index = data_all[:, 3] == b
        label = label_all[batch_index]

        # Accuracy for ghost prediction
        ghost_acc = ((ghost_predictions[batch_index] == 1)
                     == (label == 5)).sum() / float(label.shape[0])

        # Accuracy for 5 types, global
        uresnet_acc = (label[label < 5] == predictions_all[batch_index][
            label < 5]).sum() / float(np.count_nonzero(label < 5))

        # Class-wise nonzero accuracy for 5 types, based on true mask
        acc, num_true_pix, num_pred_pix = [], [], []
        num_pred_pix_true = []
        ghost_false_positives, ghost_true_positives = [], []
        for c in range(num_classes):
            class_mask = label == c
            class_predictions = predictions_all[batch_index][mask[batch_index]
                                                             & class_mask]
            acc.append((class_predictions == c).sum() /
                       float(class_predictions.shape[0]))
            num_true_pix.append(np.count_nonzero(class_mask))
            num_pred_pix.append(
                np.count_nonzero(predictions_all[batch_index & mask] == c))
            num_pred_pix_true.append(np.count_nonzero(class_predictions == c))
            ghost_false_positives.append(
                np.count_nonzero(
                    ghost_predictions[batch_index][class_mask] == 1))
            ghost_true_positives.append(
                np.count_nonzero(
                    ghost_predictions[batch_index][class_mask] == 0))
        csv_logger.record(['acc_class%d' % c for c in range(num_classes)], acc)
        csv_logger.record(
            ['num_true_pix_class%d' % c for c in range(num_classes)],
            num_true_pix)
        csv_logger.record(
            ['num_pred_pix_class%d' % c for c in range(num_classes)],
            num_pred_pix)
        csv_logger.record(
            ['num_pred_pix_true_class%d' % c for c in range(num_classes)],
            num_pred_pix_true)
        csv_logger.record(
            ['ghost_false_positives_class%d' % c for c in range(num_classes)],
            ghost_false_positives)
        csv_logger.record(
            ['ghost_true_positives_class%d' % c for c in range(num_classes)],
            ghost_true_positives)
        csv_logger.record(('ghost_acc', 'uresnet_acc'),
                          (ghost_acc, uresnet_acc))
        csv_logger.write()
    csv_logger.close()