def run_training(learn: Learner, resume_ckpt: Path, min_lr=0.005, head_runs=1, full_runs=1): if resume_ckpt: print(f'Loading {resume_ckpt}...') try: learn.model.load_state_dict(torch.load(resume_ckpt)) except Exception as e: print(f'Error while trying to load {resume_ckpt}: {e}') monitor.display_average_stats_per_gpu() print(f"Training for {head_runs}+{full_runs} epochs at min LR {min_lr}") learn.fine_tune(full_runs, min_lr, freeze_epochs=head_runs)
def run_training(learn: Learner, min_lr=0.05, head_runs=1, full_runs=1): monitor.display_average_stats_per_gpu() print(f"Training for {head_runs}+{full_runs} epochs at min LR {min_lr}") learn.fine_tune(full_runs, min_lr, freeze_epochs=head_runs)
class EnsembleLearner(GetAttr): _default = 'config' def __init__(self, image_dir='images', mask_dir=None, config=None, path=None, ensemble_path=None, preproc_dir=None, label_fn=None, metrics=None, cbs=None, ds_kwargs={}, dl_kwargs={}, model_kwargs={}, stats=None, files=None): self.config = config or Config() self.stats = stats self.dl_kwargs = dl_kwargs self.model_kwargs = model_kwargs self.add_ds_kwargs = ds_kwargs self.path = Path(path) if path is not None else Path('.') default_metrics = [Dice()] if self.n_classes == 2 else [DiceMulti()] self.metrics = metrics or default_metrics self.loss_fn = self.get_loss() self.cbs = cbs or [ SaveModelCallback( monitor='dice' if self.n_classes == 2 else 'dice_multi') ] #ShowGraphCallback self.ensemble_dir = ensemble_path or self.path / self.ens_dir if ensemble_path is not None: ensemble_path.mkdir(exist_ok=True, parents=True) self.load_ensemble(path=ensemble_path) else: self.models = {} self.files = L(files) or get_image_files(self.path / image_dir, recurse=False) assert len( self.files ) > 0, f'Found {len(self.files)} images in "{image_dir}". Please check your images and image folder' if any([mask_dir, label_fn]): if label_fn: self.label_fn = label_fn else: self.label_fn = get_label_fn(self.files[0], self.path / mask_dir) #Check if corresponding masks exist mask_check = [self.label_fn(x).exists() for x in self.files] chk_str = f'Found {len(self.files)} images in "{image_dir}" and {sum(mask_check)} masks in "{mask_dir}".' assert len(self.files) == sum(mask_check) and len( self.files ) > 0, f'Please check your images and masks (and folders). {chk_str}' print(chk_str) else: self.label_fn = label_fn self.n_splits = min(len(self.files), self.max_splits) self._set_splits() self.ds = RandomTileDataset( self.files, label_fn=self.label_fn, preproc_dir=preproc_dir, instance_labels=self.instance_labels, n_classes=self.n_classes, stats=self.stats, normalize=True, sample_mult=self.sample_mult if self.sample_mult > 0 else None, verbose=0, **self.add_ds_kwargs) self.stats = self.ds.stats self.in_channels = self.ds.get_data(max_n=1)[0].shape[-1] self.df_val, self.df_ens, self.df_model, self.ood = None, None, None, None self.recorder = {} def _set_splits(self): if self.n_splits > 1: kf = KFold(self.n_splits, shuffle=True, random_state=self.random_state) self.splits = { key: (self.files[idx[0]], self.files[idx[1]]) for key, idx in zip(range(1, self.n_splits + 1), kf.split(self.files)) } else: self.splits = {1: (self.files[0], self.files[0])} def _compose_albumentations(self, **kwargs): return _compose_albumentations(**kwargs) @property def pred_ds_kwargs(self): # Setting default shapes and padding ds_kwargs = self.add_ds_kwargs.copy() ds_kwargs['use_preprocessed_labels'] = True ds_kwargs['preproc_dir'] = self.ds.preproc_dir ds_kwargs['instance_labels'] = self.instance_labels ds_kwargs['tile_shape'] = (self.tile_shape, ) * 2 ds_kwargs['n_classes'] = self.n_classes ds_kwargs['shift'] = self.shift ds_kwargs['border_padding_factor'] = self.border_padding_factor return ds_kwargs @property def train_ds_kwargs(self): # Setting default shapes and padding ds_kwargs = self.add_ds_kwargs.copy() # Settings from config ds_kwargs['use_preprocessed_labels'] = True ds_kwargs['preproc_dir'] = self.ds.preproc_dir ds_kwargs['instance_labels'] = self.instance_labels ds_kwargs['stats'] = self.stats ds_kwargs['tile_shape'] = (self.tile_shape, ) * 2 ds_kwargs['n_classes'] = self.n_classes ds_kwargs['shift'] = 1. ds_kwargs['border_padding_factor'] = 0. ds_kwargs['flip'] = self.flip ds_kwargs['albumentations_tfms'] = self._compose_albumentations( **self.albumentation_kwargs) ds_kwargs[ 'sample_mult'] = self.sample_mult if self.sample_mult > 0 else None return ds_kwargs @property def model_name(self): return f'{self.arch}_{self.encoder_name}_{self.n_classes}classes' def get_loss(self): kwargs = { 'mode': self.mode, 'classes': [x for x in range(1, self.n_classes)], 'smooth_factor': self.loss_smooth_factor, 'alpha': self.loss_alpha, 'beta': self.loss_beta, 'gamma': self.loss_gamma } return get_loss(self.loss, **kwargs) def _get_dls(self, files, files_val=None): ds = [] ds.append( RandomTileDataset(files, label_fn=self.label_fn, **self.train_ds_kwargs)) if files_val: ds.append( TileDataset(files_val, label_fn=self.label_fn, **self.train_ds_kwargs)) else: ds.append(ds[0]) dls = DataLoaders.from_dsets(*ds, bs=self.batch_size, pin_memory=True, **self.dl_kwargs) if torch.cuda.is_available(): dls.cuda() return dls def _create_model(self): model = create_smp_model(arch=self.arch, encoder_name=self.encoder_name, encoder_weights=self.encoder_weights, in_channels=self.in_channels, classes=self.n_classes, **self.model_kwargs) if torch.cuda.is_available(): model.cuda() return model def fit(self, i, n_iter=None, base_lr=None, **kwargs): n_iter = n_iter or self.n_iter base_lr = base_lr or self.base_lr name = self.ensemble_dir / f'{self.model_name}-fold{i}.pth' model = self._create_model() files_train, files_val = self.splits[i] dls = self._get_dls(files_train, files_val) log_name = f'{name.name}_{time.strftime("%Y%m%d-%H%M%S")}.csv' log_dir = self.ensemble_dir / 'logs' log_dir.mkdir(exist_ok=True, parents=True) cbs = self.cbs.append(CSVLogger(fname=log_dir / log_name)) self.learn = Learner(dls, model, metrics=self.metrics, wd=self.weight_decay, loss_func=self.loss_fn, opt_func=_optim_dict[self.optim], cbs=self.cbs) self.learn.model_dir = self.ensemble_dir.parent / '.tmp' if self.mixed_precision_training: self.learn.to_fp16() print(f'Starting training for {name.name}') epochs = calc_iterations(n_iter=n_iter, ds_length=len(dls.train_ds), bs=self.batch_size) #self.learn.fit_one_cycle(epochs, lr_max) self.learn.fine_tune(epochs, base_lr=base_lr) print(f'Saving model at {name}') name.parent.mkdir(exist_ok=True, parents=True) save_smp_model(self.learn.model, self.arch, name, stats=self.stats) self.models[i] = name self.recorder[i] = self.learn.recorder def fit_ensemble(self, n_iter, skip=False, **kwargs): for i in range(1, self.n_models + 1): if skip and (i in self.models): continue self.fit(i, n_iter, **kwargs) def set_n(self, n): for i in range(n, len(self.models)): self.models.pop(i + 1, None) self.n_models = n def get_valid_results(self, model_no=None, zarr_store=None, export_dir=None, filetype='.png', **kwargs): res_list = [] model_list = self.models if not model_no else { k: v for k, v in self.models.items() if k == model_no } if export_dir: export_dir = Path(export_dir) pred_path = export_dir / 'masks' pred_path.mkdir(parents=True, exist_ok=True) unc_path = export_dir / 'uncertainties' unc_path.mkdir(parents=True, exist_ok=True) for i, model_path in model_list.items(): ep = EnsemblePredict(models_paths=[model_path], zarr_store=zarr_store) _, files_val = self.splits[i] g_smx, g_std, g_eng = ep.predict_images( files_val, bs=self.batch_size, ds_kwargs=self.pred_ds_kwargs, **kwargs) del ep torch.cuda.empty_cache() chunk_store = g_smx.chunk_store.path for j, f in enumerate(files_val): msk = self.ds.get_data(f, mask=True)[0] pred = np.argmax(g_smx[f.name][:], axis=-1).astype('uint8') m_dice = dice_score(msk, pred) m_path = self.models[i].name df_tmp = pd.Series({ 'file': f.name, 'model': m_path, 'model_no': i, 'dice_score': m_dice, #'mean_energy': np.mean(g_eng[f.name][:][pred>0]), 'uncertainty_score': np.mean(g_std[f.name][:][pred > 0]) if g_std is not None else None, 'image_path': f, 'mask_path': self.label_fn(f), 'softmax_path': f'{chunk_store}/{g_smx.path}/{f.name}', 'engergy_path': f'{chunk_store}/{g_eng.path}/{f.name}' if g_eng is not None else None, 'uncertainty_path': f'{chunk_store}/{g_std.path}/{f.name}' if g_std is not None else None }) res_list.append(df_tmp) if export_dir: save_mask(pred, pred_path / f'{df_tmp.file}_{df_tmp.model}_mask', filetype) if g_std is not None: save_unc( g_std[f.name][:], unc_path / f'{df_tmp.file}_{df_tmp.model}_uncertainty', filetype) if g_eng is not None: save_unc( g_eng[f.name][:], unc_path / f'{df_tmp.file}_{df_tmp.model}_energy', filetype) self.df_val = pd.DataFrame(res_list) if export_dir: self.df_val.to_csv(export_dir / f'val_results.csv', index=False) self.df_val.to_excel(export_dir / f'val_results.xlsx') return self.df_val def show_valid_results(self, model_no=None, files=None, **kwargs): if self.df_val is None: self.get_valid_results(**kwargs) df = self.df_val if files is not None: df = df.set_index('file', drop=False).loc[files] if model_no is not None: df = df[df.model_no == model_no] for _, r in df.iterrows(): img = self.ds.get_data(r.image_path)[0][:] msk = self.ds.get_data(r.image_path, mask=True)[0] pred = np.argmax(zarr.load(r.softmax_path), axis=-1).astype('uint8') std = zarr.load(r.uncertainty_path) _d_model = f'Model {r.model_no}' if self.tta: plot_results(img, msk, pred, std, df=r, model=_d_model) else: plot_results(img, msk, pred, np.zeros_like(pred), df=r, model=_d_model) def load_ensemble(self, path=None): path = path or self.ensemble_dir models = sorted(get_files(path, extensions='.pth', recurse=False)) self.models = {} for i, m in enumerate(models, 1): if i == 0: self.n_classes = int(m.name.split('_')[2][0]) else: assert self.n_classes == int( m.name.split('_')[2][0] ), 'Check models. Models are trained on different number of classes.' self.models[i] = m if len(self.models) > 0: self.set_n(len(self.models)) print(f'Found {len(self.models)} models in folder {path}:') print([m.name for m in self.models.values()]) # Reset stats print(f'Loading stats from {self.models[1].name}') _, self.stats = load_smp_model(self.models[1]) def get_ensemble_results(self, files, zarr_store=None, export_dir=None, filetype='.png', **kwargs): ep = EnsemblePredict(models_paths=self.models.values(), zarr_store=zarr_store) g_smx, g_std, g_eng = ep.predict_images(files, bs=self.batch_size, ds_kwargs=self.pred_ds_kwargs, **kwargs) chunk_store = g_smx.chunk_store.path del ep torch.cuda.empty_cache() if export_dir: export_dir = Path(export_dir) pred_path = export_dir / 'masks' pred_path.mkdir(parents=True, exist_ok=True) unc_path = export_dir / 'uncertainties' unc_path.mkdir(parents=True, exist_ok=True) res_list = [] for f in files: pred = np.argmax(g_smx[f.name][:], axis=-1).astype('uint8') df_tmp = pd.Series({ 'file': f.name, 'ensemble': self.model_name, 'n_models': len(self.models), #'mean_energy': np.mean(g_eng[f.name][:][pred>0]), 'uncertainty_score': np.mean(g_std[f.name][:][pred > 0]) if g_std is not None else None, 'image_path': f, 'softmax_path': f'{chunk_store}/{g_smx.path}/{f.name}', 'uncertainty_path': f'{chunk_store}/{g_std.path}/{f.name}' if g_std is not None else None, 'energy_path': f'{chunk_store}/{g_eng.path}/{f.name}' if g_eng is not None else None }) res_list.append(df_tmp) if export_dir: save_mask(pred, pred_path / f'{df_tmp.file}_{df_tmp.ensemble}_mask', filetype) if g_std is not None: save_unc(g_std[f.name][:], unc_path / f'{df_tmp.file}_{df_tmp.ensemble}_unc', filetype) if g_eng is not None: save_unc( g_eng[f.name][:], unc_path / f'{df_tmp.file}_{df_tmp.ensemble}_energy', filetype) self.df_ens = pd.DataFrame(res_list) return g_smx, g_std, g_eng def score_ensemble_results(self, mask_dir=None, label_fn=None): if mask_dir is not None and label_fn is None: label_fn = get_label_fn(self.df_ens.image_path[0], self.path / mask_dir) for i, r in self.df_ens.iterrows(): if label_fn is not None: msk_path = self.label_fn(r.image_path) msk = _read_msk(msk_path, n_classes=self.n_classes, instance_labels=self.instance_labels) self.df_ens.loc[i, 'mask_path'] = msk_path else: msk = self.ds.labels[r.file][:] pred = np.argmax(zarr.load(r.softmax_path), axis=-1).astype('uint8') self.df_ens.loc[i, 'dice_score'] = dice_score(msk, pred) return self.df_ens def show_ensemble_results(self, files=None, unc=True, unc_metric=None, metric_name='dice_score'): assert self.df_ens is not None, "Please run `get_ensemble_results` first." df = self.df_ens if files is not None: df = df.reset_index().set_index('file', drop=False).loc[files] for _, r in df.iterrows(): imgs = [] imgs.append(_read_img(r.image_path)[:]) if metric_name in r.index: try: msk = self.ds.labels[r.file][:] except: msk = _read_msk(r.mask_path, n_classes=self.n_classes, instance_labels=self.instance_labels) imgs.append(msk) hastarget = True else: hastarget = False imgs.append( np.argmax(zarr.load(r.softmax_path), axis=-1).astype('uint8')) if unc: imgs.append(zarr.load(r.uncertainty_path)) plot_results(*imgs, df=r, hastarget=hastarget, metric_name=metric_name, unc_metric=unc_metric) def get_cellpose_results(self, export_dir=None): assert self.df_ens is not None, "Please run `get_ensemble_results` first." cl = self.cellpose_export_class assert cl < self.n_classes, f'{cl} not avaialable from {self.n_classes} classes' smxs, preds = [], [] for _, r in self.df_ens.iterrows(): softmax = zarr.load(r.softmax_path) smxs.append(softmax) preds.append(np.argmax(softmax, axis=-1).astype('uint8')) probs = [x[..., cl] for x in smxs] masks = [x == cl for x in preds] cp_masks = run_cellpose(probs, masks, model_type=self.cellpose_model, diameter=self.cellpose_diameter, min_size=self.min_pixel_export, gpu=torch.cuda.is_available()) if export_dir: export_dir = Path(export_dir) cp_path = export_dir / 'cellpose_masks' cp_path.mkdir(parents=True, exist_ok=True) for idx, r in self.df_ens.iterrows(): tifffile.imwrite(cp_path / f'{r.file}_class{cl}.tif', cp_masks[idx], compress=6) self.cellpose_masks = cp_masks return cp_masks def score_cellpose_results(self, mask_dir=None, label_fn=None): assert self.cellpose_masks is not None, 'Run get_cellpose_results() first' if mask_dir is not None and label_fn is None: label_fn = get_label_fn(self.df_ens.image_path[0], self.path / mask_dir) for i, r in self.df_ens.iterrows(): if label_fn is not None: msk_path = self.label_fn(r.image_path) msk = _read_msk(msk_path, n_classes=self.n_classes, instance_labels=self.instance_labels) self.df_ens.loc[i, 'mask_path'] = msk_path else: msk = self.ds.labels[r.file][:] _, msk = cv2.connectedComponents(msk, connectivity=4) pred = self.cellpose_masks[i] ap, tp, fp, fn = get_instance_segmentation_metrics( msk, pred, is_binary=False, min_pixel=self.min_pixel_export) self.df_ens.loc[i, 'mean_average_precision'] = ap.mean() self.df_ens.loc[i, 'average_precision_at_iou_50'] = ap[0] return self.df_ens def show_cellpose_results(self, files=None, unc=True, unc_metric=None, metric_name='mean_average_precision'): assert self.df_ens is not None, "Please run `get_ensemble_results` first." df = self.df_ens.reset_index() if files is not None: df = df.set_index('file', drop=False).loc[files] for _, r in df.iterrows(): imgs = [] imgs.append(_read_img(r.image_path)[:]) if metric_name in r.index: try: mask = self.ds.labels[idx][:] except: mask = _read_msk(r.mask_path, n_classes=self.n_classes, instance_labels=self.instance_labels) _, comps = cv2.connectedComponents( (mask == self.cellpose_export_class).astype('uint8'), connectivity=4) imgs.append(label2rgb(comps, bg_label=0)) hastarget = True else: hastarget = False imgs.append(label2rgb(self.cellpose_masks[r['index']], bg_label=0)) if unc: imgs.append(zarr.load(r.uncertainty_path)) plot_results(*imgs, df=r, hastarget=hastarget, metric_name=metric_name, unc_metric=unc_metric) def lr_find(self, files=None, **kwargs): files = files or self.files dls = self._get_dls(files) model = self._create_model() learn = Learner(dls, model, metrics=self.metrics, wd=self.weight_decay, loss_func=self.loss_fn, opt_func=_optim_dict[self.optim]) if self.mixed_precision_training: learn.to_fp16() sug_lrs = learn.lr_find(**kwargs) return sug_lrs, learn.recorder def export_imagej_rois(self, output_folder='ROI_sets', **kwargs): assert self.df_ens is not None, "Please run prediction first." output_folder = Path(output_folder) output_folder.mkdir(exist_ok=True, parents=True) for idx, r in progress_bar(self.df_ens.iterrows(), total=len(self.df_ens)): mask = np.argmax(zarr.load(r.softmax_path), axis=-1).astype('uint8') uncertainty = zarr.load(r.uncertainty_path) export_roi_set(mask, uncertainty, name=r.file, path=output_folder, ascending=False, **kwargs) def export_cellpose_rois(self, output_folder='cellpose_ROI_sets', **kwargs): output_folder = Path(output_folder) output_folder.mkdir(exist_ok=True, parents=True) for idx, r in progress_bar(self.df_ens.iterrows(), total=len(self.df_ens)): mask = self.cellpose_masks[idx] uncertainty = zarr.load(r.uncertainty_path) export_roi_set(mask, uncertainty, instance_labels=True, name=r.file, path=output_folder, ascending=False, **kwargs)