コード例 #1
0
    def __init__(self, io_cfg, chain_cfg, verbose=False):
        '''
        Initializes the chain from the configuration file
        '''
        # Initialize the data loader
        io_cfg = yaml.load(io_cfg, Loader=yaml.Loader)

        # Save config, initialize output
        self.cfg = chain_cfg
        self.verbose = verbose
        self.output = {}

        # Initialize log
        log_path = chain_cfg['name'] + '_log.csv'
        print('Initialized Pi0 mass chain, log path:', log_path)
        self._log = CSVData(log_path)
        self._keys = ['event_id', 'pion_id', 'pion_mass']

        # If a network is specified, initialize the network
        self.network = False
        if chain_cfg['segment'] == 'uresnet' or chain_cfg[
                'shower_start'] == 'ppn':
            self.network = True
            with open(chain_cfg['net_cfg']) as cfg_file:
                net_cfg = yaml.load(cfg_file, Loader=yaml.Loader)
            io_cfg['model'] = net_cfg['model']
            io_cfg['trainval'] = net_cfg['trainval']

        # Pre-process configuration
        process_config(io_cfg)

        # Instantiate "handlers" (IO tools)
        self.hs = prepare(io_cfg)
        self.data_set = iter(self.hs.data_io)
コード例 #2
0
    def __init__(self, io_cfg, chain_cfg, verbose=False):
        '''
        Initializes the chain from the configuration file
        '''
        # Initialize the data loader
        io_cfg = yaml.load(io_cfg, Loader=yaml.Loader)

        # Save config, initialize output
        self.cfg = chain_cfg
        self.verbose = verbose
        self.event = None
        self.output = {}

        # Initialize log
        log_path = chain_cfg['name'] + '_log.csv'
        print('Initialized Pi0 mass chain, log path:', log_path)
        self._log = CSVData(log_path)
        self._keys = ['event_id', 'pion_id', 'pion_mass']

        # If a network is specified, initialize the network
        self.network = False
        if chain_cfg['segment'] == 'uresnet' or chain_cfg[
                'shower_start'] == 'ppn':
            self.network = True
            with open(chain_cfg['net_cfg']) as cfg_file:
                net_cfg = yaml.load(cfg_file, Loader=yaml.Loader)
            io_cfg['model'] = net_cfg['model']
            io_cfg['trainval'] = net_cfg['trainval']

        # Initialize the fragment identifier
        self.frag_est = FragmentEstimator()

        # If a direction estimator is requested, initialize it
        if chain_cfg['shower_dir'] != 'truth':
            self.dir_est = DirectionEstimator()

        # If a clusterer is requested, initialize it
        if chain_cfg['shower_energy'] == 'cone':
            self.clusterer = ConeClusterer()

        # If a pi0 identifier is requested, initialize it
        if chain_cfg['shower_match'] == 'proximity':
            self.matcher = Pi0Matcher()

        # Pre-process configuration
        process_config(io_cfg)

        # Instantiate "handlers" (IO tools)
        self.hs = prepare(io_cfg)
        self.data_set = iter(self.hs.data_io)
コード例 #3
0
class Pi0Chain():
    def __init__(self, io_cfg, chain_cfg, verbose=False):
        '''
        Initializes the chain from the configuration file
        '''
        # Initialize the data loader
        io_cfg = yaml.load(io_cfg, Loader=yaml.Loader)

        # Save config, initialize output
        self.cfg = chain_cfg
        self.verbose = verbose
        self.output = {}

        # Initialize log
        log_path = chain_cfg['name'] + '_log.csv'
        print('Initialized Pi0 mass chain, log path:', log_path)
        self._log = CSVData(log_path)
        self._keys = ['event_id', 'pion_id', 'pion_mass']

        # If a network is specified, initialize the network
        self.network = False
        if chain_cfg['segment'] == 'uresnet' or chain_cfg[
                'shower_start'] == 'ppn':
            self.network = True
            with open(chain_cfg['net_cfg']) as cfg_file:
                net_cfg = yaml.load(cfg_file, Loader=yaml.Loader)
            io_cfg['model'] = net_cfg['model']
            io_cfg['trainval'] = net_cfg['trainval']

        # Pre-process configuration
        process_config(io_cfg)

        # Instantiate "handlers" (IO tools)
        self.hs = prepare(io_cfg)
        self.data_set = iter(self.hs.data_io)

    def hs(self):
        return self.hs

    def data_set(self):
        return self.data_set

    def log(self, eid, pion_id, pion_mass):
        self._log.record(self._keys, [eid, pion_id, pion_mass])
        self._log.write()
        self._log.flush()

    def run(self):
        '''
        Runs the full Pi0 reconstruction chain, from 3D charge
        information to Pi0 masses for events that contain one
        or more Pi0 decay.
        '''
        for i in range(len(self.hs.data_io)):
            self.run_loop()

    def run_loop(self):
        '''
        Runs the full Pi0 reconstruction chain on a single event,
        from 3D charge information to Pi0 masses for events that
        contain one or more Pi0 decay.
        '''
        # Reset output
        self.output = {}

        # Load data
        if not self.network:
            event = next(self.data_set)
            event_id = event['index'][0]
        else:
            event, self.output['forward'] = self.hs.trainer.forward(
                self.data_set)
            for key in event.keys():
                if key != 'particles':
                    event[key] = event[key][0]
            event_id = event['index']

        # Filter out ghosts
        self.filter_ghosts(event)

        # Reconstruct energy
        self.reconstruct_energy(event)

        # Identify shower starting points
        self.find_shower_starts(event)
        if not len(self.output['showers']):
            if self.verbose:
                print('No shower start point found in event', event_id)
            return []

        # Reconstruct shower direction vectors
        self.reconstruct_shower_directions(event)

        # Reconstruct shower energy
        self.reconstruct_shower_energy(event)

        # Identify pi0 decays
        self.identify_pi0(event)
        if not len(self.output['matches']):
            if self.verbose:
                print('No pi0 found in event', event_id)
            return []

        # Compute masses
        masses = self.pi0_mass()

        # Log masses
        for i, m in enumerate(masses):
            self.log(event_id, i, m)

    def filter_ghosts(self, event):
        '''
        Removes ghost points from the charge tensor
        '''
        if self.cfg['input'] == 'energy':
            self.output['segment'] = event['segment_label_true']
            self.output['group'] = event['group_label_true']
            self.output['dbscan'] = event['dbscan_label_true']

        elif self.cfg['segment'] == 'mask':
            self.output['segment'] = event['segment_label_reco']

            mask = np.where(self.output['segment'][:, 4] != 5)
            self.output['charge'] = event['charge'][mask]
            self.output['group'] = event['group_label_reco']
            self.output['dbscan'] = event['dbscan_label_reco'][mask]

        elif self.cfg['segment'] == 'uresnet':
            # Get the segmentation output of the network
            res = self.output['forward']['segmentation'][0]

            # Argmax to determine most probable label
            pred_labels = np.argmax(res, axis=1)
            mask = np.where(pred_labels != 5)
            self.output['charge'] = event['charge'][mask]
            self.output['segment'] = copy(event['segment_label_reco'])
            self.output['segment'][:, 4] = pred_labels
            self.output['group'] = event['group_label_reco']
            self.output['dbscan'] = event['dbscan_label_reco'][mask]

        else:
            raise ValueError('Semantic segmentation method not recognized:',
                             self.cfg['segment'])

    def reconstruct_energy(self, event):
        '''
        Reconstructs energy deposition from charge
        '''
        if self.cfg['input'] == 'energy':
            self.output['energy'] = event['energy']

        elif self.cfg['response'] == 'constant':
            reco = self.cfg['response_cst'] * event['charge'][:, 4]
            self.output['energy'] = copy(event['charge'])
            self.output['energy'][:, 4] = reco

        elif self.cfg['response'] == 'full':
            raise NotImplementedError(
                'Proper energy reconstruction not implemented yet')

        elif self.cfg['response'] == 'enet':
            raise NotImplementedError('ENet not implemented yet')

        else:
            raise ValueError('Energy reconstruction method not recognized:',
                             self.cfg['response'])

    def find_shower_starts(self, event):
        '''
        Identify starting points of showers
        '''
        if self.cfg['shower_start'] == 'truth':
            # Get the true shower starting points from the particle information
            self.output['showers'] = []
            for i, part in enumerate(event['particles'][0]):
                if self.is_shower(part):
                    new_shower = Shower(start=[
                        part.first_step().x(),
                        part.first_step().y(),
                        part.first_step().z()
                    ],
                                        pid=i)
                    self.output['showers'].append(new_shower)

        elif self.cfg['shower_start'] == 'ppn':
            raise NotImplementedError('PPN not implemented yet')

        else:
            raise ValueError(
                'EM shower primary identifiation method not recognized:',
                self.cfg['shower_start'])

    def reconstruct_shower_directions(self, event):
        '''
        Reconstructs the direction of the showers
        '''
        if self.cfg['shower_dir'] == 'truth':
            for shower in self.output['showers']:
                part = event['particles'][0][shower.pid]
                mom = [part.px(), part.py(), part.pz()]
                shower.direction = list(np.array(mom) / np.linalg.norm(mom))

        elif self.cfg['shower_dir'] == 'pca':
            # Apply DBSCAN, PCA on the touching cluster to get angles
            points = np.array([
                s.start + [0., s.pid] + [0., 0., 0.]
                for s in self.output['showers']
            ])
            res, _, _ = gamma_direction.do_calculation(self.output['segment'],
                                                       points)
            for i, shower in enumerate(self.output['showers']):
                if np.linalg.norm(res[i][-3:]) == 0.:
                    shower.direction = [0., 0., 0.]
                    continue
                shower.direction = list(res[i][-3:] /
                                        np.linalg.norm(res[i][-3:]))

        else:
            raise ValueError(
                'Shower direction reconstruction method not recognized:',
                self.cfg['shower_dir'])

    def reconstruct_shower_energy(self, event):
        '''
        Clusters the different showers, reconstruct energy of each shower
        '''
        if self.cfg['shower_energy'] == 'truth':
            # Gets the true energy information from Geant4
            for shower in self.output['showers']:
                part = event['particles'][0][shower.pid]
                shower.energy = part.energy_init()
                pid = shower.pid
                mask = np.where(self.output['group'][:, -1] == pid)[0]
                shower.voxels = mask

        elif self.cfg['shower_energy'] == 'group':
            # Gets all the voxels in the group corresponding to the pid, adds up energy
            for shower in self.output['showers']:
                pid = shower.pid
                mask = np.where(self.output['group'][:, -1] == pid)[0]
                shower.voxels = mask
                shower.energy = np.sum(self.output['energy'][mask][:, -1])

        elif self.cfg['shower_energy'] == 'cone':
            # Fits cones to each shower, adds energies within that cone
            points = np.array(
                [s.start + [0., s.pid] for s in self.output['showers']])
            res = cone_clusterer.find_shower_cone(
                self.output['dbscan'], self.output['group'], points,
                self.output['energy'], self.output['segment'])[
                    0]  # This returns one array of voxel ids per primary
            for i, shower in enumerate(self.output['showers']):
                if not len(res[i]):
                    shower.energy = 0.
                    continue
                shower.voxels = res[i]
                shower.energy = np.sum(self.output['energy'][res[i]][:, 4])

        else:
            raise ValueError(
                'Shower energy reconstruction method not recognized:',
                self.cfg['shower_energy'])

    def identify_pi0(self, event):
        '''
        Proposes pi0 candidates (match two showers)
        '''
        self.output['matches'] = []
        self.output['vertices'] = []
        n_showers = len(self.output['showers'])
        if self.cfg['shower_match'] == 'truth':
            # Get the creation point of each particle. If two gammas originate from the same point,
            # It is most likely a pi0 decay.
            creations = []
            for shower in self.output['showers']:
                part = event['particles'][0][shower.pid]
                creations.append([
                    part.position().x(),
                    part.position().y(),
                    part.position().z()
                ])

            for i, ci in enumerate(creations):
                for j in range(i + 1, n_showers):
                    if (np.array(ci) == np.array(creations[j])).all():
                        self.output['matches'].append([i, j])
                        self.output['vertices'].append(ci)

            return self.output['matches']

        elif self.cfg['shower_match'] == 'proximity':
            # Pair closest shower vectors
            points = np.array([
                s.start + [0, s.pid] + s.direction
                for s in self.output['showers']
            ])
            event['segment_label'] = self.output['segment']
            event['group_label'] = self.output['group']
            res, vertices = pi0_pi_selection.generate_pair_labels(
                event, points, predict=False)
            for i, v in enumerate(vertices):
                self.output['matches'].append([0, 1])  # TODO, must ask DHK
                self.output['vertices'].append(v)

        else:
            raise ValueError('Shower matching method not recognized:',
                             self.cfg['shower_match'])

    def pi0_mass(self):
        '''
        Reconstructs the pi0 mass
        '''
        from math import sqrt
        masses = []
        for match in self.output['matches']:
            s1, s2 = self.output['showers'][match[0]], self.output['showers'][
                match[1]]
            e1, e2 = s1.energy, s2.energy
            t1, t2 = s1.direction, s2.direction
            costheta = np.dot(t1, t2)
            if abs(costheta) > 1.:
                masses.append(0.)
                continue
            masses.append(sqrt(2 * e1 * e2 * (1 - costheta)))
        return masses

    def draw(self):
        from mlreco.visualization import plotly_layout3d
        from mlreco.visualization.voxels import scatter_voxels, scatter_label
        import plotly.plotly as py
        import plotly.graph_objs as go
        from plotly.offline import init_notebook_mode, iplot
        init_notebook_mode(connected=False)

        # Create labels for the voxels
        # Use a different color for each cluster
        labels = np.full(len(self.output['energy'][:, 4]), -1)
        for i, s in enumerate(self.output['showers']):
            labels[s.voxels] = i

        # Draw voxels with cluster labels
        voxels = self.output['energy'][:, :3]
        graph_voxels = scatter_label(voxels, labels, 2)[0]
        graph_voxels.name = 'Shower ID'
        graph_data = [graph_voxels]

        if len(self.output['showers']):
            # Add EM primary points
            points = np.array([s.start for s in self.output['showers']])
            graph_start = scatter_voxels(points)[0]
            graph_start.name = 'Shower starts'
            graph_data.append(graph_start)

            # Add a vertex if matches, join vertex to start points
            for i, m in enumerate(self.output['matches']):
                v = self.output['vertices'][i]
                s1, s2 = self.output['showers'][
                    m[0]].start, self.output['showers'][m[1]].start
                points = [v, s1, v, s2]
                line = scatter_voxels(np.array(points))[0]
                line.name = 'Pi0 Decay'
                line.mode = 'lines,markers'
                graph_data.append(line)

        # Draw
        iplot(go.Figure(data=graph_data, layout=plotly_layout3d()))

    @staticmethod
    def is_shower(particle):
        '''
        Check if the particle is a shower
        '''
        pdg_code = abs(particle.pdg_code())
        if pdg_code == 22 or pdg_code == 11:
            return True
        return False
