def __init__(self, io_cfg, chain_cfg, verbose=False): ''' Initializes the chain from the configuration file ''' # Initialize the data loader io_cfg = yaml.load(io_cfg, Loader=yaml.Loader) # Save config, initialize output self.cfg = chain_cfg self.verbose = verbose self.output = {} # Initialize log log_path = chain_cfg['name'] + '_log.csv' print('Initialized Pi0 mass chain, log path:', log_path) self._log = CSVData(log_path) self._keys = ['event_id', 'pion_id', 'pion_mass'] # If a network is specified, initialize the network self.network = False if chain_cfg['segment'] == 'uresnet' or chain_cfg[ 'shower_start'] == 'ppn': self.network = True with open(chain_cfg['net_cfg']) as cfg_file: net_cfg = yaml.load(cfg_file, Loader=yaml.Loader) io_cfg['model'] = net_cfg['model'] io_cfg['trainval'] = net_cfg['trainval'] # Pre-process configuration process_config(io_cfg) # Instantiate "handlers" (IO tools) self.hs = prepare(io_cfg) self.data_set = iter(self.hs.data_io)
def __init__(self, io_cfg, chain_cfg, verbose=False): ''' Initializes the chain from the configuration file ''' # Initialize the data loader io_cfg = yaml.load(io_cfg, Loader=yaml.Loader) # Save config, initialize output self.cfg = chain_cfg self.verbose = verbose self.event = None self.output = {} # Initialize log log_path = chain_cfg['name'] + '_log.csv' print('Initialized Pi0 mass chain, log path:', log_path) self._log = CSVData(log_path) self._keys = ['event_id', 'pion_id', 'pion_mass'] # If a network is specified, initialize the network self.network = False if chain_cfg['segment'] == 'uresnet' or chain_cfg[ 'shower_start'] == 'ppn': self.network = True with open(chain_cfg['net_cfg']) as cfg_file: net_cfg = yaml.load(cfg_file, Loader=yaml.Loader) io_cfg['model'] = net_cfg['model'] io_cfg['trainval'] = net_cfg['trainval'] # Initialize the fragment identifier self.frag_est = FragmentEstimator() # If a direction estimator is requested, initialize it if chain_cfg['shower_dir'] != 'truth': self.dir_est = DirectionEstimator() # If a clusterer is requested, initialize it if chain_cfg['shower_energy'] == 'cone': self.clusterer = ConeClusterer() # If a pi0 identifier is requested, initialize it if chain_cfg['shower_match'] == 'proximity': self.matcher = Pi0Matcher() # Pre-process configuration process_config(io_cfg) # Instantiate "handlers" (IO tools) self.hs = prepare(io_cfg) self.data_set = iter(self.hs.data_io)
class Pi0Chain(): def __init__(self, io_cfg, chain_cfg, verbose=False): ''' Initializes the chain from the configuration file ''' # Initialize the data loader io_cfg = yaml.load(io_cfg, Loader=yaml.Loader) # Save config, initialize output self.cfg = chain_cfg self.verbose = verbose self.output = {} # Initialize log log_path = chain_cfg['name'] + '_log.csv' print('Initialized Pi0 mass chain, log path:', log_path) self._log = CSVData(log_path) self._keys = ['event_id', 'pion_id', 'pion_mass'] # If a network is specified, initialize the network self.network = False if chain_cfg['segment'] == 'uresnet' or chain_cfg[ 'shower_start'] == 'ppn': self.network = True with open(chain_cfg['net_cfg']) as cfg_file: net_cfg = yaml.load(cfg_file, Loader=yaml.Loader) io_cfg['model'] = net_cfg['model'] io_cfg['trainval'] = net_cfg['trainval'] # Pre-process configuration process_config(io_cfg) # Instantiate "handlers" (IO tools) self.hs = prepare(io_cfg) self.data_set = iter(self.hs.data_io) def hs(self): return self.hs def data_set(self): return self.data_set def log(self, eid, pion_id, pion_mass): self._log.record(self._keys, [eid, pion_id, pion_mass]) self._log.write() self._log.flush() def run(self): ''' Runs the full Pi0 reconstruction chain, from 3D charge information to Pi0 masses for events that contain one or more Pi0 decay. ''' for i in range(len(self.hs.data_io)): self.run_loop() def run_loop(self): ''' Runs the full Pi0 reconstruction chain on a single event, from 3D charge information to Pi0 masses for events that contain one or more Pi0 decay. ''' # Reset output self.output = {} # Load data if not self.network: event = next(self.data_set) event_id = event['index'][0] else: event, self.output['forward'] = self.hs.trainer.forward( self.data_set) for key in event.keys(): if key != 'particles': event[key] = event[key][0] event_id = event['index'] # Filter out ghosts self.filter_ghosts(event) # Reconstruct energy self.reconstruct_energy(event) # Identify shower starting points self.find_shower_starts(event) if not len(self.output['showers']): if self.verbose: print('No shower start point found in event', event_id) return [] # Reconstruct shower direction vectors self.reconstruct_shower_directions(event) # Reconstruct shower energy self.reconstruct_shower_energy(event) # Identify pi0 decays self.identify_pi0(event) if not len(self.output['matches']): if self.verbose: print('No pi0 found in event', event_id) return [] # Compute masses masses = self.pi0_mass() # Log masses for i, m in enumerate(masses): self.log(event_id, i, m) def filter_ghosts(self, event): ''' Removes ghost points from the charge tensor ''' if self.cfg['input'] == 'energy': self.output['segment'] = event['segment_label_true'] self.output['group'] = event['group_label_true'] self.output['dbscan'] = event['dbscan_label_true'] elif self.cfg['segment'] == 'mask': self.output['segment'] = event['segment_label_reco'] mask = np.where(self.output['segment'][:, 4] != 5) self.output['charge'] = event['charge'][mask] self.output['group'] = event['group_label_reco'] self.output['dbscan'] = event['dbscan_label_reco'][mask] elif self.cfg['segment'] == 'uresnet': # Get the segmentation output of the network res = self.output['forward']['segmentation'][0] # Argmax to determine most probable label pred_labels = np.argmax(res, axis=1) mask = np.where(pred_labels != 5) self.output['charge'] = event['charge'][mask] self.output['segment'] = copy(event['segment_label_reco']) self.output['segment'][:, 4] = pred_labels self.output['group'] = event['group_label_reco'] self.output['dbscan'] = event['dbscan_label_reco'][mask] else: raise ValueError('Semantic segmentation method not recognized:', self.cfg['segment']) def reconstruct_energy(self, event): ''' Reconstructs energy deposition from charge ''' if self.cfg['input'] == 'energy': self.output['energy'] = event['energy'] elif self.cfg['response'] == 'constant': reco = self.cfg['response_cst'] * event['charge'][:, 4] self.output['energy'] = copy(event['charge']) self.output['energy'][:, 4] = reco elif self.cfg['response'] == 'full': raise NotImplementedError( 'Proper energy reconstruction not implemented yet') elif self.cfg['response'] == 'enet': raise NotImplementedError('ENet not implemented yet') else: raise ValueError('Energy reconstruction method not recognized:', self.cfg['response']) def find_shower_starts(self, event): ''' Identify starting points of showers ''' if self.cfg['shower_start'] == 'truth': # Get the true shower starting points from the particle information self.output['showers'] = [] for i, part in enumerate(event['particles'][0]): if self.is_shower(part): new_shower = Shower(start=[ part.first_step().x(), part.first_step().y(), part.first_step().z() ], pid=i) self.output['showers'].append(new_shower) elif self.cfg['shower_start'] == 'ppn': raise NotImplementedError('PPN not implemented yet') else: raise ValueError( 'EM shower primary identifiation method not recognized:', self.cfg['shower_start']) def reconstruct_shower_directions(self, event): ''' Reconstructs the direction of the showers ''' if self.cfg['shower_dir'] == 'truth': for shower in self.output['showers']: part = event['particles'][0][shower.pid] mom = [part.px(), part.py(), part.pz()] shower.direction = list(np.array(mom) / np.linalg.norm(mom)) elif self.cfg['shower_dir'] == 'pca': # Apply DBSCAN, PCA on the touching cluster to get angles points = np.array([ s.start + [0., s.pid] + [0., 0., 0.] for s in self.output['showers'] ]) res, _, _ = gamma_direction.do_calculation(self.output['segment'], points) for i, shower in enumerate(self.output['showers']): if np.linalg.norm(res[i][-3:]) == 0.: shower.direction = [0., 0., 0.] continue shower.direction = list(res[i][-3:] / np.linalg.norm(res[i][-3:])) else: raise ValueError( 'Shower direction reconstruction method not recognized:', self.cfg['shower_dir']) def reconstruct_shower_energy(self, event): ''' Clusters the different showers, reconstruct energy of each shower ''' if self.cfg['shower_energy'] == 'truth': # Gets the true energy information from Geant4 for shower in self.output['showers']: part = event['particles'][0][shower.pid] shower.energy = part.energy_init() pid = shower.pid mask = np.where(self.output['group'][:, -1] == pid)[0] shower.voxels = mask elif self.cfg['shower_energy'] == 'group': # Gets all the voxels in the group corresponding to the pid, adds up energy for shower in self.output['showers']: pid = shower.pid mask = np.where(self.output['group'][:, -1] == pid)[0] shower.voxels = mask shower.energy = np.sum(self.output['energy'][mask][:, -1]) elif self.cfg['shower_energy'] == 'cone': # Fits cones to each shower, adds energies within that cone points = np.array( [s.start + [0., s.pid] for s in self.output['showers']]) res = cone_clusterer.find_shower_cone( self.output['dbscan'], self.output['group'], points, self.output['energy'], self.output['segment'])[ 0] # This returns one array of voxel ids per primary for i, shower in enumerate(self.output['showers']): if not len(res[i]): shower.energy = 0. continue shower.voxels = res[i] shower.energy = np.sum(self.output['energy'][res[i]][:, 4]) else: raise ValueError( 'Shower energy reconstruction method not recognized:', self.cfg['shower_energy']) def identify_pi0(self, event): ''' Proposes pi0 candidates (match two showers) ''' self.output['matches'] = [] self.output['vertices'] = [] n_showers = len(self.output['showers']) if self.cfg['shower_match'] == 'truth': # Get the creation point of each particle. If two gammas originate from the same point, # It is most likely a pi0 decay. creations = [] for shower in self.output['showers']: part = event['particles'][0][shower.pid] creations.append([ part.position().x(), part.position().y(), part.position().z() ]) for i, ci in enumerate(creations): for j in range(i + 1, n_showers): if (np.array(ci) == np.array(creations[j])).all(): self.output['matches'].append([i, j]) self.output['vertices'].append(ci) return self.output['matches'] elif self.cfg['shower_match'] == 'proximity': # Pair closest shower vectors points = np.array([ s.start + [0, s.pid] + s.direction for s in self.output['showers'] ]) event['segment_label'] = self.output['segment'] event['group_label'] = self.output['group'] res, vertices = pi0_pi_selection.generate_pair_labels( event, points, predict=False) for i, v in enumerate(vertices): self.output['matches'].append([0, 1]) # TODO, must ask DHK self.output['vertices'].append(v) else: raise ValueError('Shower matching method not recognized:', self.cfg['shower_match']) def pi0_mass(self): ''' Reconstructs the pi0 mass ''' from math import sqrt masses = [] for match in self.output['matches']: s1, s2 = self.output['showers'][match[0]], self.output['showers'][ match[1]] e1, e2 = s1.energy, s2.energy t1, t2 = s1.direction, s2.direction costheta = np.dot(t1, t2) if abs(costheta) > 1.: masses.append(0.) continue masses.append(sqrt(2 * e1 * e2 * (1 - costheta))) return masses def draw(self): from mlreco.visualization import plotly_layout3d from mlreco.visualization.voxels import scatter_voxels, scatter_label import plotly.plotly as py import plotly.graph_objs as go from plotly.offline import init_notebook_mode, iplot init_notebook_mode(connected=False) # Create labels for the voxels # Use a different color for each cluster labels = np.full(len(self.output['energy'][:, 4]), -1) for i, s in enumerate(self.output['showers']): labels[s.voxels] = i # Draw voxels with cluster labels voxels = self.output['energy'][:, :3] graph_voxels = scatter_label(voxels, labels, 2)[0] graph_voxels.name = 'Shower ID' graph_data = [graph_voxels] if len(self.output['showers']): # Add EM primary points points = np.array([s.start for s in self.output['showers']]) graph_start = scatter_voxels(points)[0] graph_start.name = 'Shower starts' graph_data.append(graph_start) # Add a vertex if matches, join vertex to start points for i, m in enumerate(self.output['matches']): v = self.output['vertices'][i] s1, s2 = self.output['showers'][ m[0]].start, self.output['showers'][m[1]].start points = [v, s1, v, s2] line = scatter_voxels(np.array(points))[0] line.name = 'Pi0 Decay' line.mode = 'lines,markers' graph_data.append(line) # Draw iplot(go.Figure(data=graph_data, layout=plotly_layout3d())) @staticmethod def is_shower(particle): ''' Check if the particle is a shower ''' pdg_code = abs(particle.pdg_code()) if pdg_code == 22 or pdg_code == 11: return True return False
def instance_clustering(cfg, data_blob, res, logdir, iteration): """ Simple DBSCAN on uresnet clustering output for instance segmentation Parameters ---------- data_blob: dict Input dictionary returned by iotools res: dict Results from the network, dictionary using `analysis_keys` cfg: dict Configuration idx: int Iteration number Input ----- Requires the following analysis keys - `segmentation`: output of UResNet segmentation (scores) - `clustering`: coordinates in hyperspace, also output of the network Requires the following data blob keys - `input_data` - `segment_label` UResNet 5 classes label - `cluster_label` Output ------ Writes 2 CSV files: - `instance_clustering-*` with the clustering predictions (point type 0 = event data, point type 1 = predictions, point type 2 = T-SNE visualizations) - `instance_clustering_metrics-*` with some event-wise metrics such as AMI and ARI. """ method_cfg = cfg['post_processing']['instance_clustering'] model_cfg = cfg['model']['modules']['uresnet_clustering'] tsne = TSNE(n_components = 2 if method_cfg is None else method_cfg.get('tsne_dim',2)) compute_tsne = False if method_cfg is None else method_cfg.get('compute_tsne',False) store_per_iteration = True if method_cfg is not None and method_cfg.get('store_method',None) is not None: assert(method_cfg['store_method'] in ['per-iteration','per-event']) store_per_iteration = method_cfg['store_method'] == 'per-iteration' fout_cluster,fout_metric=None,None if store_per_iteration: fout_cluster=CSVData(os.path.join(logdir, 'instance-clustering-iter-%07d.csv' % iteration)) fout_metrics=CSVData(os.path.join(logdir, 'instance-clustering-metrics-iter-%07d.csv' % iteration)) model_cfg = cfg['model']['modules']['uresnet_clustering'] data_dim = model_cfg.get('data_dim', 3) depth = model_cfg.get('num_strides', 5) num_classes = model_cfg.get('num_classes', 5) # Loop over batch index for batch_index, event_data in enumerate(data_blob['input_data']): event_index = data_blob['index'][batch_index] if not store_per_iteration: fout_cluster=CSVData(os.path.join(logdir, 'instance-clustering-iter-%07d.csv' % event_index)) fout_metrics=CSVData(os.path.join(logdir, 'instance-clustering-metrics-iter-%07d.csv' % event_index)) event_segmentation = res['segmentation'][batch_index] event_label = data_blob['segment_label'][batch_index] event_cluster_label = data_blob['cluster_label'][batch_index] max_depth = len(event_cluster_label) for d, event_feature_map in enumerate(res['cluster_feature'][batch_index]): coords = event_feature_map[:, :data_dim] perm = np.lexsort((coords[:, 2], coords[:, 1], coords[:, 0])) coords = coords[perm] class_label = event_label[-(d+1+max_depth-depth)] cluster_count = 0 for class_ in range(num_classes): class_index = class_label[:, -1] == class_ if np.count_nonzero(class_index) == 0: continue clusters_label = event_cluster_label[-(d+1+max_depth-depth)][class_index] embedding = event_feature_map[perm][class_index] # DBSCAN in high dimension embedding predicted_clusters = DBSCAN(eps=20, min_samples=1).fit(embedding).labels_ predicted_clusters += cluster_count # To avoid overlapping id cluster_count += len(np.unique(predicted_clusters)) # Cluster similarity metrics ARI = metrics.adjusted_rand_score(clusters_label[:, -1], predicted_clusters) AMI = metrics.adjusted_mutual_info_score(clusters_label[:, -1], predicted_clusters) fout_metrics.record(('class', 'batch_id', 'AMI', 'ARI', 'idx'), (class_, batch_index, AMI, ARI, event_index)) fout_metrics.write() for i, point in enumerate(clusters_label): fout_cluster.record(('type', 'x', 'y', 'z', 'batch_id', 'value', 'predicted_class', 'true_class', 'true_cluster_id', 'predicted_cluster_id', 'idx'), (1, point[0], point[1], point[2], batch_index, d, -1, class_label[class_index][i, -1], clusters_label[i, -1], predicted_clusters[i], event_index)) fout_cluster.write() # TSNE to visualize embedding if compute_tsne and embedding.shape[0] > 1: new_embedding = tsne.fit_transform(embedding) for i, point in enumerate(new_embedding): fout_cluster.record(('type', 'x', 'y', 'z', 'batch_id', 'value', 'predicted_class', 'true_class', 'true_cluster_id', 'predicted_cluster_id', 'idx'), (2, point[0], point[1], -1, batch_index, d, -1, class_label[class_index][i, -1], clusters_label[i, -1], predicted_clusters[i], event_index)) fout_cluster.write() # Record in CSV everything perm = np.lexsort((event_data[:, 2], event_data[:, 1], event_data[:, 0])) event_data = event_data[perm] event_segmentation = event_segmentation[perm] # Point in data and semantic class predictions/true information for i, point in enumerate(event_data): fout_cluster.record(('type', 'x', 'y', 'z', 'batch_id', 'value', 'predicted_class', 'true_class', 'true_cluster_id', 'predicted_cluster_id', 'idx'), (0, point[0], point[1], point[2], batch_index, point[4], np.argmax(event_segmentation[i]), event_label[0][:,-1][i], -1, -1, event_index)) fout_cluster.write() if not store_per_iteration: fout_cluster.close() fout_metrics.close() if store_per_iteration: fout_cluster.close() fout_metrics.close()
def michel_reconstruction_2d(cfg, data_blob, res, logdir, iteration): """ Very simple algorithm to reconstruct Michel clusters from UResNet semantic segmentation output. Parameters ---------- data_blob: dict Input dictionary returned by iotools res: dict Results from the network, dictionary using `analysis_keys` cfg: dict Configuration idx: int Iteration number Notes ----- Assumes 2D Input ----- Requires the following analysis keys: - `segmentation` output of UResNet Requires the following input keys: - `input_data` - `segment_label` - `particles_label` to get detailed information such as energy. - `clusters_label` from `cluster3d_mcst` for true clusters informations Output ------ Writes 2 CSV files: - `michel_reconstruction-*` - `michel_reconstruction2-*` """ method_cfg = cfg['post_processing']['michel_reconstruction_2d'] # Create output CSV store_per_iteration = True if method_cfg is not None and method_cfg.get('store_method', None) is not None: assert (method_cfg['store_method'] in ['per-iteration', 'per-event']) store_per_iteration = method_cfg['store_method'] == 'per-iteration' fout_reco, fout_true = None, None if store_per_iteration: fout_reco = CSVData( os.path.join( logdir, 'michel-reconstruction-reco-iter-%07d.csv' % iteration)) fout_true = CSVData( os.path.join( logdir, 'michel-reconstruction-true-iter-%07d.csv' % iteration)) # Loop over events for batch_id, data in enumerate(data_blob['input_data']): event_idx = data_blob['index'][batch_id] if not store_per_iteration: fout_reco = CSVData( os.path.join( logdir, 'michel-reconstruction-reco-event-%07d.csv' % event_idx)) fout_true = CSVData( os.path.join( logdir, 'michel-reconstruction-true-event-%07d.csv' % event_idx)) # from input/labels data = data_blob['input_data'][batch_id] label = data_blob['segment_label'][batch_id][:, -1] meta = data_blob['meta'][batch_id] # clusters = data_blob['clusters_label' ][batch_id] # particles = data_blob['particles_label'][batch_id] # Michel_particles = particles[particles[:, 2] == 4] # FIXME 3 or 4 in 2D? Also not sure if type is registered for Michel # from network output segmentation = res['segmentation'][batch_id] predictions = np.argmax(segmentation, axis=1) Michel_label = 3 MIP_label = 0 data_dim = 2 # 0. Retrieve coordinates of true and predicted Michels MIP_coords = data[(label == MIP_label).reshape((-1, )), ...][:, :data_dim] Michel_coords = data[(label == Michel_label).reshape((-1, )), ...][:, :data_dim] # MIP_coords = clusters[clusters[:, -1] == 1][:, :3] # Michel_coords = clusters[clusters[:, -1] == 4][:, :3] if Michel_coords.shape[0] == 0: # FIXME continue MIP_coords_pred = data[(predictions == MIP_label).reshape((-1, )), ...][:, :data_dim] Michel_coords_pred = data[(predictions == Michel_label).reshape( (-1, )), ...][:, :data_dim] # DBSCAN epsilon used for many things... TODO list here #one_pixel = 15#2.8284271247461903 one_pixel_dbscan = 5 one_pixel_is_attached = 2 # 1. Find true particle information matching the true Michel cluster Michel_true_clusters = DBSCAN(eps=one_pixel_dbscan, min_samples=5).fit(Michel_coords).labels_ MIP_true_clusters = DBSCAN(eps=one_pixel_dbscan, min_samples=5).fit(MIP_coords).labels_ # compute all edges of true MIP clusters MIP_edges = [] for cluster in np.unique(MIP_true_clusters[MIP_true_clusters > -1]): touching_idx = find_edges(MIP_coords[MIP_true_clusters == cluster]) MIP_edges.append( MIP_coords[MIP_true_clusters == cluster][touching_idx[0]]) MIP_edges.append( MIP_coords[MIP_true_clusters == cluster][touching_idx[1]]) # Michel_true_clusters = [Michel_coords[Michel_coords[:, -2] == gid] for gid in np.unique(Michel_coords[:, -2])] # Michel_true_clusters = clusters[clusters[:, -1] == 4][:, -2].astype(np.int64) # Michel_start = Michel_particles[:, :data_dim] true_Michel_is_attached = {} true_Michel_is_edge = {} true_Michel_is_too_close = {} for cluster in np.unique(Michel_true_clusters): min_y = Michel_coords[Michel_true_clusters == cluster][:, 1].min() # * meta[-1] + meta[1] max_y = Michel_coords[Michel_true_clusters == cluster][:, 1].max() # * meta[-1] + meta[1] min_x = Michel_coords[Michel_true_clusters == cluster][:, 0].min() # * meta[-2] + meta[0] max_x = Michel_coords[Michel_true_clusters == cluster][:, 0].max() # * meta[-2] + meta[0] # Find coordinates of Michel pixel touching MIP edge Michel_edges_idx = find_edges( Michel_coords[Michel_true_clusters == cluster]) distances = cdist( Michel_coords[Michel_true_clusters == cluster] [Michel_edges_idx], MIP_coords[MIP_true_clusters > -1]) # Make sure true Michel is attached at edge of MIP Michel_min, MIP_min = np.unravel_index(np.argmin(distances), distances.shape) is_attached = np.min(distances) < one_pixel_is_attached is_too_close = np.max(distances) < one_pixel_is_attached # Check whether the Michel is at the edge of a predicted MIP # From the MIP pixel closest to the Michel, remove all pixels in # a radius of 15px. DBSCAN what is left and make sure it is all in # one single piece. is_edge = False # default if is_attached: # cluster id of MIP closest MIP_id = MIP_true_clusters[MIP_true_clusters > -1][MIP_min] # coordinates of closest MIP pixel in this cluster MIP_min_coords = MIP_coords[MIP_true_clusters > -1][MIP_min] # coordinates of the whole cluster MIP_cluster_coords = MIP_coords[MIP_true_clusters == MIP_id] is_edge = is_at_edge(MIP_cluster_coords, MIP_min_coords, one_pixel=one_pixel_dbscan, radius=15.0) true_Michel_is_attached[cluster] = is_attached true_Michel_is_edge[cluster] = is_edge true_Michel_is_too_close[cluster] = is_too_close # these are the coordinates of Michel edges edge1_x = Michel_coords[Michel_true_clusters == cluster][ Michel_edges_idx[0], 0] edge1_y = Michel_coords[Michel_true_clusters == cluster][ Michel_edges_idx[0], 1] edge2_x = Michel_coords[Michel_true_clusters == cluster][ Michel_edges_idx[1], 0] edge2_y = Michel_coords[Michel_true_clusters == cluster][ Michel_edges_idx[1], 1] # Find for each Michel edge the closest MIP pixels # Check for each of these whether they are at the edge of MIP # FIXME what happens if both are at the edge of a MIP?? unlikely closest_MIP_pixels = np.argmin(distances, axis=1) clusters_idx = MIP_true_clusters[closest_MIP_pixels] edge0 = is_at_edge( MIP_coords[MIP_true_clusters == clusters_idx[0]], MIP_coords[closest_MIP_pixels[0]], one_pixel=one_pixel_dbscan, radius=10.0) edge1 = is_at_edge( MIP_coords[MIP_true_clusters == clusters_idx[1]], MIP_coords[closest_MIP_pixels[1]], one_pixel=one_pixel_dbscan, radius=10.0) if edge0 and not edge1: touching_x = edge1_x touching_y = edge1_y elif not edge0 and edge1: touching_x = edge2_x touching_y = edge2_y else: if distances[0, closest_MIP_pixels[0]] < distances[ 1, closest_MIP_pixels[1]]: touching_x = edge1_x touching_y = edge1_y else: touching_x = edge2_x touching_y = edge2_y # touching_idx = np.unravel_index(np.argmin(distances), distances.shape) # touching_x = Michel_coords[Michel_true_clusters == cluster][Michel_edges_idx][touching_idx[0], 0] # touching_y = Michel_coords[Michel_true_clusters == cluster][Michel_edges_idx][touching_idx[0], 1] # # if touching_x not in [edge1_x, edge2_x] or touching_y not in [edge1_y, edge2_y]: # print('true', event_idx, touching_x, touching_y, edge1_x, edge1_y, edge2_x, edge2_y) #if event_idx == 127: # print('true', touching_x, touching_y, edge1_x, edge1_y, edge2_x, edge2_y) fout_true.record( ('batch_id', 'iteration', 'event_idx', 'num_pix', 'sum_pix', 'min_y', 'max_y', 'min_x', 'max_x', 'pixel_width', 'pixel_height', 'meta_min_x', 'meta_min_y', 'touching_x', 'touching_y', 'edge1_x', 'edge1_y', 'edge2_x', 'edge2_y', 'edge0', 'edge1', 'is_attached', 'is_edge', 'is_too_close', 'cluster_id'), ( batch_id, iteration, event_idx, np.count_nonzero(Michel_true_clusters == cluster), data[(label == Michel_label).reshape((-1, )), ...][Michel_true_clusters == cluster][:, -1].sum(), # clusters[clusters[:, -1] == 4][Michel_true_clusters == cluster][:, -3].sum() min_y, max_y, min_x, max_x, meta[-2], meta[-1], meta[0], meta[1], touching_x, touching_y, edge1_x, edge1_y, edge2_x, edge2_y, edge0, edge1, is_attached, is_edge, is_too_close, cluster)) fout_true.write() # e.g. deposited energy, creation energy # TODO retrieve particles information # if Michel_coords.shape[0] > 0: # Michel_clusters_id = np.unique(Michel_true_clusters[Michel_true_clusters>-1]) # for Michel_id in Michel_clusters_id: # current_index = Michel_true_clusters == Michel_id # distances = cdist(Michel_coords[current_index], MIP_coords) # is_attached = np.min(distances) < 2.8284271247461903 # # Match to MC Michel # distances2 = cdist(Michel_coords[current_index], Michel_start) # closest_mc = np.argmin(distances2, axis=1) # closest_mc_id = closest_mc[np.bincount(closest_mc).argmax()] # TODO how do we count events where there are no predictions but true? if MIP_coords_pred.shape[0] == 0 or Michel_coords_pred.shape[0] == 0: continue # # 2. Compute true and predicted clusters # MIP_clusters = DBSCAN(eps=one_pixel_dbscan, min_samples=10).fit(MIP_coords_pred).labels_ MIP_clusters_id = np.unique(MIP_clusters[MIP_clusters > -1]) # If no predicted MIP then continue TODO how do we count this? if MIP_coords_pred[MIP_clusters > -1].shape[0] == 0: continue # MIP_edges = [] # for cluster in MIP_clusters_id: # touching_idx = find_edges(MIP_coords_pred[MIP_clusters == cluster]) # MIP_edges.append(MIP_coords_pred[MIP_clusters == cluster][touching_idx[0]]) # MIP_edges.append(MIP_coords_pred[MIP_clusters == cluster][touching_idx[1]]) Michel_pred_clusters = DBSCAN( eps=one_pixel_dbscan, min_samples=5).fit(Michel_coords_pred).labels_ Michel_pred_clusters_id = np.unique( Michel_pred_clusters[Michel_pred_clusters > -1]) # print(len(Michel_pred_clusters_id)) # Loop over predicted Michel clusters for Michel_id in Michel_pred_clusters_id: current_index = Michel_pred_clusters == Michel_id # 3. Check whether predicted Michel is attached to a predicted MIP # and at the edge of the predicted MIP Michel_edges_idx = find_edges(Michel_coords_pred[current_index]) # distances_edges = cdist(Michel_coords_pred[current_index][Michel_edges_idx], MIP_edges) # distances = cdist(Michel_coords_pred[current_index], MIP_coords_pred[MIP_clusters>-1]) distances = cdist( Michel_coords_pred[current_index][Michel_edges_idx], MIP_coords_pred[MIP_clusters > -1]) Michel_min, MIP_min = np.unravel_index(np.argmin(distances), distances.shape) is_attached = np.min(distances) < one_pixel_is_attached is_too_close = np.max(distances) < one_pixel_is_attached # Check whether the Michel is at the edge of a predicted MIP # From the MIP pixel closest to the Michel, remove all pixels in # a radius of 15px. DBSCAN what is left and make sure it is all in # one single piece. is_edge = False # default if is_attached: # cluster id of MIP closest MIP_id = MIP_clusters[MIP_clusters > -1][MIP_min] # coordinates of closest MIP pixel in this cluster MIP_min_coords = MIP_coords_pred[MIP_clusters > -1][MIP_min] # coordinates of the whole cluster MIP_cluster_coords = MIP_coords_pred[MIP_clusters == MIP_id] is_edge = is_at_edge(MIP_cluster_coords, MIP_min_coords, one_pixel=one_pixel_dbscan, radius=15.0) michel_pred_num_pix_true, michel_pred_sum_pix_true = -1, -1 michel_true_num_pix, michel_true_sum_pix = -1, -1 michel_true_energy = -1 touching_x, touching_y = -1, -1 edge1_x, edge1_y, edge2_x, edge2_y = -1, -1, -1, -1 true_is_attached, true_is_edge, true_is_too_close = -1, -1, -1 closest_true_id = -1 # Find point where MIP and Michel touches # touching_idx = np.unravel_index(np.argmin(distances), distances.shape) # touching_x = Michel_coords_pred[current_index][Michel_edges_idx][touching_idx[0], 0] # touching_y = Michel_coords_pred[current_index][Michel_edges_idx][touching_idx[0], 1] edge1_x = Michel_coords_pred[current_index][Michel_edges_idx[0], 0] edge1_y = Michel_coords_pred[current_index][Michel_edges_idx[0], 1] edge2_x = Michel_coords_pred[current_index][Michel_edges_idx[1], 0] edge2_y = Michel_coords_pred[current_index][Michel_edges_idx[1], 1] closest_MIP_pixels = np.argmin(distances, axis=1) clusters_idx = MIP_clusters[MIP_clusters > -1][np.argmin(distances, axis=1)] edge0 = is_at_edge( MIP_coords_pred[MIP_clusters == clusters_idx[0]], MIP_coords_pred[MIP_clusters > -1][closest_MIP_pixels[0]], one_pixel=one_pixel_dbscan, radius=10.0) edge1 = is_at_edge( MIP_coords_pred[MIP_clusters == clusters_idx[1]], MIP_coords_pred[MIP_clusters > -1][closest_MIP_pixels[1]], one_pixel=one_pixel_dbscan, radius=10.0) if edge0 and not edge1: touching_x = edge1_x touching_y = edge1_y elif not edge0 and edge1: touching_x = edge2_x touching_y = edge2_y else: if distances[0, closest_MIP_pixels[0]] < distances[ 1, closest_MIP_pixels[1]]: touching_x = edge1_x touching_y = edge1_y else: touching_x = edge2_x touching_y = edge2_y if is_attached and is_edge: # Distance from current Michel pred cluster to all true points distances = cdist(Michel_coords_pred[current_index], Michel_coords) closest_clusters = Michel_true_clusters[np.argmin(distances, axis=1)] closest_clusters_final = closest_clusters[ (closest_clusters > -1) & (np.min(distances, axis=1) < one_pixel_dbscan)] if len(closest_clusters_final) > 0: # print(closest_clusters_final, np.bincount(closest_clusters_final), np.bincount(closest_clusters_final).argmax()) # cluster id of closest true Michel cluster # we take the one that has most overlap # closest_true_id = closest_clusters_final[np.bincount(closest_clusters_final).argmax()] closest_true_id = np.bincount( closest_clusters_final).argmax() overlap_pixels_index = ( closest_clusters == closest_true_id) & (np.min( distances, axis=1) < one_pixel_dbscan) if closest_true_id > -1: closest_true_index = label[ predictions == Michel_label][current_index] == Michel_label # Intersection michel_pred_num_pix_true = 0 michel_pred_sum_pix_true = 0. for v in data[(predictions == Michel_label).reshape( (-1, )), ...][current_index]: count = int( np.any( np.all(v[:data_dim] == Michel_coords[Michel_true_clusters == closest_true_id], axis=1))) michel_pred_num_pix_true += count if count > 0: michel_pred_sum_pix_true += v[-1] michel_true_num_pix = np.count_nonzero( Michel_true_clusters == closest_true_id) # michel_true_sum_pix = clusters[clusters[:, -1] == 4][Michel_true_clusters == closest_true_id][:, -3].sum() michel_true_sum_pix = data[( label == Michel_label).reshape( (-1, )), ...][Michel_true_clusters == closest_true_id][:, -1].sum() # Check whether true Michel is attached to MIP, otherwise exclude true_is_attached = true_Michel_is_attached[ closest_true_id] true_is_edge = true_Michel_is_edge[closest_true_id] true_is_too_close = true_Michel_is_too_close[ closest_true_id] # Register true energy # Match to MC Michel # FIXME in 2D Michel_start is no good # distances2 = cdist(Michel_coords[Michel_true_clusters == closest_true_id], Michel_start) # closest_mc = np.argmin(distances2, axis=1) # closest_mc_id = closest_mc[np.bincount(closest_mc).argmax()] # michel_true_energy = Michel_particles[closest_mc_id, 7] michel_true_energy = -1 # Record every predicted Michel cluster in CSV # Record min and max x in real coordinates min_y = Michel_coords_pred[current_index][:, 1].min( ) # * meta[-1] + meta[1] max_y = Michel_coords_pred[current_index][:, 1].max( ) # * meta[-1] + meta[1] min_x = Michel_coords_pred[current_index][:, 0].min( ) # * meta[-2] + meta[0] max_x = Michel_coords_pred[current_index][:, 0].max( ) # * meta[-2] + meta[0] fout_reco.record( ('batch_id', 'iteration', 'event_idx', 'pred_num_pix', 'pred_sum_pix', 'pred_num_pix_true', 'pred_sum_pix_true', 'true_num_pix', 'true_sum_pix', 'is_attached', 'is_edge', 'michel_true_energy', 'min_y', 'max_y', 'min_x', 'max_x', 'pixel_width', 'pixel_height', 'meta_min_x', 'meta_min_y', 'touching_x', 'touching_y', 'edge1_x', 'edge1_y', 'edge2_x', 'edge2_y', 'edge0', 'edge1', 'true_is_attached', 'true_is_edge', 'true_is_too_close', 'is_too_close', 'closest_true_index'), (batch_id, iteration, event_idx, np.count_nonzero(current_index), data[(predictions == Michel_label).reshape( (-1, )), ...][current_index][:, -1].sum(), michel_pred_num_pix_true, michel_pred_sum_pix_true, michel_true_num_pix, michel_true_sum_pix, is_attached, is_edge, michel_true_energy, min_y, max_y, min_x, max_x, meta[-2], meta[-1], meta[0], meta[1], touching_x, touching_y, edge1_x, edge1_y, edge2_x, edge2_y, edge0, edge1, true_is_attached, true_is_edge, true_is_too_close, is_too_close, closest_true_id)) fout_reco.write() if not store_per_iteration: fout_reco.close() fout_true.close() if store_per_iteration: fout_reco.close() fout_true.close()
def store_uresnet(cfg, data_blob, res, logdir, iteration): # UResNet prediction if not 'segmentation' in res: return method_cfg = cfg['post_processing']['store_uresnet'] index = data_blob['index'] segment_data = res['segmentation'] input_data = data_blob.get( 'input_data' if method_cfg is None else method_cfg.get( 'input_data', 'input_data'), None) store_per_iteration = True if method_cfg is not None and method_cfg.get('store_method', None) is not None: assert (method_cfg['store_method'] in ['per-iteration', 'per-event']) store_per_iteration = method_cfg['store_method'] == 'per-iteration' fout = None if store_per_iteration: fout = CSVData( os.path.join(logdir, 'uresnet-segmentation-iter-%07d.csv' % iteration)) for data_idx, tree_idx in enumerate(index): if not store_per_iteration: fout = CSVData( os.path.join(logdir, 'uresnet-segmentation-event-%07d.csv' % tree_idx)) predictions = np.argmax(segment[data_idx], axis=1) for row in predictions: event = input_data[i] fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (idx, event[0], event[1], event[2], 4, row)) fout.write() if not store_per_iteration: fout.close() if store_per_iteration: fout.close()
def michel_reconstruction(cfg, data_blob, res, logdir, iteration): """ Very simple algorithm to reconstruct Michel clusters from UResNet semantic segmentation output. Parameters ---------- data_blob: dict Input dictionary returned by iotools res: dict Results from the network, dictionary using `analysis_keys` cfg: dict Configuration idx: int Iteration number Notes ----- Assumes 3D Input ----- Requires the following analysis keys: - `segmentation` output of UResNet - `ghost` predictions of GhostNet Requires the following input keys: - `input_data` - `segment_label` - `particles_label` to get detailed information such as energy. - `clusters_label` from `cluster3d_mcst` for true clusters informations Output ------ Writes 2 CSV files: - `michel_reconstruction-*` - `michel_reconstruction2-*` """ method_cfg = cfg['post_processing']['michel_reconstruction'] # Create output CSV store_per_iteration = True if method_cfg is not None and method_cfg.get('store_method', None) is not None: assert (method_cfg['store_method'] in ['per-iteration', 'per-event']) store_per_iteration = method_cfg['store_method'] == 'per-iteration' fout_reco, fout_true = None, None if store_per_iteration: fout_reco = CSVData( os.path.join( logdir, 'michel-reconstruction-reco-iter-%07d.csv' % iteration)) fout_true = CSVData( os.path.join( logdir, 'michel-reconstruction-true-iter-%07d.csv' % iteration)) # Loop over events for batch_id, data in enumerate(data_blob['input_data']): event_idx = data_blob['index'][batch_id] if not store_per_iteration: fout_reco = CSVData( os.path.join( logdir, 'michel-reconstruction-reco-event-%07d.csv' % event_idx)) fout_true = CSVData( os.path.join( logdir, 'michel-reconstruction-true-event-%07d.csv' % event_idx)) # from input/labels label = data_blob['segment_label'][batch_id][:, -1] label_raw = data_blob['sparse3d_pcluster_semantics'][batch_id] clusters = data_blob['clusters_label'][batch_id] particles = data_blob['particles_label'][batch_id] true_ghost_mask = label < 5 data_masked = data[true_ghost_mask] label_masked = label[true_ghost_mask] one_pixel = 5 #2.8284271247461903 # Retrieve semantic labels corresponding to clusters clusters_semantics = np.zeros((clusters.shape[0])) - 1 for cluster_id in np.unique(clusters[:, -2]): cluster_idx = clusters[:, -2] == cluster_id coords = clusters[cluster_idx][:, :3] d = cdist(coords, label_raw[:, :3]) semantic_id = np.bincount(label_raw[d.argmin( axis=1)[d.min(axis=1) < one_pixel]][:, -1].astype(int)).argmax() clusters_semantics[cluster_idx] = semantic_id # Find cluster id for semantics_reco # clusters_new = np.ones((label_masked.shape[0],))*-1 # clusters_E = np.ones((label_masked.shape[0],)) # for cluster_id in np.unique(clusters[:, -2]): # cluster_idx = clusters[:, -2] == cluster_id # coords = clusters[cluster_idx][:, :3] # d = cdist(coords, data_masked[:, :3]) # overlap_idx = d.argmin(axis=0)[d.min(axis=0)<one_pixel] # clusters_new[overlap_idx] = np.bincount(clusters[cluster_idx][d.argmin(axis=0)[d.min(axis=0)<one_pixel]][:, -2].astype(int)).argmax() # clusters_E[overlap_idx] = clusters[cluster_idx][overlap_idx][:, -1] # #clusters_new[overlap_idx][:, :3] = data_masked[overlap_idx][:, :3] # print('clusters new', np.unique(clusters_new, return_counts=True)) # from network output segmentation = res['segmentation'][batch_id] predictions = np.argmax(segmentation, axis=1) ghost_mask = (np.argmax(res['ghost'][batch_id], axis=1) == 0) data_pred = data[ghost_mask] # coords label_pred = label[ghost_mask] # labels predictions = (np.argmax(segmentation, axis=1))[ghost_mask] segmentation = segmentation[ghost_mask] Michel_label = 2 MIP_label = 1 # 0. Retrieve coordinates of true and predicted Michels # MIP_coords = data[(label == 1).reshape((-1,)), ...][:, :3] # Michel_coords = data[(label == 4).reshape((-1,)), ...][:, :3] # Michel_particles = particles[particles[:, 4] == Michel_label] MIP_coords = data[label == MIP_label][:, :3] # Michel_coords = data[label == Michel_label][:, :3] Michel_coords = clusters[clusters_semantics == Michel_label][:, :3] if Michel_coords.shape[0] == 0: # FIXME continue MIP_coords_pred = data_pred[(predictions == MIP_label).reshape((-1, )), ...][:, :3] Michel_coords_pred = data_pred[(predictions == Michel_label).reshape( (-1, )), ...][:, :3] # 1. Find true particle information matching the true Michel cluster # Michel_true_clusters = DBSCAN(eps=one_pixel, min_samples=5).fit(Michel_coords).labels_ # Michel_true_clusters = [Michel_coords[Michel_coords[:, -2] == gid] for gid in np.unique(Michel_coords[:, -2])] #print(clusters.shape, label.shape) Michel_true_clusters = clusters[clusters_semantics == Michel_label][:, -2].astype(np.int64) # Michel_start = Michel_particles[:, :3] for cluster in np.unique(Michel_true_clusters): # print("True", np.count_nonzero(Michel_true_clusters == cluster)) # TODO sum_pix fout_true.record( ('batch_id', 'iteration', 'event_idx', 'num_pix', 'sum_pix'), (batch_id, iteration, event_idx, np.count_nonzero(Michel_true_clusters == cluster), clusters[clusters_semantics == Michel_label][ Michel_true_clusters == cluster][:, -1].sum())) fout_true.write() # e.g. deposited energy, creation energy # TODO retrieve particles information # if Michel_coords.shape[0] > 0: # Michel_clusters_id = np.unique(Michel_true_clusters[Michel_true_clusters>-1]) # for Michel_id in Michel_clusters_id: # current_index = Michel_true_clusters == Michel_id # distances = cdist(Michel_coords[current_index], MIP_coords) # is_attached = np.min(distances) < 2.8284271247461903 # # Match to MC Michel # distances2 = cdist(Michel_coords[current_index], Michel_start) # closest_mc = np.argmin(distances2, axis=1) # closest_mc_id = closest_mc[np.bincount(closest_mc).argmax()] # TODO how do we count events where there are no predictions but true? if MIP_coords_pred.shape[0] == 0 or Michel_coords_pred.shape[0] == 0: continue # print("Also predicted!") # 2. Compute true and predicted clusters MIP_clusters = DBSCAN(eps=one_pixel, min_samples=10).fit(MIP_coords_pred).labels_ Michel_pred_clusters = DBSCAN( eps=one_pixel, min_samples=5).fit(Michel_coords_pred).labels_ Michel_pred_clusters_id = np.unique( Michel_pred_clusters[Michel_pred_clusters > -1]) # print(len(Michel_pred_clusters_id)) # Loop over predicted Michel clusters for Michel_id in Michel_pred_clusters_id: current_index = Michel_pred_clusters == Michel_id # 3. Check whether predicted Michel is attached to a predicted MIP # and at the edge of the predicted MIP distances = cdist(Michel_coords_pred[current_index], MIP_coords_pred[MIP_clusters > -1]) # is_attached = np.min(distances) < 2.8284271247461903 is_attached = np.min(distances) < 5 is_edge = False # default # print("Min distance:", np.min(distances)) if is_attached: Michel_min, MIP_min = np.unravel_index(np.argmin(distances), distances.shape) MIP_id = MIP_clusters[MIP_clusters > -1][MIP_min] MIP_min_coords = MIP_coords_pred[MIP_clusters > -1][MIP_min] MIP_cluster_coords = MIP_coords_pred[MIP_clusters == MIP_id] ablated_cluster = MIP_cluster_coords[np.linalg.norm( MIP_cluster_coords - MIP_min_coords, axis=1) > 15.0] if ablated_cluster.shape[0] > 0: new_cluster = DBSCAN( eps=one_pixel, min_samples=5).fit(ablated_cluster).labels_ is_edge = len(np.unique( new_cluster[new_cluster > -1])) == MIP_label else: is_edge = True # print(is_attached, is_edge) michel_pred_num_pix_true, michel_pred_sum_pix_true = -1, -1 michel_true_num_pix, michel_true_sum_pix = -1, -1 michel_true_energy = -1 if is_attached and is_edge and Michel_coords.shape[0] > 0: # Distance from current Michel pred cluster to all true points distances = cdist(Michel_coords_pred[current_index], Michel_coords) closest_clusters = Michel_true_clusters[np.argmin(distances, axis=1)] closest_clusters_final = closest_clusters[ (closest_clusters > -1) & (np.min(distances, axis=1) < one_pixel)] if len(closest_clusters_final) > 0: # print(closest_clusters_final, np.bincount(closest_clusters_final), np.bincount(closest_clusters_final).argmax()) # cluster id of closest true Michel cluster # we take the one that has most overlap # closest_true_id = closest_clusters_final[np.bincount(closest_clusters_final).argmax()] closest_true_id = np.bincount( closest_clusters_final).argmax() overlap_pixels_index = (closest_clusters == closest_true_id) & (np.min( distances, axis=1) < one_pixel) if closest_true_id > -1: closest_true_index = label_pred[ predictions == Michel_label][current_index] == Michel_label # Intersection michel_pred_num_pix_true = 0 michel_pred_sum_pix_true = 0. for v in data_pred[( predictions == Michel_label).reshape((-1, )), ...][current_index]: count = int( np.any( np.all(v[:3] == Michel_coords[Michel_true_clusters == closest_true_id], axis=1))) michel_pred_num_pix_true += count if count > 0: michel_pred_sum_pix_true += v[-1] michel_true_num_pix = np.count_nonzero( Michel_true_clusters == closest_true_id) michel_true_sum_pix = clusters[ clusters_semantics == Michel_label][ Michel_true_clusters == closest_true_id][:, -1].sum() # Register true energy # Match to MC Michel # distances2 = cdist(Michel_coords[Michel_true_clusters == closest_true_id], Michel_start) # closest_mc = np.argmin(distances2, axis=1) # closest_mc_id = closest_mc[np.bincount(closest_mc).argmax()] # michel_true_energy = Michel_particles[closest_mc_id, 7] michel_true_energy = particles[ closest_true_id].energy_init() #print('michel true energy', particles[closest_true_id].energy_init(), particles[closest_true_id].pdg_code(), particles[closest_true_id].energy_deposit()) # Record every predicted Michel cluster in CSV fout_reco.record( ('batch_id', 'iteration', 'event_idx', 'pred_num_pix', 'pred_sum_pix', 'pred_num_pix_true', 'pred_sum_pix_true', 'true_num_pix', 'true_sum_pix', 'is_attached', 'is_edge', 'michel_true_energy'), (batch_id, iteration, event_idx, np.count_nonzero(current_index), data_pred[(predictions == Michel_label).reshape( (-1, )), ...][current_index][:, -1].sum(), michel_pred_num_pix_true, michel_pred_sum_pix_true, michel_true_num_pix, michel_true_sum_pix, is_attached, is_edge, michel_true_energy)) fout_reco.write() if not store_per_iteration: fout_reco.close() fout_true.close() if store_per_iteration: fout_reco.close() fout_true.close()
class Pi0Chain(): def __init__(self, io_cfg, chain_cfg, verbose=False): ''' Initializes the chain from the configuration file ''' # Initialize the data loader io_cfg = yaml.load(io_cfg,Loader=yaml.Loader) # Save config, initialize output self.cfg = chain_cfg self.verbose = verbose self.event = None self.output = {} # Initialize log log_path = chain_cfg['name']+'_log.csv' print('Initialized Pi0 mass chain, log path:', log_path) self._log = CSVData(log_path) self._keys = ['event_id', 'pion_id', 'pion_mass'] # If a network is specified, initialize the network self.network = False if chain_cfg['segment'] == 'uresnet' or chain_cfg['shower_start'] == 'ppn': self.network = True with open(chain_cfg['net_cfg']) as cfg_file: net_cfg = yaml.load(cfg_file,Loader=yaml.Loader) io_cfg['model'] = net_cfg['model'] io_cfg['trainval'] = net_cfg['trainval'] # Initialize the fragment identifier self.frag_est = FragmentEstimator() # If a direction estimator is requested, initialize it if chain_cfg['shower_dir'] != 'truth': self.dir_est = DirectionEstimator() # If a clusterer is requested, initialize it if chain_cfg['shower_energy'] == 'cone': self.clusterer = ConeClusterer() # If a pi0 identifier is requested, initialize it if chain_cfg['shower_match'] == 'proximity': self.matcher = Pi0Matcher() # Pre-process configuration process_config(io_cfg) # Instantiate "handlers" (IO tools) self.hs = prepare(io_cfg) self.data_set = iter(self.hs.data_io) def hs(self): return self.hs def data_set(self): return self.data_set def log(self, eid, pion_id, pion_mass): self._log.record(self._keys, [eid, pion_id, pion_mass]) self._log.write() self._log.flush() def run(self): ''' Runs the full Pi0 reconstruction chain, from 3D charge information to Pi0 masses for events that contain one or more Pi0 decay. ''' n_events = len(self.hs.data_io) for i in range(n_events): self.run_loop() def run_loop(self): ''' Runs the full Pi0 reconstruction chain on a single event, from 3D charge information to Pi0 masses for events that contain one or more Pi0 decay. ''' # Reset output self.output = {} # Load data if not self.network: event = next(self.data_set) event_id = event['index'][0] else: event, self.output['forward'] = self.hs.trainer.forward(self.data_set) for key in event.keys(): if key != 'particles': event[key] = event[key][0] event_id = event['index'] self.event = event # Filter out ghosts self.filter_ghosts(event) # Reconstruct energy self.reconstruct_energy(event) # Identify shower starting points, skip if there is less than 2 (no pi0) self.find_shower_starts(event) if len(self.output['showers']) < 2: if self.verbose: print('No shower start point found in event', event_id) return [] # Match primary shower fragments with each start points if self.cfg['shower_dir'] != 'truth' or self.cfg['shower_energy'] == 'cone': self.match_primary_fragments(event) if not len(self.output['fragments']): if self.verbose: print('Could not find a fragment for each start point in event', event_id) return [] # Reconstruct shower direction vectors self.reconstruct_shower_directions(event) # Reconstruct shower energy self.reconstruct_shower_energy(event) # Identify pi0 decays self.identify_pi0(event) if not len(self.output['matches']): if self.verbose: print('No pi0 found in event', event_id) return [] # Compute masses masses = self.pi0_mass() # Log masses for i, m in enumerate(masses): self.log(event_id, i, m) def filter_ghosts(self, event): ''' Removes ghost points from the charge tensor ''' if self.cfg['input'] == 'energy': self.output['segment'] = event['segment_label_true'] self.output['group'] = event['group_label_true'] self.output['dbscan'] = event['dbscan_label_true'] elif self.cfg['segment'] == 'mask': mask = np.where(event['segment_label_reco'][:,-1] != 5)[0] self.output['charge'] = event['charge'][mask] self.output['segment'] = event['segment_label_reco'][mask] self.output['group'] = event['group_label_reco'] # group_label_reco is wrong, so no masking, TODO self.output['dbscan'] = event['dbscan_label_reco'][mask] elif self.cfg['segment'] == 'uresnet': # Get the segmentation output of the network res = self.output['forward'] # Argmax to determine most probable label pred_ghost = np.argmax(res['ghost'][0], axis=1) pred_labels = np.argmax(res['segmentation'][0], axis=1) mask = np.where(pred_ghost == 0)[0] self.output['charge'] = event['charge'][mask] self.output['segment'] = copy(event['segment_label_reco']) self.output['segment'][:,-1] = pred_labels self.output['segment'] = self.output['segment'][mask] self.output['group'] = event['group_label_reco'] # group_label_reco is wrong, so no masking, TODO self.output['dbscan'] = event['dbscan_label_reco'][mask] else: raise ValueError('Semantic segmentation method not recognized:', self.cfg['segment']) def reconstruct_energy(self, event): ''' Reconstructs energy deposition from charge ''' if self.cfg['input'] == 'energy': self.output['energy'] = event['energy'] elif self.cfg['response'] == 'constant': reco = self.cfg['response_cst']*self.output['charge'][:,-1] self.output['energy'] = copy(self.output['charge']) self.output['energy'][:,-1] = reco elif self.cfg['response'] == 'average': self.output['energy'] = copy(self.output['charge']) self.output['energy'][:,-1] = self.cfg['response_average'] elif self.cfg['response'] == 'full': raise NotImplementedError('Proper energy reconstruction not implemented yet') elif self.cfg['response'] == 'enet': raise NotImplementedError('ENet not implemented yet') else: raise ValueError('Energy reconstruction method not recognized:', self.cfg['response']) def find_shower_starts(self, event): ''' Identify starting points of showers ''' if self.cfg['shower_start'] == 'truth': # Get the true shower starting points from the particle information self.output['showers'] = [] for i, part in enumerate(event['particles'][0]): if self.is_shower(part): new_shower = Shower(start=[part.first_step().x(), part.first_step().y(), part.first_step().z()], pid=i) self.output['showers'].append(new_shower) elif self.cfg['shower_start'] == 'ppn': raise NotImplementedError('PPN not implemented yet') else: raise ValueError('EM shower primary identifiation method not recognized:', self.cfg['shower_start']) def match_primary_fragments(self, event): ''' For each shower start point, find the closest DBSCAN shower cluster ''' # Mask out points that are not showers shower_mask = np.where(self.output['segment'][:,-1] == 2)[0] if not len(shower_mask): self.output['fragments'] = [] return # Assign clusters points = np.array([s.start for s in self.output['showers']]) clusts = self.frag_est.assign_frags_to_primary(self.output['energy'][shower_mask], points) if len(clusts) != len(points): self.output['fragments'] = [] return # Return list of voxel indices for each cluster self.output['fragments'] = clusts def reconstruct_shower_directions(self, event): ''' Reconstructs the direction of the showers ''' if self.cfg['shower_dir'] == 'truth': for shower in self.output['showers']: part = event['particles'][0][shower.pid] mom = [part.px(), part.py(), part.pz()] shower.direction = list(np.array(mom)/np.linalg.norm(mom)) elif self.cfg['shower_dir'] == 'pca' or self.cfg['shower_dir'] == 'cent': # Apply DBSCAN, PCA on the touching cluster to get angles algo = self.cfg['shower_dir'] mask = np.where(self.output['segment'][:,-1] == 2)[0] points = np.array([s.start for s in self.output['showers']]) try: res = self.dir_est.get_directions(self.output['energy'][mask], points, self.output['fragments'], max_distance=float('inf'), mode=algo) except AssertionError as err: # Cluster was not found for at least one primary if self.verbose: print('Error in direction reconstruction:', err) res = [[0., 0., 0.] for _ in range(len(self.output['showers']))] for i, shower in enumerate(self.output['showers']): shower.direction = res[i] else: raise ValueError('Shower direction reconstruction method not recognized:', self.cfg['shower_dir']) def reconstruct_shower_energy(self, event): ''' Clusters the different showers, reconstruct energy of each shower ''' if self.cfg['shower_energy'] == 'truth': # Gets the true energy information from Geant4 for shower in self.output['showers']: part = event['particles'][0][shower.pid] shower.energy = part.energy_init() pid = shower.pid mask = np.where(self.output['group'][:,-1] == pid)[0] shower.voxels = mask elif self.cfg['shower_energy'] == 'group': # Gets all the voxels in the group corresponding to the pid, adds up energy for shower in self.output['showers']: pid = shower.pid mask = np.where(self.output['group'][:,-1] == pid)[0] shower.voxels = mask shower.energy = np.sum(self.output['energy'][mask][:,-1]) elif self.cfg['shower_energy'] == 'cone': # Fits cones to each shower, adds energies within that cone points = np.array([s.start for s in self.output['showers']]) dirs = np.array([s.direction for s in self.output['showers']]) mask = np.where(self.output['segment'][:,-1] == 2)[0] try: pred = self.clusterer.fit_predict(self.output['energy'][mask,:3], points, self.output['fragments'], dirs) except (ValueError, AssertionError): for i, shower in enumerate(self.output['showers']): shower.voxels = [] shower.energy = 0. return padded_pred = np.full(len(self.output['segment']), -1) padded_pred[mask] = pred for i, shower in enumerate(self.output['showers']): shower_mask = np.where(padded_pred == i)[0] if not len(shower_mask): shower.energy = 0. continue shower.voxels = shower_mask shower.energy = np.sum(self.output['energy'][shower_mask][:,-1]) else: raise ValueError('Shower energy reconstruction method not recognized:', self.cfg['shower_energy']) def identify_pi0(self, event): ''' Proposes pi0 candidates (match two showers) ''' self.output['matches'] = [] self.output['vertices'] = [] n_showers = len(self.output['showers']) if self.cfg['shower_match'] == 'truth': # Get the creation point of each particle. If two gammas originate from the same point, # It is most likely a pi0 decay. creations = [] for shower in self.output['showers']: part = event['particles'][0][shower.pid] creations.append([part.position().x(), part.position().y(), part.position().z()]) for i, ci in enumerate(creations): for j in range(i+1,n_showers): if (np.array(ci) == np.array(creations[j])).all(): self.output['matches'].append([i,j]) self.output['vertices'].append(ci) return self.output['matches'] elif self.cfg['shower_match'] == 'proximity': # Pair closest shower vectors points = np.array([s.start for s in self.output['showers']]) dirs = np.array([s.direction for s in self.output['showers']]) try: self.output['matches'], self.output['vertices'], dists =\ self.matcher.find_matches(points, dirs, self.output['segment']) except ValueError as err: if self.verbose: print('Error in PID:', err) return if self.cfg['refit_dir']: for i, m in enumerate(self.output['matches']): v = np.array(self.output['vertices'][i]) for j in m: new_dir = np.array(points[j]) - v self.output['showers'][j].direction = new_dir/np.linalg.norm(new_dir) if self.cfg['shower_energy'] == 'cone' and self.cfg['refit_cone']: self.reconstruct_shower_energy(event) else: raise ValueError('Shower matching method not recognized:', self.cfg['shower_match']) def pi0_mass(self): ''' Reconstructs the pi0 mass ''' from math import sqrt masses = [] for match in self.output['matches']: s1, s2 = self.output['showers'][match[0]], self.output['showers'][match[1]] e1, e2 = s1.energy, s2.energy t1, t2 = s1.direction, s2.direction costheta = np.dot(t1, t2) if abs(costheta) > 1.: masses.append(0.) continue masses.append(sqrt(2*e1*e2*(1-costheta))) self.output['masses'] = masses return masses def draw(self): from mlreco.visualization import plotly_layout3d from mlreco.visualization.voxels import scatter_voxels, scatter_label import plotly.plotly as py import plotly.graph_objs as go from plotly.offline import init_notebook_mode, iplot init_notebook_mode(connected=False) # Create labels for the voxels # Use a different color for each cluster labels = np.full(len(self.output['energy'][:,-1]), -1) for i, s in enumerate(self.output['showers']): labels[s.voxels] = i # Draw voxels with cluster labels voxels = self.output['energy'][:,:3] graph_voxels = scatter_label(voxels, labels, 2)[0] graph_voxels.name = 'Shower ID' graph_data = [graph_voxels] if len(self.output['showers']): # Add EM primary points points = np.array([s.start for s in self.output['showers']]) graph_start = scatter_voxels(points)[0] graph_start.name = 'Shower starts' graph_data.append(graph_start) # Add EM primary directions dirs = np.array([s.direction for s in self.output['showers']]) arrows = go.Cone(x=points[:,0], y=points[:,1], z=points[:,2], u=dirs[:,0], v=dirs[:,1], w=dirs[:,2], sizemode='absolute', sizeref=0.25, anchor='tail', showscale=False, opacity=0.4) graph_data.append(arrows) # Add a vertex if matches, join vertex to start points for i, m in enumerate(self.output['matches']): v = self.output['vertices'][i] s1, s2 = self.output['showers'][m[0]].start, self.output['showers'][m[1]].start points = [v, s1, v, s2] line = scatter_voxels(np.array(points))[0] line.name = 'Pi0 Decay' line.mode = 'lines,markers' graph_data.append(line) # Draw iplot(go.Figure(data=graph_data,layout=plotly_layout3d())) @staticmethod def is_shower(particle): ''' Check if the particle is a shower ''' pdg_code = abs(particle.pdg_code()) if pdg_code == 22 or pdg_code == 11: return True return False
def store_input(cfg, data_blob, res, logdir, iteration): """ Store input data blob. Configuration ------------- threshold: float, optional Default: 0. input_data: str, optional particles_label: str, optional segment_label: str, optional clusters_label: str, optional cluster3d_mcst_true: str, optional store_method: str, optional Can be `per-iteration` or `per-event` """ method_cfg = cfg['post_processing']['store_input'] if (method_cfg is not None and not method_cfg.get('input_data', 'input_data') in data_blob ) or (method_cfg is None and 'input_data' not in data_blob): return threshold = 0. if method_cfg is None else method_cfg.get('threshold', 0.) data_dim = 3 if method_cfg is None else method_cfg.get('data_dim', 3) index = data_blob.get('index', None) input_dat = data_blob.get( 'input_data' if method_cfg is None else method_cfg.get( 'input_data', 'input_data'), None) label_ppn = data_blob.get( 'particles_label' if method_cfg is None else method_cfg.get( 'particles_label', 'particles_label'), None) label_seg = data_blob.get( 'segment_label' if method_cfg is None else method_cfg.get( 'segment_label', 'segment_label'), None) label_cls = data_blob.get( 'clusters_label' if method_cfg is None else method_cfg.get( 'clusters_label', 'clusters_label'), None) label_mcst = data_blob.get( 'cluster3d_mcst_true' if method_cfg is None else method_cfg.get( 'cluster3d_mcst_true', 'cluster3d_mcst_true'), None) store_per_iteration = True if method_cfg is not None and method_cfg.get('store_method', None) is not None: assert (method_cfg['store_method'] in ['per-iteration', 'per-event']) store_per_iteration = method_cfg['store_method'] == 'per-iteration' fout = None if store_per_iteration: fout = CSVData(os.path.join(logdir, 'input-iter-%07d.csv' % iteration)) if input_dat is None: return for data_index, tree_index in enumerate(index): if not store_per_iteration: fout = CSVData( os.path.join(logdir, 'input-event-%07d.csv' % tree_index)) mask = input_dat[data_index][:, -1] > threshold # type 0 = input data for row in input_dat[data_index][mask]: coords_labels, coords = get_coords(row, data_dim, tree_index) fout.record(coords_labels + ('type', 'value'), coords + (0, row[data_dim + 1])) fout.write() # type 1 = Labels for PPN if label_ppn is not None: for row in label_ppn[data_index]: fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_index, row[0], row[1], row[2], 1, row[4])) fout.write() # 2 = UResNet labels if label_seg is not None: for row in label_seg[data_index][mask]: coords_labels, coords = get_coords(row, data_dim, tree_index) fout.record(coords_labels + ('type', 'value'), coords + (2, row[data_dim + 1])) fout.write() # type 15 = group id, 16 = semantic labels, 17 = energy if label_cls is not None: for row in label_cls[data_index]: fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_index, row[0], row[1], row[2], 15, row[5])) fout.write() for row in label_cls[data_index]: fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_index, row[0], row[1], row[2], 16, row[6])) fout.write() for row in label_cls[data_index]: fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_index, row[0], row[1], row[2], 17, row[4])) fout.write() # type 18 = cluster3d_mcst_true if label_mcst is not None: for row in label_mcst[data_index]: fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_index, row[0], row[1], row[2], 19, row[4])) fout.write() if not store_per_iteration: fout.close() if store_per_iteration: fout.close()
def uresnet_metrics(cfg, data_blob, res, logdir, iteration): # UResNet prediction if not 'segmentation' in res: return method_cfg = cfg['post_processing']['uresnet_metrics'] index = data_blob['index'] segment_data = res['segmentation'] # input_data = data_blob.get('input_data' if method_cfg is None else method_cfg.get('input_data', 'input_data'), None) segment_label = data_blob.get( 'segment_label' if method_cfg is None else method_cfg.get( 'segment_label', 'segment_label'), None) num_classes = 5 if method_cfg is None else method_cfg.get('num_classes', 5) store_per_iteration = True if method_cfg is not None and method_cfg.get('store_method', None) is not None: assert (method_cfg['store_method'] in ['per-iteration', 'per-event']) store_per_iteration = method_cfg['store_method'] == 'per-iteration' fout = None if store_per_iteration: fout = CSVData( os.path.join(logdir, 'uresnet-metrics-iter-%07d.csv' % iteration)) for data_idx, tree_idx in enumerate(index): if not store_per_iteration: fout = CSVData( os.path.join(logdir, 'uresnet-metrics-event-%07d.csv' % tree_idx)) predictions = np.argmax(segment_data[data_idx], axis=1) label = segment_label[data_idx][:, -1] acc = (predictions == label).sum() / float(len(label)) class_acc = [] pix = [] for c1 in range(num_classes): for c2 in range(num_classes): class_mask = label == c1 class_acc.append((predictions[class_mask] == c2).sum() / float(np.count_nonzero(class_mask))) pix.append( np.count_nonzero((label == c1) & (predictions == c2))) fout.record(('idx', 'acc') + tuple([ 'confusion_%d_%d' % (c1, c2) for c1 in range(num_classes) for c2 in range(num_classes) ]) + tuple([ 'num_pix_%d_%d' % (c1, c2) for c1 in range(num_classes) for c2 in range(num_classes) ]), (tree_idx, acc) + tuple(class_acc) + tuple(pix)) fout.write() if not store_per_iteration: fout.close() if store_per_iteration: fout.close()
def track_clustering(cfg, data_blob, res, logdir, iteration): """ Track clustering on PPN+UResNet output. Parameters ---------- data_blob: dict The input data dictionary from iotools. res: dict The output of the network, formatted using `analysis_keys`. cfg: dict Configuration. debug: bool, optional Whether to print some stats or not in the stdout. Notes ----- Based on - semantic segmentation output - point position and type predictions In addition, includes a break algorithm to break track clusters into smaller clusters based on predicted points. Stores all points and informations in a CSV file. """ method_cfg = cfg['post_processing']['track_clustering'] dbscan_cfg = cfg['model']['modules']['dbscan'] data_dim = int(dbscan_cfg['data_dim']) min_samples = int(dbscan_cfg['minPoints']) debug = bool(method_cfg.get('debug', None)) score_threshold = float(method_cfg.get('score_threshold', 0.6)) threshold_association = float(method_cfg.get('threshold_association', 3)) exclusion_radius = float(method_cfg.get('exclusion_radius', 5)) type_threshold = float(method_cfg.get('type_threshold', 2)) store_per_iteration = True if method_cfg is not None and method_cfg.get('store_method', None) is not None: assert (method_cfg['store_method'] in ['per-iteration', 'per-event']) store_per_iteration = method_cfg['store_method'] == 'per-iteration' fout = None if store_per_iteration: fout = CSVData( os.path.join(logdir, 'track-clustering-iter-%07d.csv' % iteration)) # Loop over batch index #for b in batch_ids: for batch_index, data in enumerate(data_blob['input_data']): if not store_per_iteration: fout = CSVData( os.path.join(logdir, 'track-clustering-event-%07d.csv' % event_index)) event_clusters = res['final'][batch_index] event_index = data_blob['index'][batch_index] event_data = data[:, :data_dim] event_clusters_label = data_blob['clusters_label'][batch_index] event_particles_label = data_blob['particles_label'][batch_index] event_segmentation = res['segmentation'][batch_index] event_segmentation_label = data_blob['segment_label'][batch_index] points = res['points'][batch_index] event_xyz = points[:, :data_dim] event_scores = points[:, data_dim:data_dim + 2] event_mask = res['mask_ppn2'][batch_index] anchors = (event_data + 0.5) event_xyz = event_xyz + anchors dbscan_points = [] predicted_points = [] # 0) Postprocessing on predicted pixels # Apply selection mask from PPN2 + score thresholding scores = scipy.special.softmax(event_scores, axis=1) event_mask = ((~(event_mask == 0)).any( axis=1)) & (scores[:, 1] > score_threshold) # Now loop through semantic classes and look at ppn+uresnet predictions uresnet_predictions = np.argmax(event_segmentation[event_mask], axis=1) num_classes = event_segmentation.shape[1] ppn_type_predictions = np.argmax(scipy.special.softmax( points[event_mask][:, 5:], axis=1), axis=1) for c in range(num_classes): uresnet_points = uresnet_predictions == c ppn_points = ppn_type_predictions == c # We want to keep only points of type X within 2px of uresnet prediction of type X d = scipy.spatial.distance.cdist( event_xyz[event_mask][ppn_points], event_data[event_mask][uresnet_points]) ppn_mask = (d < type_threshold).any(axis=1) # dbscan_points stores coordinates only # predicted_points stores everything for each point dbscan_points.append(event_xyz[event_mask][ppn_points][ppn_mask]) pp = points[event_mask][ppn_points][ppn_mask] pp[:, :3] += anchors[event_mask][ppn_points][ppn_mask] predicted_points.append(pp) dbscan_points = np.concatenate(dbscan_points, axis=0) predicted_points = np.concatenate(predicted_points, axis=0) # 0.5) Remove points for delta rays point_types = np.argmax(predicted_points[:, -5:], axis=1) dbscan_points = dbscan_points[point_types != 3] # 1) Break algorithm # Using PPN point predictions, the idea is to mask an area around # each point associated with a given predicted cluster. Dbscan # then tells us whether this predicted cluster should be broken in # two or more smaller clusters. Pixel that were masked are then # assigned to the closest cluster among the newly formed clusters. if dbscan_points.shape[0] > 0: # If PPN predicted some points cluster_ids = np.unique(event_clusters[:, -1]) final_clusters = [] # Loop over predicted clusters for c in cluster_ids: # Find predicted points associated to this predicted cluster cluster = event_clusters[event_clusters[:, -1] == c][:, :data_dim] d = cdist(dbscan_points, cluster) index = d.min(axis=1) < threshold_association new_d = d[index.reshape((-1, )), :] # Now mask around these points new_index = (new_d > exclusion_radius).all(axis=0) # Main body of the cluster (far way from the points) new_cluster = cluster[new_index] # Cluster part around the points remaining_cluster = cluster[~new_index] # FIXME this might eliminate too small clusters? # put a threshold here? sometimes clusters with 1px only if new_cluster.shape[0] == 0: continue # Now dbscan on the main body of the cluster to find if we need # to break it or not db2 = DBSCAN(eps=exclusion_radius, min_samples=min_samples).fit(new_cluster).labels_ # All points were garbage if (len(new_cluster[db2 == -1]) == len(new_cluster)): continue # These are going to be the new bodies of predicted clusters new_cluster_ids = np.unique(db2) new_clusters = [] for c2 in new_cluster_ids: if c2 > -1: new_clusters.append([new_cluster[db2 == c2]]) # If some points were left by dbscan, put them in remaining # cluster and assign them to closest cluster if len(new_cluster[db2 == -1]) > 0: print(len(new_cluster[db2 == -1]), len(new_cluster)) remaining_cluster = np.concatenate( [remaining_cluster, new_cluster[db2 == -1]], axis=0) # effectively remove them from new_cluster for the argmin new_cluster[db2 == -1] = 100000 # Now assign remaining pixels in remaining_cluster based on # their distance to the new clusters. # First we find which point of new_cluster was closest d3 = cdist(remaining_cluster, new_cluster) # Then we find what is the corresponding new cluster id of this # closest pixel remaining_db = db2[d3.argmin(axis=1)] # Now append each pixel of remaining_cluster to correct new # cluster for i, c in enumerate(remaining_cluster): new_clusters[remaining_db[i]].append(c[None, :]) # Turn everything into np arrays for i in range(len(new_clusters)): new_clusters[i] = np.concatenate(new_clusters[i], axis=0) final_clusters.extend(new_clusters) else: # no predicted points: no need to break, keep predicted clusters final_clusters = [] cluster_idx = np.unique(event_clusters[:, -1]) for c in cluster_idx: final_clusters.append( event_clusters[event_clusters[:, -1] == c][:, :data_dim]) # 2) Compute cluster efficiency/purity # ie associate final clusters after breaking with true clusters label_cluster_ids = np.unique(event_clusters_label[:, -1]) true_clusters = [] for c in label_cluster_ids: true_clusters.append( event_clusters_label[event_clusters_label[:, -1] == c][:, :-2]) # Match each predicted cluster to a true cluster matches = [] overlaps = [] for predicted_cluster in final_clusters: overlap = [] for true_cluster in true_clusters: overlap_pixel_count = np.count_nonzero( (cdist(predicted_cluster, true_cluster) < 1).any(axis=0)) overlap.append(overlap_pixel_count) overlap = np.array(overlap) if overlap.max() > 0: matches.append(overlap.argmax()) overlaps.append(overlap.max()) else: matches.append(-1) overlaps.append(0) # Compute cluster purity/efficiency purity, efficiency = [], [] npix_predicted, npix_true = [], [] for i, predicted_cluster in enumerate(final_clusters): if matches[i] > -1: matched_cluster = true_clusters[matches[i]] purity.append(overlaps[i] / predicted_cluster.shape[0]) efficiency.append(overlaps[i] / matched_cluster.shape[0]) npix_predicted.append(predicted_cluster.shape[0]) npix_true.append(matched_cluster.shape[0]) if debug: print("Purity: ", purity) print("Efficiency: ", efficiency) print("Match indices: ", matches) print("Overlaps: ", overlaps) print("Npix predicted: ", npix_predicted) print("Npix true: ", npix_true) # Record in CSV everything # Point in data and semantic class predictions/true information for i, point in enumerate(data): fout.record( ('type', 'x', 'y', 'z', 'batch_id', 'value', 'predicted_class', 'true_class', 'cluster_id', 'point_type', 'idx'), (0, point[0], point[1], point[2], batch_index, point[4], np.argmax(event_segmentation[i]), event_segmentation_label[i, -1], -1, -1, event_index)) fout.write() # Predicted clusters for c, cluster in enumerate(final_clusters): for point in cluster: fout.record( ('type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value', 'predicted_class', 'true_class', 'point_type', 'idx'), (1, point[0], point[1], point[2], batch_index, c, -1, -1, -1, -1, event_index)) fout.write() # True clusters for c, cluster in enumerate(true_clusters): for point in cluster: fout.record( ('type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value', 'predicted_class', 'true_class', 'point_type', 'idx'), (2, point[0], point[1], point[2], batch_index, c, -1, -1, -1, -1, event_index)) fout.write() # for point in event_xyz: # fout.record(('type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value', 'predicted_class', 'true_class', 'point_type'), # (3, point[0], point[1], point[2], batch_index, -1, -1, -1, -1, -1)) # fout.write() # for point in event_clusters_label: # fout.record(('type', 'x', 'y', 'z', 'batch_id', 'cluster_id', 'value', 'predicted_class', 'true_class', 'point_type'), # (4, point[0], point[1], point[2], batch_index, point[4], -1, -1, -1, -1)) # fout.write() # True PPN points for point in event_particles_label: fout.record( ('type', 'x', 'y', 'z', 'batch_id', 'point_type', 'cluster_id', 'value', 'predicted_class', 'true_class', 'idx'), (5, point[0], point[1], point[2], batch_index, point[4], -1, -1, -1, -1, event_index)) fout.write() # Predicted PPN points for point in predicted_points: fout.record( ('type', 'x', 'y', 'z', 'batch_id', 'predicted_class', 'value', 'true_class', 'cluster_id', 'point_type', 'idx'), (6, point[0], point[1], point[2], batch_index, np.argmax(point[-5:]), -1, -1, -1, -1, event_index)) fout.write() if not store_per_iteration: fout.close() if store_per_iteration: fout.close()