def get_item_specific(self, item: Tuple[str, int]): """ Loading method used in specific mode, when given item specifies object and index of next chunk. """ # Get specific item (e.g. chunk 5 of object 1) if isinstance(item, tuple): splitted_obj = self._splitted_objs[item[0]] # In specific mode, the files get loaded sequentially if self._curr_name != item[0]: self._curr_obj = self._adapt_obj( objects.load_obj(self._data_type, self._data + item[0] + '.pkl')) self._curr_name = item[0] # Return None if requested chunk doesn't exist if item[1] >= len(splitted_obj) or abs( item[1]) > len(splitted_obj): return None, None, None # splitted_obj: (source_node, node_arr) local_bfs = splitted_obj[item[1]][1] sample, idcs = objects.extract_cloud_subset( self._curr_obj, local_bfs) return sample, idcs, splitted_obj[item[1]][0] else: raise ValueError( 'In validation mode, items can only be requested with a tuple of object name and ' 'chunk index within that cloud.')
def load_prediction(self, name: str): self._curr_obj = objects.load_obj(self._datatype, f'{self._data_path}{name}.pkl') if self._hybrid_mode and isinstance(self._curr_obj, CloudEnsemble): self._curr_obj = self._curr_obj.hc if self._label_remove is not None: self._curr_obj.remove_nodes(self._label_remove) self._curr_name = name if os.path.exists(f'{self._save_path}{name}_preds.pkl'): preds = basics.load_pkl(f'{self._save_path}{name}_preds.pkl') self._curr_obj.set_predictions(preds)
def smooth_dataset(input_path: str, output_path: str, data_type: str = 'ce'): files = glob.glob(input_path + '*.pkl') if not os.path.isdir(output_path): os.makedirs(output_path) print("Starting to smooth labels of dataset...") for file in tqdm(files): slashs = [pos for pos, char in enumerate(file) if char == '/'] name = file[slashs[-1] + 1:-4] obj = objects.load_obj(data_type, file) obj = smooth_labels(obj) obj.save2pkl(output_path + name + '.pkl')
def preds2hc(preds: str): preds = os.path.expanduser(preds) preds = basics.load_pkl(preds) obj = objects.load_obj('ce', preds[0]) obj.set_predictions(preds[1]) obj.generate_pred_labels() if isinstance(obj, CloudEnsemble): hc = obj.hc else: hc = obj _ = hc.pred_node_labels return hc
def visualize_data_set(input_path: str, output_path: str, random_seed: int = 4, data_type: str = 'ce'): files = glob.glob(input_path + '*.pkl') if not os.path.isdir(output_path): os.makedirs(output_path) for file in tqdm(files): slashs = [pos for pos, char in enumerate(file) if char == '/'] name = file[slashs[-1] + 1:-4] obj = objects.load_obj(data_type, file) visualize_clouds([obj], capture=True, path=output_path + name + '.png', random_seed=random_seed)
def get_obj_info(self, name: str): if not self._specific: # get objects which are already in cache ix = self._obj_names.index(name) obj = self._objs[ix] else: # load objects individually obj = self._adapt_obj( objects.load_obj(self._data_type, self._data + name + '.pkl')) attr_dict = { 'vertex_num': len(obj.vertices), 'node_num': len(obj.nodes), 'types': list(np.unique(obj.types, return_counts=True)), 'labels': list(np.unique(obj.labels, return_counts=True)), 'length': self.get_obj_length(name), 'node_labels': list(np.unique(obj.node_labels, return_counts=True)) } return attr_dict
def visualize_prediction_set(input_path: str, output_path: str, gt_path: str, random_seed: int = 4, data_type: str = 'ce', save_to_image: bool = True): """ Saves images of all predicted files at input_path using the visualize_prediction method for each file. """ files = glob.glob(input_path + 'sso_*.pkl') gt_files = glob.glob(gt_path + 'sso_*.pkl') if not os.path.isdir(output_path): os.makedirs(output_path) # find corresponding gt file, map predictions to labels and save images for file in tqdm(files): for gt_file in gt_files: slashs = [pos for pos, char in enumerate(gt_file) if char == '/'] name = gt_file[slashs[-1] + 1:-4] if name in file: obj = objects.load_obj(data_type, gt_file) pred = basics.load_pkl(file) obj.set_predictions(pred[1]) visualize_prediction(obj, output_path + name + '.png', random_seed=random_seed, save_to_image=save_to_image)
def __init__(self, data: Union[str, SuperSegmentationDataset], sample_num: int, density_mode: bool = True, bio_density: float = None, tech_density: int = None, ctx_size: int = None, transform: clouds.Compose = clouds.Compose( [clouds.Identity()]), specific: bool = False, data_type: str = 'ce', obj_feats: dict = None, label_mappings: List[Tuple[int, int]] = None, hybrid_mode: bool = False, splitting_redundancy: int = 1, label_remove: List[int] = None, sampling: bool = True, force_split: bool = False, padding: int = None, verbose: bool = False, split_on_demand: bool = False, split_jitter: int = 0, epoch_size: int = None, workers: int = 2, voxel_sizes: Optional[dict] = None, ssd_exclude: List[int] = None, ssd_include: List[int] = None, ssd_labels: str = None, exclude_borders: int = 0, rebalance: dict = None): """ Args: data: Path to objects saved as pickle files. Existing chunking information would be available in the folder 'splitted' at this location. sample_num: Number of vertices which should be sampled from the surface of each chunk. Should be equal to the capacity of the given network architecture. tech_density: poisson sampling density with which data set was preprocessed in point/um² bio_density: chunk sampling density in point/um². This determines the size of the chunks. If previous chunking information should be used, this information must be available in the splitted/ folder with 'bio_density' as name. transform: Transformations which should be applied to the chunks before returning them (e.g. see :func:`morphx.processing.clouds.Compose`) specific: Flag for setting mode of requesting specific or rather randomly drawn chunks. data_type: Type of dataset, 'ce': CloudEnsembles, 'hc': HybridClouds obj_feats: Only used when inputs are CloudEnsembles. Dict with feature array (1, n) keyed by the name of the corresponding object in the CloudEnsemble. The HybridCloud gets addressed with 'hc'. label_mappings: list of labels which should get replaced by other labels. E.g. [(1, 2), (3, 2)] means that the labels 1 and 3 will get replaced by 3. splitting_redundancy: indicates how many times each skeleton node is included in different contexts. label_remove: List of labels indicating which nodes should be removed from the dataset. This is is independent from the label_mappings, as the label removal is done during splitting. sampling: Flag for random sampling from the extracted subsets. force_split: Split dataset again even if splitting information exists. padding: add padded points if a subset contains less points than there should be sampled. verbose: Return additional information about size of subsets. split_on_demand: Do not generate splitting information in advance, but rather generate chunks on the fly. split_jitter: Used only if split_on_demand = True. Adds jitter to the context size of the generated chunks. epoch_size: Parameter for epoch size that can be used when dataset size is unknown and epoch size should somehow be bounded. workers: Number of workers in case of ssd dataset. voxel_sizes: Voxelization options in case of ssd dataset use. Given as dict with voxel sizes keyed by cell part identifier (e.g. 'sv' or 'mi'). exclude_borders: Offset radius (chunk_size - exclude_border) for excluding border regions of chunks from loss calculation. rebalance: dict for rebalancing of dataset if certain classes dominate. dict contains factor keyed by labels where the factor indicate how often the labels should get resampled. This was introduced for rebalancing the CMN ads dataset. Now this is outcommented and replaced by a hacky version for terminals. """ if type(data) == SuperSegmentationDataset: self._data = data else: self._data = os.path.expanduser(data) if not os.path.exists(self._data): os.makedirs(self._data) # --- split cells into chunks and save this split information to file for later loading --- if not split_on_demand: if not os.path.exists(self._data + 'splitted/'): os.makedirs(self._data + 'splitted/') self._splitfile = '' if density_mode: if bio_density is None or tech_density is None: raise ValueError( "Density mode requires bio_density and tech_density" ) self._splitfile = f'{self._data}splitted/d{bio_density}_p{sample_num}' \ f'_r{splitting_redundancy}_lr{label_remove}.pkl' else: if ctx_size is None: raise ValueError("Context mode requires chunk_size.") self._splitfile = f'{self._data}splitted/s{ctx_size}_r{splitting_redundancy}_lr{label_remove}.pkl' self._splitted_objs = None orig_splitfile = self._splitfile while os.path.exists(self._splitfile): if not force_split: # continue with existing split information with open(self._splitfile, 'rb') as f: self._splitted_objs = pickle.load(f) f.close() break else: # generate new split information without overriding the old version = re.findall(r"v(\d+).", self._splitfile) if len(version) == 0: self._splitfile = self._splitfile[:-4] + '_v1.pkl' else: version = int(version[0]) self._splitfile = orig_splitfile[:-4] + f'_v{version + 1}.pkl' # actual splitting happens here splitting.split(data, self._splitfile, bio_density=bio_density, capacity=sample_num, tech_density=tech_density, density_splitting=density_mode, chunk_size=ctx_size, splitted_hcs=self._splitted_objs, redundancy=splitting_redundancy, label_remove=label_remove, split_jitter=split_jitter) with open(self._splitfile, 'rb') as f: self._splitted_objs = pickle.load(f) f.close() self._voxel_sizes = dict(sv=80, mi=100, syn_ssv=100, vc=100) if voxel_sizes is not None: self._voxel_sizes = voxel_sizes self._sample_num = sample_num self._transform = transform self._specific = specific self._data_type = data_type self._obj_feats = obj_feats self._label_mappings = label_mappings self._hybrid_mode = hybrid_mode self._label_remove = label_remove self._sampling = sampling self._padding = padding self._verbose = verbose self._split_on_demand = split_on_demand self._bio_density = bio_density self._tech_density = tech_density self._density_mode = density_mode self._chunk_size = ctx_size self._splitting_redundancy = splitting_redundancy self._split_jitter = split_jitter self._epoch_size = epoch_size self._workers = workers self._ssd_labels = ssd_labels self._ssd_exclude = ssd_exclude self._rebalance = rebalance self._exclude_borders = exclude_borders if ssd_exclude is None: self._ssd_exclude = [] self._ssd_include = ssd_include if self._ssd_labels is None and type( self._data) == SuperSegmentationDataset: raise ValueError( "ssd_labels must be specified when working with a SuperSegmentationDataset!" ) self._obj_names = [] self._objs = [] self._chunk_list = [] self._parts = {} if type(data) == SuperSegmentationDataset: self._load_func = self.get_item_ssd elif self._specific: self._load_func = self.get_item_specific else: self._load_func = self.get_item # --- dataloader for experiments when using CMN predictions as ground truth --- if type(self._data) == SuperSegmentationDataset: for key in self._obj_feats: self._parts[key] = [ self._voxel_sizes[key], self._obj_feats[key] ] # If ssd dataset is given, multiple workers are used for splitting the ssvs of the given dataset. self._obj_names = Queue() self._chunk_list = Queue(maxsize=10000) if self._ssd_include is None: sizes = [sso.size for sso in self._data.ssvs] idcs = np.argsort(sizes) self._ssd_include = np.array(self._data.ssv_ids)[idcs[-200:]] for ssv in self._ssd_include: if ssv not in self._ssd_exclude: self._obj_names.put(ssv) self._splitters = [ Process(target=worker_split, args=(self._obj_names, self._chunk_list, self._data, self._chunk_size, self._chunk_size / self._splitting_redundancy, self._parts, self._ssd_labels, self._label_mappings, self._split_jitter)) for ix in range(workers) ] for splitter in self._splitters: splitter.start() # --- dataloader for experiments with cells saved as pickle files --- else: files = glob.glob(data + '*.pkl') for file in files: slashs = [pos for pos, char in enumerate(file) if char == '/'] name = file[slashs[-1] + 1:-4] self._obj_names.append(name) if not self._specific: # load entire dataset into memory obj = self._adapt_obj( objects.load_obj(self._data_type, file)) self._objs.append(obj) if not self._specific: if split_on_demand: # do not use split information from file but split cells on the fly for ix, obj in enumerate(tqdm(self._objs)): base_nodes = np.arange(len(obj.nodes)).reshape( -1, 1)[obj.node_labels != -1] base_nodes = np.random.choice(base_nodes, int(len(base_nodes) / 3), replace=True) chunks = context_splitting_kdt(obj, base_nodes, self._chunk_size) for chunk in chunks: self._chunk_list.append((ix, chunk)) else: # use split information from file for item in self._splitted_objs: if item in self._obj_names: for idx in range(len(self._splitted_objs[item])): self._chunk_list.append((item, idx)) if self._rebalance is not None: # rebalance occurence of chunks by using chunks which contain specific labels multiple times print("Rebalancing...") balance = {} for key in self._rebalance: balance[key] = 0 for ix in tqdm(range(len(self._chunk_list))): item = self._chunk_list[ix] obj = self._objs[self._obj_names.index(item[0])] for key in self._rebalance: if key in np.unique(obj.labels): for i in range(self._rebalance[key]): self._chunk_list.append(item) balance[key] += 1 print("Done with rebalancing!") print(balance) random.shuffle(self._chunk_list) self._curr_obj = None self._curr_name = None self._ix = 0 self._size = len(self._chunk_list)
def evaluate_cell_predictions(prediction: str, total_evaluation: dict = None, evaluation_mode: str = 'mv', label_names: list = None, skeleton_smoothing: bool = False, remove_unpredicted: bool = True, data_type: str = 'obj', label_mapping: List[Tuple[int, int]] = None, label_remove: List[int] = None) -> tuple: """ Can be used to evaluation single cell prediction. See `evaluate_prediction_set` for argument descriptions. """ reports = {} reports_txt = "" prediction = os.path.expanduser(prediction) # --- load predictions into corresponding cell (predictions contain pointer to original cell file) --- vertex_predictions = basics.load_pkl(prediction) obj = objects.load_obj(data_type, vertex_predictions[0]) if label_remove is not None: obj.remove_nodes(label_remove) obj.set_predictions(vertex_predictions[1]) reports['pred_num'] = obj.pred_num if label_mapping is not None: obj.map_labels(label_mapping) # --- reduce multiple predictions to single label by either: # d: taking first prediction as label # mv: taking majority vote on all predictions as new label # mvs: taking majority vote on all predictions as label and apply smoothing --- if evaluation_mode == 'd': obj.generate_pred_labels(False) elif evaluation_mode == 'mv': obj.generate_pred_labels() elif evaluation_mode == 'mvs': obj.generate_pred_labels() obj.hc.prediction_smoothing() else: raise ValueError(f"Mode {evaluation_mode} is not known.") if isinstance(obj, CloudEnsemble): hc = obj.hc else: hc = obj if len(hc.pred_labels) != len(hc.labels): raise ValueError( "Length of predicted label array doesn't match with length of label array." ) coverage = hc.get_coverage() reports['cov'] = coverage # --- vertex evaluation --- vertex_labels = hc.labels vertex_predictions = hc.pred_labels if remove_unpredicted: mask = np.logical_and(vertex_labels != -1, vertex_predictions != -1) vertex_labels, vertex_predictions = vertex_labels[ mask], vertex_predictions[mask] targets = get_target_names(vertex_labels, vertex_predictions, label_names) reports[evaluation_mode] = sm.classification_report(vertex_labels, vertex_predictions, output_dict=True, target_names=targets) reports_txt += evaluation_mode + '\n\n' + \ f'Coverage: {coverage[1] - coverage[0]} of {coverage[1]}, ' \ f'{round((1 - coverage[0] / coverage[1]) * 100)} %\n\n' + \ f'Number of predictions: {obj.pred_num}\n\n' + \ sm.classification_report(vertex_labels, vertex_predictions, target_names=targets) + '\n\n' cm = sm.confusion_matrix(vertex_labels, vertex_predictions) reports_txt += write_confusion_matrix(cm, targets) + '\n\n' # --- skeleton evaluation --- evaluation_mode += '_skel' if skeleton_smoothing: hc.node_sliding_window_bfs(neighbor_num=20) node_labels = hc.node_labels node_predictions = hc.pred_node_labels if remove_unpredicted: mask = np.logical_and(node_labels != -1, node_predictions != -1) node_labels, node_predictions = node_labels[mask], node_predictions[ mask] targets = get_target_names(node_labels, node_predictions, label_names) reports[evaluation_mode] = sm.classification_report(node_labels, node_predictions, output_dict=True, target_names=targets) reports_txt += evaluation_mode + '\n\n' + sm.classification_report( node_labels, node_predictions, target_names=targets) + '\n\n' cm = sm.confusion_matrix(node_labels, node_predictions) reports_txt += write_confusion_matrix(cm, targets) + '\n\n' # --- save generated labels for total evaluation --- if total_evaluation is not None: total_evaluation['vertex_predictions'] = np.append( total_evaluation['vertex_predictions'], vertex_predictions) total_evaluation['node_predictions'] = np.append( total_evaluation['node_predictions'], node_predictions) total_evaluation['vertex_labels'] = np.append( total_evaluation['vertex_labels'], vertex_labels) total_evaluation['node_labels'] = np.append( total_evaluation['node_labels'], node_labels) total_evaluation['coverage'][0] += coverage[0] total_evaluation['coverage'][1] += coverage[1] return reports, reports_txt
"Recall: What percentage of points from one true label have been predicted as that label \n\n" reports = {} target_names = [ 'dendrite', 'axon', 'soma', 'bouton', 'terminal', 'neck', 'head' ] for sso_id in sso_ids: reports_txt += str(sso_id) + '\n\n' # 0: axon, 1: bouton, 2: terminal with open( base_path + f'abt/20_09_27_test_eval_red{red}_valiter1_batchsize-1/epoch_370/sso_' + str(sso_id) + '_preds.pkl', 'rb') as f: abt_preds = pkl.load(f) abt = objects.load_obj('ce', abt_preds[0]) abt.set_predictions(abt_preds[1]) abt.generate_pred_labels() # 0: dendrite, 1: neck, 2: head with open( base_path + f'dnh/20_09_27_test_eval_red{red}_valiter1_batchsize-1/epoch_390/sso_' + str(sso_id) + '_preds.pkl', 'rb') as f: dnh_preds = pkl.load(f) dnh = objects.load_obj('ce', dnh_preds[0]) dnh.set_predictions(dnh_preds[1]) dnh.generate_pred_labels() # 0: dendrite, 1: axon, 2: soma, with open( base_path + f'ads/20_09_27_test_eval_red{red}_nokdt_valiter1_batchsize-1/epoch_760/sso_'