コード例 #4
0
def instance_clustering(cfg, data_blob, res, logdir, iteration):
    """
    Simple DBSCAN on uresnet clustering output for instance segmentation

    Parameters
    ----------
    data_blob: dict
        Input dictionary returned by iotools
    res: dict
        Results from the network, dictionary using `analysis_keys`
    cfg: dict
        Configuration
    idx: int
        Iteration number

    Input
    -----
    Requires the following analysis keys
    - `segmentation`: output of UResNet segmentation (scores)
    - `clustering`: coordinates in hyperspace, also output of the network
    Requires the following data blob keys
    - `input_data`
    - `segment_label` UResNet 5 classes label
    - `cluster_label`

    Output
    ------
    Writes 2 CSV files:
    - `instance_clustering-*` with the clustering predictions (point type 0 =
    event data, point type 1 = predictions, point type 2 = T-SNE visualizations)
    - `instance_clustering_metrics-*` with some event-wise metrics such as AMI and ARI.
    """

    method_cfg = cfg['post_processing']['instance_clustering']
    model_cfg  = cfg['model']['modules']['uresnet_clustering']

    tsne = TSNE(n_components = 2 if method_cfg is None else method_cfg.get('tsne_dim',2))
    compute_tsne = False if method_cfg is None else method_cfg.get('compute_tsne',False)
    
    store_per_iteration = True
    if method_cfg is not None and method_cfg.get('store_method',None) is not None:
        assert(method_cfg['store_method'] in ['per-iteration','per-event'])
        store_per_iteration = method_cfg['store_method'] == 'per-iteration'
    fout_cluster,fout_metric=None,None
    if store_per_iteration:
        fout_cluster=CSVData(os.path.join(logdir, 'instance-clustering-iter-%07d.csv' % iteration))
        fout_metrics=CSVData(os.path.join(logdir, 'instance-clustering-metrics-iter-%07d.csv' % iteration))

    model_cfg = cfg['model']['modules']['uresnet_clustering']
    data_dim = model_cfg.get('data_dim', 3)
    depth = model_cfg.get('num_strides', 5)
    num_classes = model_cfg.get('num_classes', 5)

    # Loop over batch index
    for batch_index, event_data in enumerate(data_blob['input_data']):
        
        event_index = data_blob['index'][batch_index]
        
        if not store_per_iteration:
            fout_cluster=CSVData(os.path.join(logdir, 'instance-clustering-iter-%07d.csv' % event_index))
            fout_metrics=CSVData(os.path.join(logdir, 'instance-clustering-metrics-iter-%07d.csv' % event_index))
            
        event_segmentation = res['segmentation'][batch_index]
        event_label = data_blob['segment_label'][batch_index]
        event_cluster_label = data_blob['cluster_label'][batch_index]
        max_depth = len(event_cluster_label)
        for d, event_feature_map in enumerate(res['cluster_feature'][batch_index]):
            coords = event_feature_map[:, :data_dim]
            perm = np.lexsort((coords[:, 2], coords[:, 1], coords[:, 0]))
            coords = coords[perm]
            class_label = event_label[-(d+1+max_depth-depth)]
            cluster_count = 0
            for class_ in range(num_classes):
                class_index = class_label[:, -1] == class_
                if np.count_nonzero(class_index) == 0:
                    continue
                clusters_label = event_cluster_label[-(d+1+max_depth-depth)][class_index]
                embedding = event_feature_map[perm][class_index]
                # DBSCAN in high dimension embedding
                predicted_clusters = DBSCAN(eps=20, min_samples=1).fit(embedding).labels_
                predicted_clusters += cluster_count  # To avoid overlapping id
                cluster_count += len(np.unique(predicted_clusters))

                # Cluster similarity metrics
                ARI = metrics.adjusted_rand_score(clusters_label[:, -1], predicted_clusters)
                AMI = metrics.adjusted_mutual_info_score(clusters_label[:, -1], predicted_clusters)
                fout_metrics.record(('class', 'batch_id', 'AMI', 'ARI', 'idx'),
                                    (class_, batch_index, AMI, ARI, event_index))
                fout_metrics.write()

                for i, point in enumerate(clusters_label):
                    fout_cluster.record(('type', 'x', 'y', 'z', 'batch_id', 'value', 'predicted_class', 'true_class', 'true_cluster_id', 'predicted_cluster_id', 'idx'),
                                        (1, point[0], point[1], point[2], batch_index, d, -1, class_label[class_index][i, -1], clusters_label[i, -1], predicted_clusters[i], event_index))
                    fout_cluster.write()
                # TSNE to visualize embedding
                if compute_tsne and embedding.shape[0] > 1:
                    new_embedding = tsne.fit_transform(embedding)
                    for i, point in enumerate(new_embedding):
                        fout_cluster.record(('type', 'x', 'y', 'z', 'batch_id', 'value', 'predicted_class', 'true_class', 'true_cluster_id', 'predicted_cluster_id', 'idx'),
                                            (2, point[0], point[1], -1, batch_index, d, -1, class_label[class_index][i, -1], clusters_label[i, -1], predicted_clusters[i], event_index))
                        fout_cluster.write()

        # Record in CSV everything
        perm = np.lexsort((event_data[:, 2], event_data[:, 1], event_data[:, 0]))
        event_data = event_data[perm]
        event_segmentation = event_segmentation[perm]
        # Point in data and semantic class predictions/true information
        for i, point in enumerate(event_data):
            fout_cluster.record(('type', 'x', 'y', 'z', 'batch_id', 'value', 'predicted_class', 'true_class', 'true_cluster_id', 'predicted_cluster_id', 'idx'),
                                (0, point[0], point[1], point[2], batch_index, point[4], np.argmax(event_segmentation[i]), event_label[0][:,-1][i], -1, -1, event_index))
            fout_cluster.write()

        if not store_per_iteration:
            fout_cluster.close()
            fout_metrics.close()
    if store_per_iteration:
        fout_cluster.close()
        fout_metrics.close()
コード例 #5
0
def michel_reconstruction_2d(cfg, data_blob, res, logdir, iteration):
    """
    Very simple algorithm to reconstruct Michel clusters from UResNet semantic
    segmentation output.

    Parameters
    ----------
    data_blob: dict
        Input dictionary returned by iotools
    res: dict
        Results from the network, dictionary using `analysis_keys`
    cfg: dict
        Configuration
    idx: int
        Iteration number

    Notes
    -----
    Assumes 2D

    Input
    -----
    Requires the following analysis keys:
    - `segmentation` output of UResNet
    Requires the following input keys:
    - `input_data`
    - `segment_label`
    - `particles_label` to get detailed information such as energy.
    - `clusters_label` from `cluster3d_mcst` for true clusters informations

    Output
    ------
    Writes 2 CSV files:
    - `michel_reconstruction-*`
    - `michel_reconstruction2-*`
    """
    method_cfg = cfg['post_processing']['michel_reconstruction_2d']

    # Create output CSV
    store_per_iteration = True
    if method_cfg is not None and method_cfg.get('store_method',
                                                 None) is not None:
        assert (method_cfg['store_method'] in ['per-iteration', 'per-event'])
        store_per_iteration = method_cfg['store_method'] == 'per-iteration'

    fout_reco, fout_true = None, None
    if store_per_iteration:
        fout_reco = CSVData(
            os.path.join(
                logdir,
                'michel-reconstruction-reco-iter-%07d.csv' % iteration))
        fout_true = CSVData(
            os.path.join(
                logdir,
                'michel-reconstruction-true-iter-%07d.csv' % iteration))

    # Loop over events
    for batch_id, data in enumerate(data_blob['input_data']):

        event_idx = data_blob['index'][batch_id]

        if not store_per_iteration:
            fout_reco = CSVData(
                os.path.join(
                    logdir,
                    'michel-reconstruction-reco-event-%07d.csv' % event_idx))
            fout_true = CSVData(
                os.path.join(
                    logdir,
                    'michel-reconstruction-true-event-%07d.csv' % event_idx))

        # from input/labels
        data = data_blob['input_data'][batch_id]
        label = data_blob['segment_label'][batch_id][:, -1]
        meta = data_blob['meta'][batch_id]

        # clusters    = data_blob['clusters_label' ][batch_id]
        # particles   = data_blob['particles_label'][batch_id]

        # Michel_particles = particles[particles[:, 2] == 4]  # FIXME 3 or 4 in 2D? Also not sure if type is registered for Michel

        # from network output
        segmentation = res['segmentation'][batch_id]
        predictions = np.argmax(segmentation, axis=1)
        Michel_label = 3
        MIP_label = 0

        data_dim = 2
        # 0. Retrieve coordinates of true and predicted Michels
        MIP_coords = data[(label == MIP_label).reshape((-1, )),
                          ...][:, :data_dim]
        Michel_coords = data[(label == Michel_label).reshape((-1, )),
                             ...][:, :data_dim]
        # MIP_coords = clusters[clusters[:, -1] == 1][:, :3]
        # Michel_coords = clusters[clusters[:, -1] == 4][:, :3]
        if Michel_coords.shape[0] == 0:  # FIXME
            continue
        MIP_coords_pred = data[(predictions == MIP_label).reshape((-1, )),
                               ...][:, :data_dim]
        Michel_coords_pred = data[(predictions == Michel_label).reshape(
            (-1, )), ...][:, :data_dim]

        # DBSCAN epsilon used for many things... TODO list here
        #one_pixel = 15#2.8284271247461903
        one_pixel_dbscan = 5
        one_pixel_is_attached = 2
        # 1. Find true particle information matching the true Michel cluster
        Michel_true_clusters = DBSCAN(eps=one_pixel_dbscan,
                                      min_samples=5).fit(Michel_coords).labels_
        MIP_true_clusters = DBSCAN(eps=one_pixel_dbscan,
                                   min_samples=5).fit(MIP_coords).labels_

        # compute all edges of true MIP clusters
        MIP_edges = []
        for cluster in np.unique(MIP_true_clusters[MIP_true_clusters > -1]):
            touching_idx = find_edges(MIP_coords[MIP_true_clusters == cluster])
            MIP_edges.append(
                MIP_coords[MIP_true_clusters == cluster][touching_idx[0]])
            MIP_edges.append(
                MIP_coords[MIP_true_clusters == cluster][touching_idx[1]])
        # Michel_true_clusters = [Michel_coords[Michel_coords[:, -2] == gid] for gid in np.unique(Michel_coords[:, -2])]
        # Michel_true_clusters = clusters[clusters[:, -1] == 4][:, -2].astype(np.int64)
        # Michel_start = Michel_particles[:, :data_dim]

        true_Michel_is_attached = {}
        true_Michel_is_edge = {}
        true_Michel_is_too_close = {}
        for cluster in np.unique(Michel_true_clusters):
            min_y = Michel_coords[Michel_true_clusters ==
                                  cluster][:, 1].min()  # * meta[-1] + meta[1]
            max_y = Michel_coords[Michel_true_clusters ==
                                  cluster][:, 1].max()  # * meta[-1] + meta[1]
            min_x = Michel_coords[Michel_true_clusters ==
                                  cluster][:, 0].min()  # * meta[-2] + meta[0]
            max_x = Michel_coords[Michel_true_clusters ==
                                  cluster][:, 0].max()  # * meta[-2] + meta[0]

            # Find coordinates of Michel pixel touching MIP edge
            Michel_edges_idx = find_edges(
                Michel_coords[Michel_true_clusters == cluster])
            distances = cdist(
                Michel_coords[Michel_true_clusters == cluster]
                [Michel_edges_idx], MIP_coords[MIP_true_clusters > -1])

            # Make sure true Michel is attached at edge of MIP
            Michel_min, MIP_min = np.unravel_index(np.argmin(distances),
                                                   distances.shape)
            is_attached = np.min(distances) < one_pixel_is_attached
            is_too_close = np.max(distances) < one_pixel_is_attached
            # Check whether the Michel is at the edge of a predicted MIP
            # From the MIP pixel closest to the Michel, remove all pixels in
            # a radius of 15px. DBSCAN what is left and make sure it is all in
            # one single piece.
            is_edge = False  # default
            if is_attached:
                # cluster id of MIP closest
                MIP_id = MIP_true_clusters[MIP_true_clusters > -1][MIP_min]
                # coordinates of closest MIP pixel in this cluster
                MIP_min_coords = MIP_coords[MIP_true_clusters > -1][MIP_min]
                # coordinates of the whole cluster
                MIP_cluster_coords = MIP_coords[MIP_true_clusters == MIP_id]
                is_edge = is_at_edge(MIP_cluster_coords,
                                     MIP_min_coords,
                                     one_pixel=one_pixel_dbscan,
                                     radius=15.0)
            true_Michel_is_attached[cluster] = is_attached
            true_Michel_is_edge[cluster] = is_edge
            true_Michel_is_too_close[cluster] = is_too_close

            # these are the coordinates of Michel edges
            edge1_x = Michel_coords[Michel_true_clusters == cluster][
                Michel_edges_idx[0], 0]
            edge1_y = Michel_coords[Michel_true_clusters == cluster][
                Michel_edges_idx[0], 1]
            edge2_x = Michel_coords[Michel_true_clusters == cluster][
                Michel_edges_idx[1], 0]
            edge2_y = Michel_coords[Michel_true_clusters == cluster][
                Michel_edges_idx[1], 1]

            # Find for each Michel edge the closest MIP pixels
            # Check for each of these whether they are at the edge of MIP
            # FIXME what happens if both are at the edge of a MIP?? unlikely
            closest_MIP_pixels = np.argmin(distances, axis=1)
            clusters_idx = MIP_true_clusters[closest_MIP_pixels]
            edge0 = is_at_edge(
                MIP_coords[MIP_true_clusters == clusters_idx[0]],
                MIP_coords[closest_MIP_pixels[0]],
                one_pixel=one_pixel_dbscan,
                radius=10.0)
            edge1 = is_at_edge(
                MIP_coords[MIP_true_clusters == clusters_idx[1]],
                MIP_coords[closest_MIP_pixels[1]],
                one_pixel=one_pixel_dbscan,
                radius=10.0)
            if edge0 and not edge1:
                touching_x = edge1_x
                touching_y = edge1_y
            elif not edge0 and edge1:
                touching_x = edge2_x
                touching_y = edge2_y
            else:
                if distances[0, closest_MIP_pixels[0]] < distances[
                        1, closest_MIP_pixels[1]]:
                    touching_x = edge1_x
                    touching_y = edge1_y
                else:
                    touching_x = edge2_x
                    touching_y = edge2_y
            # touching_idx = np.unravel_index(np.argmin(distances), distances.shape)
            # touching_x = Michel_coords[Michel_true_clusters == cluster][Michel_edges_idx][touching_idx[0], 0]
            # touching_y = Michel_coords[Michel_true_clusters == cluster][Michel_edges_idx][touching_idx[0], 1]
            #
            # if touching_x not in [edge1_x, edge2_x] or touching_y not in [edge1_y, edge2_y]:
            #     print('true', event_idx, touching_x, touching_y, edge1_x, edge1_y, edge2_x, edge2_y)
            #if event_idx == 127:
            #    print('true', touching_x, touching_y, edge1_x, edge1_y, edge2_x, edge2_y)
            fout_true.record(
                ('batch_id', 'iteration', 'event_idx', 'num_pix', 'sum_pix',
                 'min_y', 'max_y', 'min_x', 'max_x', 'pixel_width',
                 'pixel_height', 'meta_min_x', 'meta_min_y', 'touching_x',
                 'touching_y', 'edge1_x', 'edge1_y', 'edge2_x', 'edge2_y',
                 'edge0', 'edge1', 'is_attached', 'is_edge', 'is_too_close',
                 'cluster_id'),
                (
                    batch_id,
                    iteration,
                    event_idx,
                    np.count_nonzero(Michel_true_clusters == cluster),
                    data[(label == Michel_label).reshape((-1, )),
                         ...][Michel_true_clusters == cluster][:, -1].sum(),
                    # clusters[clusters[:, -1] == 4][Michel_true_clusters == cluster][:, -3].sum()
                    min_y,
                    max_y,
                    min_x,
                    max_x,
                    meta[-2],
                    meta[-1],
                    meta[0],
                    meta[1],
                    touching_x,
                    touching_y,
                    edge1_x,
                    edge1_y,
                    edge2_x,
                    edge2_y,
                    edge0,
                    edge1,
                    is_attached,
                    is_edge,
                    is_too_close,
                    cluster))
            fout_true.write()
        # e.g. deposited energy, creation energy
        # TODO retrieve particles information
        # if Michel_coords.shape[0] > 0:
        #     Michel_clusters_id = np.unique(Michel_true_clusters[Michel_true_clusters>-1])
        #     for Michel_id in Michel_clusters_id:
        #         current_index = Michel_true_clusters == Michel_id
        #         distances = cdist(Michel_coords[current_index], MIP_coords)
        #         is_attached = np.min(distances) < 2.8284271247461903
        #         # Match to MC Michel
        #         distances2 = cdist(Michel_coords[current_index], Michel_start)
        #         closest_mc = np.argmin(distances2, axis=1)
        #         closest_mc_id = closest_mc[np.bincount(closest_mc).argmax()]

        # TODO how do we count events where there are no predictions but true?
        if MIP_coords_pred.shape[0] == 0 or Michel_coords_pred.shape[0] == 0:
            continue

        #
        # 2. Compute true and predicted clusters
        #
        MIP_clusters = DBSCAN(eps=one_pixel_dbscan,
                              min_samples=10).fit(MIP_coords_pred).labels_
        MIP_clusters_id = np.unique(MIP_clusters[MIP_clusters > -1])

        # If no predicted MIP then continue TODO how do we count this?
        if MIP_coords_pred[MIP_clusters > -1].shape[0] == 0:
            continue

        # MIP_edges = []
        # for cluster in MIP_clusters_id:
        #     touching_idx = find_edges(MIP_coords_pred[MIP_clusters == cluster])
        #     MIP_edges.append(MIP_coords_pred[MIP_clusters == cluster][touching_idx[0]])
        #     MIP_edges.append(MIP_coords_pred[MIP_clusters == cluster][touching_idx[1]])

        Michel_pred_clusters = DBSCAN(
            eps=one_pixel_dbscan,
            min_samples=5).fit(Michel_coords_pred).labels_
        Michel_pred_clusters_id = np.unique(
            Michel_pred_clusters[Michel_pred_clusters > -1])
        # print(len(Michel_pred_clusters_id))

        # Loop over predicted Michel clusters
        for Michel_id in Michel_pred_clusters_id:
            current_index = Michel_pred_clusters == Michel_id
            # 3. Check whether predicted Michel is attached to a predicted MIP
            # and at the edge of the predicted MIP
            Michel_edges_idx = find_edges(Michel_coords_pred[current_index])
            # distances_edges = cdist(Michel_coords_pred[current_index][Michel_edges_idx], MIP_edges)
            # distances = cdist(Michel_coords_pred[current_index], MIP_coords_pred[MIP_clusters>-1])
            distances = cdist(
                Michel_coords_pred[current_index][Michel_edges_idx],
                MIP_coords_pred[MIP_clusters > -1])
            Michel_min, MIP_min = np.unravel_index(np.argmin(distances),
                                                   distances.shape)
            is_attached = np.min(distances) < one_pixel_is_attached
            is_too_close = np.max(distances) < one_pixel_is_attached
            # Check whether the Michel is at the edge of a predicted MIP
            # From the MIP pixel closest to the Michel, remove all pixels in
            # a radius of 15px. DBSCAN what is left and make sure it is all in
            # one single piece.
            is_edge = False  # default
            if is_attached:
                # cluster id of MIP closest
                MIP_id = MIP_clusters[MIP_clusters > -1][MIP_min]
                # coordinates of closest MIP pixel in this cluster
                MIP_min_coords = MIP_coords_pred[MIP_clusters > -1][MIP_min]
                # coordinates of the whole cluster
                MIP_cluster_coords = MIP_coords_pred[MIP_clusters == MIP_id]
                is_edge = is_at_edge(MIP_cluster_coords,
                                     MIP_min_coords,
                                     one_pixel=one_pixel_dbscan,
                                     radius=15.0)

            michel_pred_num_pix_true, michel_pred_sum_pix_true = -1, -1
            michel_true_num_pix, michel_true_sum_pix = -1, -1
            michel_true_energy = -1
            touching_x, touching_y = -1, -1
            edge1_x, edge1_y, edge2_x, edge2_y = -1, -1, -1, -1
            true_is_attached, true_is_edge, true_is_too_close = -1, -1, -1
            closest_true_id = -1

            # Find point where MIP and Michel touches
            # touching_idx = np.unravel_index(np.argmin(distances), distances.shape)
            # touching_x = Michel_coords_pred[current_index][Michel_edges_idx][touching_idx[0], 0]
            # touching_y = Michel_coords_pred[current_index][Michel_edges_idx][touching_idx[0], 1]
            edge1_x = Michel_coords_pred[current_index][Michel_edges_idx[0], 0]
            edge1_y = Michel_coords_pred[current_index][Michel_edges_idx[0], 1]
            edge2_x = Michel_coords_pred[current_index][Michel_edges_idx[1], 0]
            edge2_y = Michel_coords_pred[current_index][Michel_edges_idx[1], 1]

            closest_MIP_pixels = np.argmin(distances, axis=1)
            clusters_idx = MIP_clusters[MIP_clusters > -1][np.argmin(distances,
                                                                     axis=1)]
            edge0 = is_at_edge(
                MIP_coords_pred[MIP_clusters == clusters_idx[0]],
                MIP_coords_pred[MIP_clusters > -1][closest_MIP_pixels[0]],
                one_pixel=one_pixel_dbscan,
                radius=10.0)
            edge1 = is_at_edge(
                MIP_coords_pred[MIP_clusters == clusters_idx[1]],
                MIP_coords_pred[MIP_clusters > -1][closest_MIP_pixels[1]],
                one_pixel=one_pixel_dbscan,
                radius=10.0)
            if edge0 and not edge1:
                touching_x = edge1_x
                touching_y = edge1_y
            elif not edge0 and edge1:
                touching_x = edge2_x
                touching_y = edge2_y
            else:
                if distances[0, closest_MIP_pixels[0]] < distances[
                        1, closest_MIP_pixels[1]]:
                    touching_x = edge1_x
                    touching_y = edge1_y
                else:
                    touching_x = edge2_x
                    touching_y = edge2_y

            if is_attached and is_edge:
                # Distance from current Michel pred cluster to all true points
                distances = cdist(Michel_coords_pred[current_index],
                                  Michel_coords)
                closest_clusters = Michel_true_clusters[np.argmin(distances,
                                                                  axis=1)]
                closest_clusters_final = closest_clusters[
                    (closest_clusters > -1)
                    & (np.min(distances, axis=1) < one_pixel_dbscan)]
                if len(closest_clusters_final) > 0:
                    # print(closest_clusters_final, np.bincount(closest_clusters_final), np.bincount(closest_clusters_final).argmax())
                    # cluster id of closest true Michel cluster
                    # we take the one that has most overlap
                    # closest_true_id = closest_clusters_final[np.bincount(closest_clusters_final).argmax()]
                    closest_true_id = np.bincount(
                        closest_clusters_final).argmax()
                    overlap_pixels_index = (
                        closest_clusters == closest_true_id) & (np.min(
                            distances, axis=1) < one_pixel_dbscan)
                    if closest_true_id > -1:
                        closest_true_index = label[
                            predictions ==
                            Michel_label][current_index] == Michel_label
                        # Intersection
                        michel_pred_num_pix_true = 0
                        michel_pred_sum_pix_true = 0.
                        for v in data[(predictions == Michel_label).reshape(
                            (-1, )), ...][current_index]:
                            count = int(
                                np.any(
                                    np.all(v[:data_dim] ==
                                           Michel_coords[Michel_true_clusters
                                                         == closest_true_id],
                                           axis=1)))
                            michel_pred_num_pix_true += count
                            if count > 0:
                                michel_pred_sum_pix_true += v[-1]

                        michel_true_num_pix = np.count_nonzero(
                            Michel_true_clusters == closest_true_id)
                        # michel_true_sum_pix = clusters[clusters[:, -1] == 4][Michel_true_clusters == closest_true_id][:, -3].sum()
                        michel_true_sum_pix = data[(
                            label == Michel_label).reshape(
                                (-1, )), ...][Michel_true_clusters ==
                                              closest_true_id][:, -1].sum()

                        # Check whether true Michel is attached to MIP, otherwise exclude
                        true_is_attached = true_Michel_is_attached[
                            closest_true_id]
                        true_is_edge = true_Michel_is_edge[closest_true_id]
                        true_is_too_close = true_Michel_is_too_close[
                            closest_true_id]
                        # Register true energy
                        # Match to MC Michel
                        # FIXME in 2D Michel_start is no good
                        # distances2 = cdist(Michel_coords[Michel_true_clusters == closest_true_id], Michel_start)
                        # closest_mc = np.argmin(distances2, axis=1)
                        # closest_mc_id = closest_mc[np.bincount(closest_mc).argmax()]
                        # michel_true_energy = Michel_particles[closest_mc_id, 7]
                        michel_true_energy = -1

            # Record every predicted Michel cluster in CSV
            # Record min and max x in real coordinates
            min_y = Michel_coords_pred[current_index][:, 1].min(
            )  # * meta[-1] + meta[1]
            max_y = Michel_coords_pred[current_index][:, 1].max(
            )  # * meta[-1] + meta[1]
            min_x = Michel_coords_pred[current_index][:, 0].min(
            )  # * meta[-2] + meta[0]
            max_x = Michel_coords_pred[current_index][:, 0].max(
            )  # * meta[-2] + meta[0]
            fout_reco.record(
                ('batch_id', 'iteration', 'event_idx', 'pred_num_pix',
                 'pred_sum_pix', 'pred_num_pix_true', 'pred_sum_pix_true',
                 'true_num_pix', 'true_sum_pix', 'is_attached', 'is_edge',
                 'michel_true_energy', 'min_y', 'max_y', 'min_x', 'max_x',
                 'pixel_width', 'pixel_height', 'meta_min_x', 'meta_min_y',
                 'touching_x', 'touching_y', 'edge1_x', 'edge1_y', 'edge2_x',
                 'edge2_y', 'edge0', 'edge1', 'true_is_attached',
                 'true_is_edge', 'true_is_too_close', 'is_too_close',
                 'closest_true_index'),
                (batch_id, iteration, event_idx,
                 np.count_nonzero(current_index),
                 data[(predictions == Michel_label).reshape(
                     (-1, )), ...][current_index][:, -1].sum(),
                 michel_pred_num_pix_true, michel_pred_sum_pix_true,
                 michel_true_num_pix, michel_true_sum_pix, is_attached,
                 is_edge, michel_true_energy, min_y, max_y, min_x, max_x,
                 meta[-2], meta[-1], meta[0], meta[1], touching_x, touching_y,
                 edge1_x, edge1_y, edge2_x, edge2_y, edge0, edge1,
                 true_is_attached, true_is_edge, true_is_too_close,
                 is_too_close, closest_true_id))
            fout_reco.write()

        if not store_per_iteration:
            fout_reco.close()
            fout_true.close()

    if store_per_iteration:
        fout_reco.close()
        fout_true.close()
コード例 #6
0
def store_uresnet(cfg, data_blob, res, logdir, iteration):
    # UResNet prediction
    if not 'segmentation' in res: return

    method_cfg = cfg['post_processing']['store_uresnet']

    index = data_blob['index']
    segment_data = res['segmentation']
    input_data = data_blob.get(
        'input_data' if method_cfg is None else method_cfg.get(
            'input_data', 'input_data'), None)

    store_per_iteration = True
    if method_cfg is not None and method_cfg.get('store_method',
                                                 None) is not None:
        assert (method_cfg['store_method'] in ['per-iteration', 'per-event'])
        store_per_iteration = method_cfg['store_method'] == 'per-iteration'
    fout = None
    if store_per_iteration:
        fout = CSVData(
            os.path.join(logdir,
                         'uresnet-segmentation-iter-%07d.csv' % iteration))

    for data_idx, tree_idx in enumerate(index):

        if not store_per_iteration:
            fout = CSVData(
                os.path.join(logdir,
                             'uresnet-segmentation-event-%07d.csv' % tree_idx))

        predictions = np.argmax(segment[data_idx], axis=1)
        for row in predictions:
            event = input_data[i]
            fout.record(('idx', 'x', 'y', 'z', 'type', 'value'),
                        (idx, event[0], event[1], event[2], 4, row))
            fout.write()

        if not store_per_iteration: fout.close()

    if store_per_iteration: fout.close()
コード例 #7
0
def michel_reconstruction(cfg, data_blob, res, logdir, iteration):
    """
    Very simple algorithm to reconstruct Michel clusters from UResNet semantic
    segmentation output.

    Parameters
    ----------
    data_blob: dict
        Input dictionary returned by iotools
    res: dict
        Results from the network, dictionary using `analysis_keys`
    cfg: dict
        Configuration
    idx: int
        Iteration number

    Notes
    -----
    Assumes 3D

    Input
    -----
    Requires the following analysis keys:
    - `segmentation` output of UResNet
    - `ghost` predictions of GhostNet
    Requires the following input keys:
    - `input_data`
    - `segment_label`
    - `particles_label` to get detailed information such as energy.
    - `clusters_label` from `cluster3d_mcst` for true clusters informations

    Output
    ------
    Writes 2 CSV files:
    - `michel_reconstruction-*`
    - `michel_reconstruction2-*`
    """
    method_cfg = cfg['post_processing']['michel_reconstruction']

    # Create output CSV
    store_per_iteration = True
    if method_cfg is not None and method_cfg.get('store_method',
                                                 None) is not None:
        assert (method_cfg['store_method'] in ['per-iteration', 'per-event'])
        store_per_iteration = method_cfg['store_method'] == 'per-iteration'

    fout_reco, fout_true = None, None
    if store_per_iteration:
        fout_reco = CSVData(
            os.path.join(
                logdir,
                'michel-reconstruction-reco-iter-%07d.csv' % iteration))
        fout_true = CSVData(
            os.path.join(
                logdir,
                'michel-reconstruction-true-iter-%07d.csv' % iteration))

    # Loop over events
    for batch_id, data in enumerate(data_blob['input_data']):

        event_idx = data_blob['index'][batch_id]

        if not store_per_iteration:
            fout_reco = CSVData(
                os.path.join(
                    logdir,
                    'michel-reconstruction-reco-event-%07d.csv' % event_idx))
            fout_true = CSVData(
                os.path.join(
                    logdir,
                    'michel-reconstruction-true-event-%07d.csv' % event_idx))

        # from input/labels
        label = data_blob['segment_label'][batch_id][:, -1]
        label_raw = data_blob['sparse3d_pcluster_semantics'][batch_id]
        clusters = data_blob['clusters_label'][batch_id]
        particles = data_blob['particles_label'][batch_id]
        true_ghost_mask = label < 5
        data_masked = data[true_ghost_mask]
        label_masked = label[true_ghost_mask]

        one_pixel = 5  #2.8284271247461903

        # Retrieve semantic labels corresponding to clusters
        clusters_semantics = np.zeros((clusters.shape[0])) - 1
        for cluster_id in np.unique(clusters[:, -2]):
            cluster_idx = clusters[:, -2] == cluster_id
            coords = clusters[cluster_idx][:, :3]
            d = cdist(coords, label_raw[:, :3])
            semantic_id = np.bincount(label_raw[d.argmin(
                axis=1)[d.min(axis=1) < one_pixel]][:,
                                                    -1].astype(int)).argmax()
            clusters_semantics[cluster_idx] = semantic_id

        # Find cluster id for semantics_reco
        # clusters_new = np.ones((label_masked.shape[0],))*-1
        # clusters_E = np.ones((label_masked.shape[0],))
        # for cluster_id in np.unique(clusters[:, -2]):
        #     cluster_idx = clusters[:, -2] == cluster_id
        #     coords = clusters[cluster_idx][:, :3]
        #     d = cdist(coords, data_masked[:, :3])
        #     overlap_idx = d.argmin(axis=0)[d.min(axis=0)<one_pixel]
        #     clusters_new[overlap_idx] = np.bincount(clusters[cluster_idx][d.argmin(axis=0)[d.min(axis=0)<one_pixel]][:, -2].astype(int)).argmax()
        #     clusters_E[overlap_idx] = clusters[cluster_idx][overlap_idx][:, -1]
        #     #clusters_new[overlap_idx][:, :3] = data_masked[overlap_idx][:, :3]
        # print('clusters new', np.unique(clusters_new, return_counts=True))

        # from network output
        segmentation = res['segmentation'][batch_id]
        predictions = np.argmax(segmentation, axis=1)
        ghost_mask = (np.argmax(res['ghost'][batch_id], axis=1) == 0)

        data_pred = data[ghost_mask]  # coords
        label_pred = label[ghost_mask]  # labels
        predictions = (np.argmax(segmentation, axis=1))[ghost_mask]
        segmentation = segmentation[ghost_mask]

        Michel_label = 2
        MIP_label = 1

        # 0. Retrieve coordinates of true and predicted Michels
        # MIP_coords = data[(label == 1).reshape((-1,)), ...][:, :3]
        # Michel_coords = data[(label == 4).reshape((-1,)), ...][:, :3]
        # Michel_particles = particles[particles[:, 4] == Michel_label]
        MIP_coords = data[label == MIP_label][:, :3]
        # Michel_coords = data[label == Michel_label][:, :3]
        Michel_coords = clusters[clusters_semantics == Michel_label][:, :3]
        if Michel_coords.shape[0] == 0:  # FIXME
            continue
        MIP_coords_pred = data_pred[(predictions == MIP_label).reshape((-1, )),
                                    ...][:, :3]
        Michel_coords_pred = data_pred[(predictions == Michel_label).reshape(
            (-1, )), ...][:, :3]

        # 1. Find true particle information matching the true Michel cluster
        # Michel_true_clusters = DBSCAN(eps=one_pixel, min_samples=5).fit(Michel_coords).labels_
        # Michel_true_clusters = [Michel_coords[Michel_coords[:, -2] == gid] for gid in np.unique(Michel_coords[:, -2])]
        #print(clusters.shape, label.shape)
        Michel_true_clusters = clusters[clusters_semantics ==
                                        Michel_label][:, -2].astype(np.int64)
        # Michel_start = Michel_particles[:, :3]
        for cluster in np.unique(Michel_true_clusters):
            # print("True", np.count_nonzero(Michel_true_clusters == cluster))
            # TODO sum_pix
            fout_true.record(
                ('batch_id', 'iteration', 'event_idx', 'num_pix', 'sum_pix'),
                (batch_id, iteration, event_idx,
                 np.count_nonzero(Michel_true_clusters == cluster),
                 clusters[clusters_semantics == Michel_label][
                     Michel_true_clusters == cluster][:, -1].sum()))
            fout_true.write()
        # e.g. deposited energy, creation energy
        # TODO retrieve particles information
        # if Michel_coords.shape[0] > 0:
        #     Michel_clusters_id = np.unique(Michel_true_clusters[Michel_true_clusters>-1])
        #     for Michel_id in Michel_clusters_id:
        #         current_index = Michel_true_clusters == Michel_id
        #         distances = cdist(Michel_coords[current_index], MIP_coords)
        #         is_attached = np.min(distances) < 2.8284271247461903
        #         # Match to MC Michel
        #         distances2 = cdist(Michel_coords[current_index], Michel_start)
        #         closest_mc = np.argmin(distances2, axis=1)
        #         closest_mc_id = closest_mc[np.bincount(closest_mc).argmax()]

        # TODO how do we count events where there are no predictions but true?
        if MIP_coords_pred.shape[0] == 0 or Michel_coords_pred.shape[0] == 0:
            continue
        # print("Also predicted!")
        # 2. Compute true and predicted clusters
        MIP_clusters = DBSCAN(eps=one_pixel,
                              min_samples=10).fit(MIP_coords_pred).labels_
        Michel_pred_clusters = DBSCAN(
            eps=one_pixel, min_samples=5).fit(Michel_coords_pred).labels_
        Michel_pred_clusters_id = np.unique(
            Michel_pred_clusters[Michel_pred_clusters > -1])
        # print(len(Michel_pred_clusters_id))
        # Loop over predicted Michel clusters
        for Michel_id in Michel_pred_clusters_id:
            current_index = Michel_pred_clusters == Michel_id
            # 3. Check whether predicted Michel is attached to a predicted MIP
            # and at the edge of the predicted MIP
            distances = cdist(Michel_coords_pred[current_index],
                              MIP_coords_pred[MIP_clusters > -1])
            # is_attached = np.min(distances) < 2.8284271247461903
            is_attached = np.min(distances) < 5
            is_edge = False  # default
            # print("Min distance:", np.min(distances))
            if is_attached:
                Michel_min, MIP_min = np.unravel_index(np.argmin(distances),
                                                       distances.shape)
                MIP_id = MIP_clusters[MIP_clusters > -1][MIP_min]
                MIP_min_coords = MIP_coords_pred[MIP_clusters > -1][MIP_min]
                MIP_cluster_coords = MIP_coords_pred[MIP_clusters == MIP_id]
                ablated_cluster = MIP_cluster_coords[np.linalg.norm(
                    MIP_cluster_coords - MIP_min_coords, axis=1) > 15.0]
                if ablated_cluster.shape[0] > 0:
                    new_cluster = DBSCAN(
                        eps=one_pixel,
                        min_samples=5).fit(ablated_cluster).labels_
                    is_edge = len(np.unique(
                        new_cluster[new_cluster > -1])) == MIP_label
                else:
                    is_edge = True
            # print(is_attached, is_edge)

            michel_pred_num_pix_true, michel_pred_sum_pix_true = -1, -1
            michel_true_num_pix, michel_true_sum_pix = -1, -1
            michel_true_energy = -1
            if is_attached and is_edge and Michel_coords.shape[0] > 0:
                # Distance from current Michel pred cluster to all true points
                distances = cdist(Michel_coords_pred[current_index],
                                  Michel_coords)
                closest_clusters = Michel_true_clusters[np.argmin(distances,
                                                                  axis=1)]
                closest_clusters_final = closest_clusters[
                    (closest_clusters > -1)
                    & (np.min(distances, axis=1) < one_pixel)]
                if len(closest_clusters_final) > 0:
                    # print(closest_clusters_final, np.bincount(closest_clusters_final), np.bincount(closest_clusters_final).argmax())
                    # cluster id of closest true Michel cluster
                    # we take the one that has most overlap
                    # closest_true_id = closest_clusters_final[np.bincount(closest_clusters_final).argmax()]
                    closest_true_id = np.bincount(
                        closest_clusters_final).argmax()
                    overlap_pixels_index = (closest_clusters
                                            == closest_true_id) & (np.min(
                                                distances, axis=1) < one_pixel)
                    if closest_true_id > -1:
                        closest_true_index = label_pred[
                            predictions ==
                            Michel_label][current_index] == Michel_label
                        # Intersection
                        michel_pred_num_pix_true = 0
                        michel_pred_sum_pix_true = 0.
                        for v in data_pred[(
                                predictions == Michel_label).reshape((-1, )),
                                           ...][current_index]:
                            count = int(
                                np.any(
                                    np.all(v[:3] ==
                                           Michel_coords[Michel_true_clusters
                                                         == closest_true_id],
                                           axis=1)))
                            michel_pred_num_pix_true += count
                            if count > 0:
                                michel_pred_sum_pix_true += v[-1]

                        michel_true_num_pix = np.count_nonzero(
                            Michel_true_clusters == closest_true_id)
                        michel_true_sum_pix = clusters[
                            clusters_semantics == Michel_label][
                                Michel_true_clusters ==
                                closest_true_id][:, -1].sum()
                        # Register true energy
                        # Match to MC Michel
                        # distances2 = cdist(Michel_coords[Michel_true_clusters == closest_true_id], Michel_start)
                        # closest_mc = np.argmin(distances2, axis=1)
                        # closest_mc_id = closest_mc[np.bincount(closest_mc).argmax()]
                        # michel_true_energy = Michel_particles[closest_mc_id, 7]
                        michel_true_energy = particles[
                            closest_true_id].energy_init()
                        #print('michel true energy', particles[closest_true_id].energy_init(), particles[closest_true_id].pdg_code(), particles[closest_true_id].energy_deposit())
            # Record every predicted Michel cluster in CSV
            fout_reco.record(
                ('batch_id', 'iteration', 'event_idx', 'pred_num_pix',
                 'pred_sum_pix', 'pred_num_pix_true', 'pred_sum_pix_true',
                 'true_num_pix', 'true_sum_pix', 'is_attached', 'is_edge',
                 'michel_true_energy'),
                (batch_id, iteration, event_idx,
                 np.count_nonzero(current_index),
                 data_pred[(predictions == Michel_label).reshape(
                     (-1, )), ...][current_index][:, -1].sum(),
                 michel_pred_num_pix_true, michel_pred_sum_pix_true,
                 michel_true_num_pix, michel_true_sum_pix, is_attached,
                 is_edge, michel_true_energy))
            fout_reco.write()

        if not store_per_iteration:
            fout_reco.close()
            fout_true.close()

    if store_per_iteration:
        fout_reco.close()
        fout_true.close()
コード例 #8
0
ファイル: chain.py プロジェクト: peter-madigan/pi0_reco
class Pi0Chain():

    def __init__(self, io_cfg, chain_cfg, verbose=False):
        '''
        Initializes the chain from the configuration file
        '''
        # Initialize the data loader
        io_cfg = yaml.load(io_cfg,Loader=yaml.Loader)

        # Save config, initialize output
        self.cfg = chain_cfg
        self.verbose = verbose
        self.event = None
        self.output = {}

        # Initialize log
        log_path = chain_cfg['name']+'_log.csv'
        print('Initialized Pi0 mass chain, log path:', log_path)
        self._log = CSVData(log_path)
        self._keys = ['event_id', 'pion_id', 'pion_mass']

        # If a network is specified, initialize the network
        self.network = False
        if chain_cfg['segment'] == 'uresnet' or chain_cfg['shower_start'] == 'ppn':
            self.network = True
            with open(chain_cfg['net_cfg']) as cfg_file:
                net_cfg = yaml.load(cfg_file,Loader=yaml.Loader)
            io_cfg['model'] = net_cfg['model']
            io_cfg['trainval'] = net_cfg['trainval']
            
        # Initialize the fragment identifier
        self.frag_est = FragmentEstimator()
            
        # If a direction estimator is requested, initialize it
        if chain_cfg['shower_dir'] != 'truth':
            self.dir_est = DirectionEstimator()
            
        # If a clusterer is requested, initialize it
        if chain_cfg['shower_energy'] == 'cone':
            self.clusterer = ConeClusterer()
            
        # If a pi0 identifier is requested, initialize it
        if chain_cfg['shower_match'] == 'proximity':
            self.matcher = Pi0Matcher()

        # Pre-process configuration
        process_config(io_cfg)

        # Instantiate "handlers" (IO tools)
        self.hs = prepare(io_cfg)
        self.data_set = iter(self.hs.data_io)

    def hs(self):
        return self.hs

    def data_set(self):
        return self.data_set

    def log(self, eid, pion_id, pion_mass):
        self._log.record(self._keys, [eid, pion_id, pion_mass])
        self._log.write()
        self._log.flush()

    def run(self):
        '''
        Runs the full Pi0 reconstruction chain, from 3D charge
        information to Pi0 masses for events that contain one
        or more Pi0 decay.
        '''
        n_events = len(self.hs.data_io)
        for i in range(n_events):
            self.run_loop()

    def run_loop(self):
        '''
        Runs the full Pi0 reconstruction chain on a single event,
        from 3D charge information to Pi0 masses for events that
        contain one or more Pi0 decay.
        '''
        # Reset output
        self.output = {}
        
        # Load data
        if not self.network:
            event = next(self.data_set)
            event_id = event['index'][0]
        else:
            event, self.output['forward'] = self.hs.trainer.forward(self.data_set)
            for key in event.keys():
                if key != 'particles':
                    event[key] = event[key][0]
            event_id = event['index']

        self.event = event

        # Filter out ghosts
        self.filter_ghosts(event)

        # Reconstruct energy
        self.reconstruct_energy(event)

        # Identify shower starting points, skip if there is less than 2 (no pi0)
        self.find_shower_starts(event)
        if len(self.output['showers']) < 2:
            if self.verbose:
                print('No shower start point found in event', event_id)
            return []
        
        # Match primary shower fragments with each start points
        if self.cfg['shower_dir'] != 'truth' or self.cfg['shower_energy'] == 'cone':
            self.match_primary_fragments(event)
            if not len(self.output['fragments']):
                if self.verbose:
                    print('Could not find a fragment for each start point in event', event_id)
                return []

        # Reconstruct shower direction vectors
        self.reconstruct_shower_directions(event)

        # Reconstruct shower energy
        self.reconstruct_shower_energy(event)

        # Identify pi0 decays
        self.identify_pi0(event)
        if not len(self.output['matches']):
            if self.verbose:
                print('No pi0 found in event', event_id)
            return []

        # Compute masses
        masses = self.pi0_mass()

        # Log masses
        for i, m in enumerate(masses):
            self.log(event_id, i, m)

    def filter_ghosts(self, event):
        '''
        Removes ghost points from the charge tensor
        '''
        if self.cfg['input'] == 'energy':
            self.output['segment'] = event['segment_label_true']
            self.output['group'] = event['group_label_true']
            self.output['dbscan'] = event['dbscan_label_true']

        elif self.cfg['segment'] == 'mask':
            mask = np.where(event['segment_label_reco'][:,-1] != 5)[0]
            self.output['charge'] = event['charge'][mask]
            self.output['segment'] = event['segment_label_reco'][mask]
            self.output['group'] = event['group_label_reco'] # group_label_reco is wrong, so no masking, TODO
            self.output['dbscan'] = event['dbscan_label_reco'][mask]

        elif self.cfg['segment'] == 'uresnet':
            # Get the segmentation output of the network
            res = self.output['forward']

            # Argmax to determine most probable label
            pred_ghost = np.argmax(res['ghost'][0], axis=1)
            pred_labels = np.argmax(res['segmentation'][0], axis=1)
            mask = np.where(pred_ghost == 0)[0]
            self.output['charge'] = event['charge'][mask]
            self.output['segment'] = copy(event['segment_label_reco'])
            self.output['segment'][:,-1] = pred_labels
            self.output['segment'] = self.output['segment'][mask]
            self.output['group'] = event['group_label_reco'] # group_label_reco is wrong, so no masking, TODO
            self.output['dbscan'] = event['dbscan_label_reco'][mask]

        else:
            raise ValueError('Semantic segmentation method not recognized:', self.cfg['segment'])

    def reconstruct_energy(self, event):
        '''
        Reconstructs energy deposition from charge
        '''
        if self.cfg['input'] == 'energy':
            self.output['energy'] = event['energy']

        elif self.cfg['response'] == 'constant':
            reco = self.cfg['response_cst']*self.output['charge'][:,-1]
            self.output['energy'] = copy(self.output['charge'])
            self.output['energy'][:,-1] = reco

        elif self.cfg['response'] == 'average':
            self.output['energy'] = copy(self.output['charge'])
            self.output['energy'][:,-1] = self.cfg['response_average']

        elif self.cfg['response'] == 'full':
            raise NotImplementedError('Proper energy reconstruction not implemented yet')

        elif self.cfg['response'] == 'enet':
            raise NotImplementedError('ENet not implemented yet')

        else:
            raise ValueError('Energy reconstruction method not recognized:', self.cfg['response'])

    def find_shower_starts(self, event):
        '''
        Identify starting points of showers
        '''
        if self.cfg['shower_start'] == 'truth':
            # Get the true shower starting points from the particle information
            self.output['showers'] = []
            for i, part in enumerate(event['particles'][0]):
                if self.is_shower(part):
                    new_shower = Shower(start=[part.first_step().x(), part.first_step().y(), part.first_step().z()], pid=i)
                    self.output['showers'].append(new_shower)

        elif self.cfg['shower_start'] == 'ppn':
            raise NotImplementedError('PPN not implemented yet')

        else:
            raise ValueError('EM shower primary identifiation method not recognized:', self.cfg['shower_start'])
            
    def match_primary_fragments(self, event):
        '''
        For each shower start point, find the closest DBSCAN shower cluster
        '''
        # Mask out points that are not showers
        shower_mask = np.where(self.output['segment'][:,-1] == 2)[0]
        if not len(shower_mask):
            self.output['fragments'] = []
            return
        
        # Assign clusters
        points = np.array([s.start for s in self.output['showers']])
        clusts = self.frag_est.assign_frags_to_primary(self.output['energy'][shower_mask], points)
        if len(clusts) != len(points):
            self.output['fragments'] = []
            return
        
        # Return list of voxel indices for each cluster
        self.output['fragments'] = clusts

    def reconstruct_shower_directions(self, event):
        '''
        Reconstructs the direction of the showers
        '''
        if self.cfg['shower_dir'] == 'truth':
            for shower in self.output['showers']:
                part = event['particles'][0][shower.pid]
                mom = [part.px(), part.py(), part.pz()]
                shower.direction = list(np.array(mom)/np.linalg.norm(mom))

        elif self.cfg['shower_dir'] == 'pca' or self.cfg['shower_dir'] == 'cent':
            # Apply DBSCAN, PCA on the touching cluster to get angles
            algo = self.cfg['shower_dir']
            mask = np.where(self.output['segment'][:,-1] == 2)[0]
            points = np.array([s.start for s in self.output['showers']])
            try:
                res = self.dir_est.get_directions(self.output['energy'][mask], 
                    points, self.output['fragments'], max_distance=float('inf'), mode=algo)
            except AssertionError as err: # Cluster was not found for at least one primary
                if self.verbose:
                    print('Error in direction reconstruction:', err)
                res = [[0., 0., 0.] for _ in range(len(self.output['showers']))]
                    
            for i, shower in enumerate(self.output['showers']):
                shower.direction = res[i]

        else:
            raise ValueError('Shower direction reconstruction method not recognized:', self.cfg['shower_dir'])

    def reconstruct_shower_energy(self, event):
        '''
        Clusters the different showers, reconstruct energy of each shower
        '''
        if self.cfg['shower_energy'] == 'truth':
            # Gets the true energy information from Geant4
            for shower in self.output['showers']:
                part = event['particles'][0][shower.pid]
                shower.energy = part.energy_init()
                pid = shower.pid
                mask = np.where(self.output['group'][:,-1] == pid)[0]
                shower.voxels = mask

        elif self.cfg['shower_energy'] == 'group':
            # Gets all the voxels in the group corresponding to the pid, adds up energy
            for shower in self.output['showers']:
                pid = shower.pid
                mask = np.where(self.output['group'][:,-1] == pid)[0]
                shower.voxels = mask
                shower.energy = np.sum(self.output['energy'][mask][:,-1])

        elif self.cfg['shower_energy'] == 'cone':
            # Fits cones to each shower, adds energies within that cone
            points = np.array([s.start for s in self.output['showers']])
            dirs = np.array([s.direction for s in self.output['showers']])
            mask = np.where(self.output['segment'][:,-1] == 2)[0]
            try:
                pred = self.clusterer.fit_predict(self.output['energy'][mask,:3], points, self.output['fragments'], dirs)
            except (ValueError, AssertionError):
                for i, shower in enumerate(self.output['showers']):
                    shower.voxels = []
                    shower.energy = 0.
                return
            padded_pred = np.full(len(self.output['segment']), -1)
            padded_pred[mask] = pred
            for i, shower in enumerate(self.output['showers']):
                shower_mask = np.where(padded_pred == i)[0]
                if not len(shower_mask):
                    shower.energy = 0.
                    continue
                shower.voxels = shower_mask
                shower.energy = np.sum(self.output['energy'][shower_mask][:,-1])

        else:
            raise ValueError('Shower energy reconstruction method not recognized:', self.cfg['shower_energy'])

    def identify_pi0(self, event):
        '''
        Proposes pi0 candidates (match two showers)
        '''
        self.output['matches'] = []
        self.output['vertices'] = []
        n_showers = len(self.output['showers'])
        if self.cfg['shower_match'] == 'truth':
            # Get the creation point of each particle. If two gammas originate from the same point,
            # It is most likely a pi0 decay.
            creations = []
            for shower in self.output['showers']:
                part = event['particles'][0][shower.pid]
                creations.append([part.position().x(), part.position().y(), part.position().z()])

            for i, ci in enumerate(creations):
                for j in range(i+1,n_showers):
                    if (np.array(ci) == np.array(creations[j])).all():
                        self.output['matches'].append([i,j])
                        self.output['vertices'].append(ci)

            return self.output['matches']

        elif self.cfg['shower_match'] == 'proximity':
            # Pair closest shower vectors
            points = np.array([s.start for s in self.output['showers']])
            dirs = np.array([s.direction for s in self.output['showers']])
            try:
                self.output['matches'], self.output['vertices'], dists =\
                    self.matcher.find_matches(points, dirs, self.output['segment'])
            except ValueError as err:
                if self.verbose:
                    print('Error in PID:', err)
                return

            if self.cfg['refit_dir']:
                for i, m in enumerate(self.output['matches']):
                    v = np.array(self.output['vertices'][i])
                    for j in m:
                        new_dir = np.array(points[j]) - v
                        self.output['showers'][j].direction = new_dir/np.linalg.norm(new_dir)

            if self.cfg['shower_energy'] == 'cone' and self.cfg['refit_cone']:
                self.reconstruct_shower_energy(event)

        else:
            raise ValueError('Shower matching method not recognized:', self.cfg['shower_match'])

    def pi0_mass(self):
        '''
        Reconstructs the pi0 mass
        '''
        from math import sqrt
        masses = []
        for match in self.output['matches']:
            s1, s2 = self.output['showers'][match[0]], self.output['showers'][match[1]]
            e1, e2 = s1.energy, s2.energy
            t1, t2 = s1.direction, s2.direction
            costheta = np.dot(t1, t2)
            if abs(costheta) > 1.:
                masses.append(0.)
                continue
            masses.append(sqrt(2*e1*e2*(1-costheta)))
        self.output['masses'] = masses
        return masses

    def draw(self):
        from mlreco.visualization import plotly_layout3d
        from mlreco.visualization.voxels import scatter_voxels, scatter_label
        import plotly.plotly as py
        import plotly.graph_objs as go
        from plotly.offline import init_notebook_mode, iplot
        init_notebook_mode(connected=False)

        # Create labels for the voxels
        # Use a different color for each cluster
        labels = np.full(len(self.output['energy'][:,-1]), -1)
        for i, s in enumerate(self.output['showers']):
            labels[s.voxels] = i

        # Draw voxels with cluster labels
        voxels = self.output['energy'][:,:3]
        graph_voxels = scatter_label(voxels, labels, 2)[0]
        graph_voxels.name = 'Shower ID'
        graph_data = [graph_voxels]

        if len(self.output['showers']):
            # Add EM primary points
            points = np.array([s.start for s in self.output['showers']])
            graph_start = scatter_voxels(points)[0]
            graph_start.name = 'Shower starts'
            graph_data.append(graph_start)
            
            # Add EM primary directions
            dirs = np.array([s.direction for s in self.output['showers']])
            arrows = go.Cone(x=points[:,0], y=points[:,1], z=points[:,2], 
                             u=dirs[:,0], v=dirs[:,1], w=dirs[:,2],
                             sizemode='absolute', sizeref=0.25, anchor='tail',
                             showscale=False, opacity=0.4)
            graph_data.append(arrows)

            # Add a vertex if matches, join vertex to start points
            for i, m in enumerate(self.output['matches']):
                v = self.output['vertices'][i]
                s1, s2 = self.output['showers'][m[0]].start, self.output['showers'][m[1]].start
                points = [v, s1, v, s2]
                line = scatter_voxels(np.array(points))[0]
                line.name = 'Pi0 Decay'
                line.mode = 'lines,markers'
                graph_data.append(line)

        # Draw
        iplot(go.Figure(data=graph_data,layout=plotly_layout3d()))

    @staticmethod
    def is_shower(particle):
        '''
        Check if the particle is a shower
        '''
        pdg_code = abs(particle.pdg_code())
        if pdg_code == 22 or pdg_code == 11:
            return True
        return False
コード例 #9
0
def store_input(cfg, data_blob, res, logdir, iteration):
    """
    Store input data blob.

    Configuration
    -------------
    threshold: float, optional
        Default: 0.
    input_data: str, optional
    particles_label: str, optional
    segment_label: str, optional
    clusters_label: str, optional
    cluster3d_mcst_true: str, optional
    store_method: str, optional
        Can be `per-iteration` or `per-event`
    """
    method_cfg = cfg['post_processing']['store_input']

    if (method_cfg is not None
            and not method_cfg.get('input_data', 'input_data') in data_blob
        ) or (method_cfg is None and 'input_data' not in data_blob):
        return

    threshold = 0. if method_cfg is None else method_cfg.get('threshold', 0.)
    data_dim = 3 if method_cfg is None else method_cfg.get('data_dim', 3)

    index = data_blob.get('index', None)
    input_dat = data_blob.get(
        'input_data' if method_cfg is None else method_cfg.get(
            'input_data', 'input_data'), None)
    label_ppn = data_blob.get(
        'particles_label' if method_cfg is None else method_cfg.get(
            'particles_label', 'particles_label'), None)
    label_seg = data_blob.get(
        'segment_label' if method_cfg is None else method_cfg.get(
            'segment_label', 'segment_label'), None)
    label_cls = data_blob.get(
        'clusters_label' if method_cfg is None else method_cfg.get(
            'clusters_label', 'clusters_label'), None)
    label_mcst = data_blob.get(
        'cluster3d_mcst_true' if method_cfg is None else method_cfg.get(
            'cluster3d_mcst_true', 'cluster3d_mcst_true'), None)

    store_per_iteration = True
    if method_cfg is not None and method_cfg.get('store_method',
                                                 None) is not None:
        assert (method_cfg['store_method'] in ['per-iteration', 'per-event'])
        store_per_iteration = method_cfg['store_method'] == 'per-iteration'
    fout = None
    if store_per_iteration:
        fout = CSVData(os.path.join(logdir, 'input-iter-%07d.csv' % iteration))

    if input_dat is None: return

    for data_index, tree_index in enumerate(index):

        if not store_per_iteration:
            fout = CSVData(
                os.path.join(logdir, 'input-event-%07d.csv' % tree_index))

        mask = input_dat[data_index][:, -1] > threshold

        # type 0 = input data
        for row in input_dat[data_index][mask]:
            coords_labels, coords = get_coords(row, data_dim, tree_index)
            fout.record(coords_labels + ('type', 'value'),
                        coords + (0, row[data_dim + 1]))
            fout.write()

        # type 1 = Labels for PPN
        if label_ppn is not None:
            for row in label_ppn[data_index]:
                fout.record(('idx', 'x', 'y', 'z', 'type', 'value'),
                            (tree_index, row[0], row[1], row[2], 1, row[4]))
                fout.write()
        # 2 = UResNet labels
        if label_seg is not None:
            for row in label_seg[data_index][mask]:
                coords_labels, coords = get_coords(row, data_dim, tree_index)
                fout.record(coords_labels + ('type', 'value'),
                            coords + (2, row[data_dim + 1]))
                fout.write()
        # type 15 = group id, 16 = semantic labels, 17 = energy
        if label_cls is not None:
            for row in label_cls[data_index]:
                fout.record(('idx', 'x', 'y', 'z', 'type', 'value'),
                            (tree_index, row[0], row[1], row[2], 15, row[5]))
                fout.write()
            for row in label_cls[data_index]:
                fout.record(('idx', 'x', 'y', 'z', 'type', 'value'),
                            (tree_index, row[0], row[1], row[2], 16, row[6]))
                fout.write()
            for row in label_cls[data_index]:
                fout.record(('idx', 'x', 'y', 'z', 'type', 'value'),
                            (tree_index, row[0], row[1], row[2], 17, row[4]))
                fout.write()
        # type 18 = cluster3d_mcst_true
        if label_mcst is not None:
            for row in label_mcst[data_index]:
                fout.record(('idx', 'x', 'y', 'z', 'type', 'value'),
                            (tree_index, row[0], row[1], row[2], 19, row[4]))
                fout.write()

        if not store_per_iteration: fout.close()

    if store_per_iteration: fout.close()
コード例 #10
0
def uresnet_metrics(cfg, data_blob, res, logdir, iteration):
    # UResNet prediction
    if not 'segmentation' in res: return

    method_cfg = cfg['post_processing']['uresnet_metrics']

    index = data_blob['index']
    segment_data = res['segmentation']
    # input_data   = data_blob.get('input_data' if method_cfg is None else method_cfg.get('input_data', 'input_data'), None)
    segment_label = data_blob.get(
        'segment_label' if method_cfg is None else method_cfg.get(
            'segment_label', 'segment_label'), None)
    num_classes = 5 if method_cfg is None else method_cfg.get('num_classes', 5)

    store_per_iteration = True
    if method_cfg is not None and method_cfg.get('store_method',
                                                 None) is not None:
        assert (method_cfg['store_method'] in ['per-iteration', 'per-event'])
        store_per_iteration = method_cfg['store_method'] == 'per-iteration'
    fout = None
    if store_per_iteration:
        fout = CSVData(
            os.path.join(logdir, 'uresnet-metrics-iter-%07d.csv' % iteration))

    for data_idx, tree_idx in enumerate(index):

        if not store_per_iteration:
            fout = CSVData(
                os.path.join(logdir,
                             'uresnet-metrics-event-%07d.csv' % tree_idx))

        predictions = np.argmax(segment_data[data_idx], axis=1)
        label = segment_label[data_idx][:, -1]

        acc = (predictions == label).sum() / float(len(label))
        class_acc = []
        pix = []
        for c1 in range(num_classes):
            for c2 in range(num_classes):
                class_mask = label == c1
                class_acc.append((predictions[class_mask] == c2).sum() /
                                 float(np.count_nonzero(class_mask)))
                pix.append(
                    np.count_nonzero((label == c1) & (predictions == c2)))
        fout.record(('idx', 'acc') + tuple([
            'confusion_%d_%d' % (c1, c2) for c1 in range(num_classes)
            for c2 in range(num_classes)
        ]) + tuple([
            'num_pix_%d_%d' % (c1, c2) for c1 in range(num_classes)
            for c2 in range(num_classes)
        ]), (tree_idx, acc) + tuple(class_acc) + tuple(pix))
        fout.write()

        if not store_per_iteration: fout.close()

    if store_per_iteration: fout.close()
コード例 #11
0
def track_clustering(cfg, data_blob, res, logdir, iteration):
    """
    Track clustering on PPN+UResNet output.

    Parameters
    ----------
    data_blob: dict
        The input data dictionary from iotools.
    res: dict
        The output of the network, formatted using `analysis_keys`.
    cfg: dict
        Configuration.
    debug: bool, optional
        Whether to print some stats or not in the stdout.

    Notes
    -----
    Based on
    - semantic segmentation output
    - point position and type predictions
    In addition, includes a break algorithm to break track clusters into
    smaller clusters based on predicted points.
    Stores all points and informations in a CSV file.
    """
    method_cfg = cfg['post_processing']['track_clustering']
    dbscan_cfg = cfg['model']['modules']['dbscan']
    data_dim = int(dbscan_cfg['data_dim'])
    min_samples = int(dbscan_cfg['minPoints'])

    debug = bool(method_cfg.get('debug', None))
    score_threshold = float(method_cfg.get('score_threshold', 0.6))
    threshold_association = float(method_cfg.get('threshold_association', 3))
    exclusion_radius = float(method_cfg.get('exclusion_radius', 5))
    type_threshold = float(method_cfg.get('type_threshold', 2))

    store_per_iteration = True
    if method_cfg is not None and method_cfg.get('store_method',
                                                 None) is not None:
        assert (method_cfg['store_method'] in ['per-iteration', 'per-event'])
        store_per_iteration = method_cfg['store_method'] == 'per-iteration'
    fout = None
    if store_per_iteration:
        fout = CSVData(
            os.path.join(logdir, 'track-clustering-iter-%07d.csv' % iteration))

    # Loop over batch index
    #for b in batch_ids:
    for batch_index, data in enumerate(data_blob['input_data']):

        if not store_per_iteration:
            fout = CSVData(
                os.path.join(logdir,
                             'track-clustering-event-%07d.csv' % event_index))

        event_clusters = res['final'][batch_index]
        event_index = data_blob['index'][batch_index]
        event_data = data[:, :data_dim]
        event_clusters_label = data_blob['clusters_label'][batch_index]
        event_particles_label = data_blob['particles_label'][batch_index]
        event_segmentation = res['segmentation'][batch_index]
        event_segmentation_label = data_blob['segment_label'][batch_index]
        points = res['points'][batch_index]
        event_xyz = points[:, :data_dim]
        event_scores = points[:, data_dim:data_dim + 2]
        event_mask = res['mask_ppn2'][batch_index]

        anchors = (event_data + 0.5)
        event_xyz = event_xyz + anchors

        dbscan_points = []
        predicted_points = []

        # 0) Postprocessing on predicted pixels
        # Apply selection mask from PPN2 + score thresholding
        scores = scipy.special.softmax(event_scores, axis=1)
        event_mask = ((~(event_mask == 0)).any(
            axis=1)) & (scores[:, 1] > score_threshold)
        # Now loop through semantic classes and look at ppn+uresnet predictions
        uresnet_predictions = np.argmax(event_segmentation[event_mask], axis=1)
        num_classes = event_segmentation.shape[1]
        ppn_type_predictions = np.argmax(scipy.special.softmax(
            points[event_mask][:, 5:], axis=1),
                                         axis=1)
        for c in range(num_classes):
            uresnet_points = uresnet_predictions == c
            ppn_points = ppn_type_predictions == c
            # We want to keep only points of type X within 2px of uresnet prediction of type X
            d = scipy.spatial.distance.cdist(
                event_xyz[event_mask][ppn_points],
                event_data[event_mask][uresnet_points])
            ppn_mask = (d < type_threshold).any(axis=1)
            # dbscan_points stores coordinates only
            # predicted_points stores everything for each point
            dbscan_points.append(event_xyz[event_mask][ppn_points][ppn_mask])
            pp = points[event_mask][ppn_points][ppn_mask]
            pp[:, :3] += anchors[event_mask][ppn_points][ppn_mask]
            predicted_points.append(pp)
        dbscan_points = np.concatenate(dbscan_points, axis=0)
        predicted_points = np.concatenate(predicted_points, axis=0)

        # 0.5) Remove points for delta rays
        point_types = np.argmax(predicted_points[:, -5:], axis=1)
        dbscan_points = dbscan_points[point_types != 3]

        # 1) Break algorithm
        # Using PPN point predictions, the idea is to mask an area around
        # each point associated with a given predicted cluster. Dbscan
        # then tells us whether this predicted cluster should be broken in
        # two or more smaller clusters. Pixel that were masked are then
        # assigned to the closest cluster among the newly formed clusters.
        if dbscan_points.shape[0] > 0:  # If PPN predicted some points
            cluster_ids = np.unique(event_clusters[:, -1])
            final_clusters = []
            # Loop over predicted clusters
            for c in cluster_ids:
                # Find predicted points associated to this predicted cluster
                cluster = event_clusters[event_clusters[:,
                                                        -1] == c][:, :data_dim]
                d = cdist(dbscan_points, cluster)
                index = d.min(axis=1) < threshold_association
                new_d = d[index.reshape((-1, )), :]
                # Now mask around these points
                new_index = (new_d > exclusion_radius).all(axis=0)
                # Main body of the cluster (far way from the points)
                new_cluster = cluster[new_index]
                # Cluster part around the points
                remaining_cluster = cluster[~new_index]
                # FIXME this might eliminate too small clusters?
                # put a threshold here? sometimes clusters with 1px only
                if new_cluster.shape[0] == 0:
                    continue
                # Now dbscan on the main body of the cluster to find if we need
                # to break it or not
                db2 = DBSCAN(eps=exclusion_radius,
                             min_samples=min_samples).fit(new_cluster).labels_
                # All points were garbage
                if (len(new_cluster[db2 == -1]) == len(new_cluster)):
                    continue
                # These are going to be the new bodies of predicted clusters
                new_cluster_ids = np.unique(db2)
                new_clusters = []
                for c2 in new_cluster_ids:
                    if c2 > -1:
                        new_clusters.append([new_cluster[db2 == c2]])
                # If some points were left by dbscan, put them in remaining
                # cluster and assign them to closest cluster
                if len(new_cluster[db2 == -1]) > 0:
                    print(len(new_cluster[db2 == -1]), len(new_cluster))
                    remaining_cluster = np.concatenate(
                        [remaining_cluster, new_cluster[db2 == -1]], axis=0)
                    # effectively remove them from new_cluster for the argmin
                    new_cluster[db2 == -1] = 100000
                # Now assign remaining pixels in remaining_cluster based on
                # their distance to the new clusters.
                # First we find which point of new_cluster was closest
                d3 = cdist(remaining_cluster, new_cluster)
                # Then we find what is the corresponding new cluster id of this
                # closest pixel
                remaining_db = db2[d3.argmin(axis=1)]
                # Now append each pixel of remaining_cluster to correct new
                # cluster
                for i, c in enumerate(remaining_cluster):
                    new_clusters[remaining_db[i]].append(c[None, :])
                # Turn everything into np arrays
                for i in range(len(new_clusters)):
                    new_clusters[i] = np.concatenate(new_clusters[i], axis=0)
                final_clusters.extend(new_clusters)
        else:  # no predicted points: no need to break, keep predicted clusters
            final_clusters = []
            cluster_idx = np.unique(event_clusters[:, -1])
            for c in cluster_idx:
                final_clusters.append(
                    event_clusters[event_clusters[:, -1] == c][:, :data_dim])

        # 2) Compute cluster efficiency/purity
        # ie associate final clusters after breaking with true clusters
        label_cluster_ids = np.unique(event_clusters_label[:, -1])
        true_clusters = []
        for c in label_cluster_ids:
            true_clusters.append(
                event_clusters_label[event_clusters_label[:, -1] == c][:, :-2])

        # Match each predicted cluster to a true cluster
        matches = []
        overlaps = []
        for predicted_cluster in final_clusters:
            overlap = []
            for true_cluster in true_clusters:
                overlap_pixel_count = np.count_nonzero(
                    (cdist(predicted_cluster, true_cluster) < 1).any(axis=0))
                overlap.append(overlap_pixel_count)
            overlap = np.array(overlap)
            if overlap.max() > 0:
                matches.append(overlap.argmax())
                overlaps.append(overlap.max())
            else:
                matches.append(-1)
                overlaps.append(0)

        # Compute cluster purity/efficiency
        purity, efficiency = [], []
        npix_predicted, npix_true = [], []
        for i, predicted_cluster in enumerate(final_clusters):
            if matches[i] > -1:
                matched_cluster = true_clusters[matches[i]]
                purity.append(overlaps[i] / predicted_cluster.shape[0])
                efficiency.append(overlaps[i] / matched_cluster.shape[0])
                npix_predicted.append(predicted_cluster.shape[0])
                npix_true.append(matched_cluster.shape[0])

        if debug:
            print("Purity: ", purity)
            print("Efficiency: ", efficiency)
            print("Match indices: ", matches)
            print("Overlaps: ", overlaps)
            print("Npix predicted: ", npix_predicted)
            print("Npix true: ", npix_true)

        # Record in CSV everything
        # Point in data and semantic class predictions/true information
        for i, point in enumerate(data):
            fout.record(
                ('type', 'x', 'y', 'z', 'batch_id', 'value', 'predicted_class',
                 'true_class', 'cluster_id', 'point_type', 'idx'),
                (0, point[0], point[1], point[2], batch_index, point[4],
                 np.argmax(event_segmentation[i]),
                 event_segmentation_label[i, -1], -1, -1, event_index))
            fout.write()
        # Predicted clusters
        for c, cluster in enumerate(final_clusters):
            for point in cluster:
                fout.record(
                    ('type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value',
                     'predicted_class', 'true_class', 'point_type', 'idx'),
                    (1, point[0], point[1], point[2], batch_index, c, -1, -1,
                     -1, -1, event_index))
                fout.write()
        # True clusters
        for c, cluster in enumerate(true_clusters):
            for point in cluster:
                fout.record(
                    ('type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value',
                     'predicted_class', 'true_class', 'point_type', 'idx'),
                    (2, point[0], point[1], point[2], batch_index, c, -1, -1,
                     -1, -1, event_index))
                fout.write()
        # for point in event_xyz:
        #     fout.record(('type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value', 'predicted_class', 'true_class', 'point_type'),
        #                       (3, point[0], point[1], point[2], batch_index, -1, -1, -1, -1, -1))
        #     fout.write()
        # for point in event_clusters_label:
        #     fout.record(('type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value', 'predicted_class', 'true_class', 'point_type'),
        #                       (4, point[0], point[1], point[2], batch_index, point[4], -1, -1, -1, -1))
        #     fout.write()
        # True PPN points
        for point in event_particles_label:
            fout.record(
                ('type', 'x', 'y', 'z', 'batch_id', 'point_type', 'cluster_id',
                 'value', 'predicted_class', 'true_class', 'idx'),
                (5, point[0], point[1], point[2], batch_index, point[4], -1,
                 -1, -1, -1, event_index))
            fout.write()
        # Predicted PPN points
        for point in predicted_points:
            fout.record(
                ('type', 'x', 'y', 'z', 'batch_id', 'predicted_class', 'value',
                 'true_class', 'cluster_id', 'point_type', 'idx'),
                (6, point[0], point[1], point[2], batch_index,
                 np.argmax(point[-5:]), -1, -1, -1, -1, event_index))
            fout.write()

        if not store_per_iteration: fout.close()

    if store_per_iteration: fout.close()