def run_training(learn: Learner,
                 resume_ckpt: Path,
                 min_lr=0.005,
                 head_runs=1,
                 full_runs=1):
    if resume_ckpt:
        print(f'Loading {resume_ckpt}...')
        try:
            learn.model.load_state_dict(torch.load(resume_ckpt))
        except Exception as e:
            print(f'Error while trying to load {resume_ckpt}: {e}')
    monitor.display_average_stats_per_gpu()
    print(f"Training for {head_runs}+{full_runs} epochs at min LR {min_lr}")
    learn.fine_tune(full_runs, min_lr, freeze_epochs=head_runs)
Exemple #2
0
def run_training(learn: Learner, min_lr=0.05, head_runs=1, full_runs=1):
    monitor.display_average_stats_per_gpu()
    print(f"Training for {head_runs}+{full_runs} epochs at min LR {min_lr}")
    learn.fine_tune(full_runs, min_lr, freeze_epochs=head_runs)
Exemple #3
0
class EnsembleLearner(GetAttr):
    _default = 'config'

    def __init__(self,
                 image_dir='images',
                 mask_dir=None,
                 config=None,
                 path=None,
                 ensemble_path=None,
                 preproc_dir=None,
                 label_fn=None,
                 metrics=None,
                 cbs=None,
                 ds_kwargs={},
                 dl_kwargs={},
                 model_kwargs={},
                 stats=None,
                 files=None):

        self.config = config or Config()
        self.stats = stats
        self.dl_kwargs = dl_kwargs
        self.model_kwargs = model_kwargs
        self.add_ds_kwargs = ds_kwargs
        self.path = Path(path) if path is not None else Path('.')
        default_metrics = [Dice()] if self.n_classes == 2 else [DiceMulti()]
        self.metrics = metrics or default_metrics
        self.loss_fn = self.get_loss()
        self.cbs = cbs or [
            SaveModelCallback(
                monitor='dice' if self.n_classes == 2 else 'dice_multi')
        ]  #ShowGraphCallback
        self.ensemble_dir = ensemble_path or self.path / self.ens_dir
        if ensemble_path is not None:
            ensemble_path.mkdir(exist_ok=True, parents=True)
            self.load_ensemble(path=ensemble_path)
        else:
            self.models = {}

        self.files = L(files) or get_image_files(self.path / image_dir,
                                                 recurse=False)
        assert len(
            self.files
        ) > 0, f'Found {len(self.files)} images in "{image_dir}". Please check your images and image folder'
        if any([mask_dir, label_fn]):
            if label_fn: self.label_fn = label_fn
            else:
                self.label_fn = get_label_fn(self.files[0],
                                             self.path / mask_dir)
            #Check if corresponding masks exist
            mask_check = [self.label_fn(x).exists() for x in self.files]
            chk_str = f'Found {len(self.files)} images in "{image_dir}" and {sum(mask_check)} masks in "{mask_dir}".'
            assert len(self.files) == sum(mask_check) and len(
                self.files
            ) > 0, f'Please check your images and masks (and folders). {chk_str}'
            print(chk_str)

        else:
            self.label_fn = label_fn
        self.n_splits = min(len(self.files), self.max_splits)
        self._set_splits()
        self.ds = RandomTileDataset(
            self.files,
            label_fn=self.label_fn,
            preproc_dir=preproc_dir,
            instance_labels=self.instance_labels,
            n_classes=self.n_classes,
            stats=self.stats,
            normalize=True,
            sample_mult=self.sample_mult if self.sample_mult > 0 else None,
            verbose=0,
            **self.add_ds_kwargs)

        self.stats = self.ds.stats
        self.in_channels = self.ds.get_data(max_n=1)[0].shape[-1]
        self.df_val, self.df_ens, self.df_model, self.ood = None, None, None, None
        self.recorder = {}

    def _set_splits(self):
        if self.n_splits > 1:
            kf = KFold(self.n_splits,
                       shuffle=True,
                       random_state=self.random_state)
            self.splits = {
                key: (self.files[idx[0]], self.files[idx[1]])
                for key, idx in zip(range(1, self.n_splits +
                                          1), kf.split(self.files))
            }
        else:
            self.splits = {1: (self.files[0], self.files[0])}

    def _compose_albumentations(self, **kwargs):
        return _compose_albumentations(**kwargs)

    @property
    def pred_ds_kwargs(self):
        # Setting default shapes and padding
        ds_kwargs = self.add_ds_kwargs.copy()
        ds_kwargs['use_preprocessed_labels'] = True
        ds_kwargs['preproc_dir'] = self.ds.preproc_dir
        ds_kwargs['instance_labels'] = self.instance_labels
        ds_kwargs['tile_shape'] = (self.tile_shape, ) * 2
        ds_kwargs['n_classes'] = self.n_classes
        ds_kwargs['shift'] = self.shift
        ds_kwargs['border_padding_factor'] = self.border_padding_factor
        return ds_kwargs

    @property
    def train_ds_kwargs(self):
        # Setting default shapes and padding
        ds_kwargs = self.add_ds_kwargs.copy()
        # Settings from config
        ds_kwargs['use_preprocessed_labels'] = True
        ds_kwargs['preproc_dir'] = self.ds.preproc_dir
        ds_kwargs['instance_labels'] = self.instance_labels
        ds_kwargs['stats'] = self.stats
        ds_kwargs['tile_shape'] = (self.tile_shape, ) * 2
        ds_kwargs['n_classes'] = self.n_classes
        ds_kwargs['shift'] = 1.
        ds_kwargs['border_padding_factor'] = 0.
        ds_kwargs['flip'] = self.flip
        ds_kwargs['albumentations_tfms'] = self._compose_albumentations(
            **self.albumentation_kwargs)
        ds_kwargs[
            'sample_mult'] = self.sample_mult if self.sample_mult > 0 else None
        return ds_kwargs

    @property
    def model_name(self):
        return f'{self.arch}_{self.encoder_name}_{self.n_classes}classes'

    def get_loss(self):
        kwargs = {
            'mode': self.mode,
            'classes': [x for x in range(1, self.n_classes)],
            'smooth_factor': self.loss_smooth_factor,
            'alpha': self.loss_alpha,
            'beta': self.loss_beta,
            'gamma': self.loss_gamma
        }
        return get_loss(self.loss, **kwargs)

    def _get_dls(self, files, files_val=None):
        ds = []
        ds.append(
            RandomTileDataset(files,
                              label_fn=self.label_fn,
                              **self.train_ds_kwargs))
        if files_val:
            ds.append(
                TileDataset(files_val,
                            label_fn=self.label_fn,
                            **self.train_ds_kwargs))
        else:
            ds.append(ds[0])
        dls = DataLoaders.from_dsets(*ds,
                                     bs=self.batch_size,
                                     pin_memory=True,
                                     **self.dl_kwargs)
        if torch.cuda.is_available(): dls.cuda()
        return dls

    def _create_model(self):
        model = create_smp_model(arch=self.arch,
                                 encoder_name=self.encoder_name,
                                 encoder_weights=self.encoder_weights,
                                 in_channels=self.in_channels,
                                 classes=self.n_classes,
                                 **self.model_kwargs)
        if torch.cuda.is_available(): model.cuda()
        return model

    def fit(self, i, n_iter=None, base_lr=None, **kwargs):
        n_iter = n_iter or self.n_iter
        base_lr = base_lr or self.base_lr
        name = self.ensemble_dir / f'{self.model_name}-fold{i}.pth'
        model = self._create_model()
        files_train, files_val = self.splits[i]
        dls = self._get_dls(files_train, files_val)
        log_name = f'{name.name}_{time.strftime("%Y%m%d-%H%M%S")}.csv'
        log_dir = self.ensemble_dir / 'logs'
        log_dir.mkdir(exist_ok=True, parents=True)
        cbs = self.cbs.append(CSVLogger(fname=log_dir / log_name))
        self.learn = Learner(dls,
                             model,
                             metrics=self.metrics,
                             wd=self.weight_decay,
                             loss_func=self.loss_fn,
                             opt_func=_optim_dict[self.optim],
                             cbs=self.cbs)
        self.learn.model_dir = self.ensemble_dir.parent / '.tmp'
        if self.mixed_precision_training: self.learn.to_fp16()
        print(f'Starting training for {name.name}')
        epochs = calc_iterations(n_iter=n_iter,
                                 ds_length=len(dls.train_ds),
                                 bs=self.batch_size)
        #self.learn.fit_one_cycle(epochs, lr_max)
        self.learn.fine_tune(epochs, base_lr=base_lr)

        print(f'Saving model at {name}')
        name.parent.mkdir(exist_ok=True, parents=True)
        save_smp_model(self.learn.model, self.arch, name, stats=self.stats)
        self.models[i] = name
        self.recorder[i] = self.learn.recorder

    def fit_ensemble(self, n_iter, skip=False, **kwargs):
        for i in range(1, self.n_models + 1):
            if skip and (i in self.models): continue
            self.fit(i, n_iter, **kwargs)

    def set_n(self, n):
        for i in range(n, len(self.models)):
            self.models.pop(i + 1, None)
        self.n_models = n

    def get_valid_results(self,
                          model_no=None,
                          zarr_store=None,
                          export_dir=None,
                          filetype='.png',
                          **kwargs):
        res_list = []
        model_list = self.models if not model_no else {
            k: v
            for k, v in self.models.items() if k == model_no
        }
        if export_dir:
            export_dir = Path(export_dir)
            pred_path = export_dir / 'masks'
            pred_path.mkdir(parents=True, exist_ok=True)
            unc_path = export_dir / 'uncertainties'
            unc_path.mkdir(parents=True, exist_ok=True)

        for i, model_path in model_list.items():
            ep = EnsemblePredict(models_paths=[model_path],
                                 zarr_store=zarr_store)
            _, files_val = self.splits[i]
            g_smx, g_std, g_eng = ep.predict_images(
                files_val,
                bs=self.batch_size,
                ds_kwargs=self.pred_ds_kwargs,
                **kwargs)
            del ep
            torch.cuda.empty_cache()

            chunk_store = g_smx.chunk_store.path
            for j, f in enumerate(files_val):
                msk = self.ds.get_data(f, mask=True)[0]
                pred = np.argmax(g_smx[f.name][:], axis=-1).astype('uint8')
                m_dice = dice_score(msk, pred)
                m_path = self.models[i].name
                df_tmp = pd.Series({
                    'file':
                    f.name,
                    'model':
                    m_path,
                    'model_no':
                    i,
                    'dice_score':
                    m_dice,
                    #'mean_energy': np.mean(g_eng[f.name][:][pred>0]),
                    'uncertainty_score':
                    np.mean(g_std[f.name][:][pred > 0])
                    if g_std is not None else None,
                    'image_path':
                    f,
                    'mask_path':
                    self.label_fn(f),
                    'softmax_path':
                    f'{chunk_store}/{g_smx.path}/{f.name}',
                    'engergy_path':
                    f'{chunk_store}/{g_eng.path}/{f.name}'
                    if g_eng is not None else None,
                    'uncertainty_path':
                    f'{chunk_store}/{g_std.path}/{f.name}'
                    if g_std is not None else None
                })
                res_list.append(df_tmp)
                if export_dir:
                    save_mask(pred,
                              pred_path / f'{df_tmp.file}_{df_tmp.model}_mask',
                              filetype)
                    if g_std is not None:
                        save_unc(
                            g_std[f.name][:], unc_path /
                            f'{df_tmp.file}_{df_tmp.model}_uncertainty',
                            filetype)
                    if g_eng is not None:
                        save_unc(
                            g_eng[f.name][:],
                            unc_path / f'{df_tmp.file}_{df_tmp.model}_energy',
                            filetype)
        self.df_val = pd.DataFrame(res_list)
        if export_dir:
            self.df_val.to_csv(export_dir / f'val_results.csv', index=False)
            self.df_val.to_excel(export_dir / f'val_results.xlsx')
        return self.df_val

    def show_valid_results(self, model_no=None, files=None, **kwargs):
        if self.df_val is None: self.get_valid_results(**kwargs)
        df = self.df_val
        if files is not None: df = df.set_index('file', drop=False).loc[files]
        if model_no is not None: df = df[df.model_no == model_no]
        for _, r in df.iterrows():
            img = self.ds.get_data(r.image_path)[0][:]
            msk = self.ds.get_data(r.image_path, mask=True)[0]
            pred = np.argmax(zarr.load(r.softmax_path),
                             axis=-1).astype('uint8')
            std = zarr.load(r.uncertainty_path)
            _d_model = f'Model {r.model_no}'
            if self.tta:
                plot_results(img, msk, pred, std, df=r, model=_d_model)
            else:
                plot_results(img,
                             msk,
                             pred,
                             np.zeros_like(pred),
                             df=r,
                             model=_d_model)

    def load_ensemble(self, path=None):
        path = path or self.ensemble_dir
        models = sorted(get_files(path, extensions='.pth', recurse=False))
        self.models = {}

        for i, m in enumerate(models, 1):
            if i == 0: self.n_classes = int(m.name.split('_')[2][0])
            else:
                assert self.n_classes == int(
                    m.name.split('_')[2][0]
                ), 'Check models. Models are trained on different number of classes.'
            self.models[i] = m

        if len(self.models) > 0:
            self.set_n(len(self.models))
            print(f'Found {len(self.models)} models in folder {path}:')
            print([m.name for m in self.models.values()])

            # Reset stats
            print(f'Loading stats from {self.models[1].name}')
            _, self.stats = load_smp_model(self.models[1])

    def get_ensemble_results(self,
                             files,
                             zarr_store=None,
                             export_dir=None,
                             filetype='.png',
                             **kwargs):
        ep = EnsemblePredict(models_paths=self.models.values(),
                             zarr_store=zarr_store)
        g_smx, g_std, g_eng = ep.predict_images(files,
                                                bs=self.batch_size,
                                                ds_kwargs=self.pred_ds_kwargs,
                                                **kwargs)
        chunk_store = g_smx.chunk_store.path
        del ep
        torch.cuda.empty_cache()

        if export_dir:
            export_dir = Path(export_dir)
            pred_path = export_dir / 'masks'
            pred_path.mkdir(parents=True, exist_ok=True)
            unc_path = export_dir / 'uncertainties'
            unc_path.mkdir(parents=True, exist_ok=True)

        res_list = []
        for f in files:
            pred = np.argmax(g_smx[f.name][:], axis=-1).astype('uint8')
            df_tmp = pd.Series({
                'file':
                f.name,
                'ensemble':
                self.model_name,
                'n_models':
                len(self.models),
                #'mean_energy': np.mean(g_eng[f.name][:][pred>0]),
                'uncertainty_score':
                np.mean(g_std[f.name][:][pred > 0])
                if g_std is not None else None,
                'image_path':
                f,
                'softmax_path':
                f'{chunk_store}/{g_smx.path}/{f.name}',
                'uncertainty_path':
                f'{chunk_store}/{g_std.path}/{f.name}'
                if g_std is not None else None,
                'energy_path':
                f'{chunk_store}/{g_eng.path}/{f.name}'
                if g_eng is not None else None
            })
            res_list.append(df_tmp)
            if export_dir:
                save_mask(pred,
                          pred_path / f'{df_tmp.file}_{df_tmp.ensemble}_mask',
                          filetype)
                if g_std is not None:
                    save_unc(g_std[f.name][:],
                             unc_path / f'{df_tmp.file}_{df_tmp.ensemble}_unc',
                             filetype)
                if g_eng is not None:
                    save_unc(
                        g_eng[f.name][:],
                        unc_path / f'{df_tmp.file}_{df_tmp.ensemble}_energy',
                        filetype)

        self.df_ens = pd.DataFrame(res_list)
        return g_smx, g_std, g_eng

    def score_ensemble_results(self, mask_dir=None, label_fn=None):
        if mask_dir is not None and label_fn is None:
            label_fn = get_label_fn(self.df_ens.image_path[0],
                                    self.path / mask_dir)
        for i, r in self.df_ens.iterrows():
            if label_fn is not None:
                msk_path = self.label_fn(r.image_path)
                msk = _read_msk(msk_path,
                                n_classes=self.n_classes,
                                instance_labels=self.instance_labels)
                self.df_ens.loc[i, 'mask_path'] = msk_path
            else:
                msk = self.ds.labels[r.file][:]
            pred = np.argmax(zarr.load(r.softmax_path),
                             axis=-1).astype('uint8')
            self.df_ens.loc[i, 'dice_score'] = dice_score(msk, pred)
        return self.df_ens

    def show_ensemble_results(self,
                              files=None,
                              unc=True,
                              unc_metric=None,
                              metric_name='dice_score'):
        assert self.df_ens is not None, "Please run `get_ensemble_results` first."
        df = self.df_ens
        if files is not None:
            df = df.reset_index().set_index('file', drop=False).loc[files]
        for _, r in df.iterrows():
            imgs = []
            imgs.append(_read_img(r.image_path)[:])
            if metric_name in r.index:
                try:
                    msk = self.ds.labels[r.file][:]
                except:
                    msk = _read_msk(r.mask_path,
                                    n_classes=self.n_classes,
                                    instance_labels=self.instance_labels)
                imgs.append(msk)
                hastarget = True
            else:
                hastarget = False
            imgs.append(
                np.argmax(zarr.load(r.softmax_path), axis=-1).astype('uint8'))
            if unc: imgs.append(zarr.load(r.uncertainty_path))
            plot_results(*imgs,
                         df=r,
                         hastarget=hastarget,
                         metric_name=metric_name,
                         unc_metric=unc_metric)

    def get_cellpose_results(self, export_dir=None):
        assert self.df_ens is not None, "Please run `get_ensemble_results` first."
        cl = self.cellpose_export_class
        assert cl < self.n_classes, f'{cl} not avaialable from {self.n_classes} classes'

        smxs, preds = [], []
        for _, r in self.df_ens.iterrows():
            softmax = zarr.load(r.softmax_path)
            smxs.append(softmax)
            preds.append(np.argmax(softmax, axis=-1).astype('uint8'))

        probs = [x[..., cl] for x in smxs]
        masks = [x == cl for x in preds]
        cp_masks = run_cellpose(probs,
                                masks,
                                model_type=self.cellpose_model,
                                diameter=self.cellpose_diameter,
                                min_size=self.min_pixel_export,
                                gpu=torch.cuda.is_available())

        if export_dir:
            export_dir = Path(export_dir)
            cp_path = export_dir / 'cellpose_masks'
            cp_path.mkdir(parents=True, exist_ok=True)
            for idx, r in self.df_ens.iterrows():
                tifffile.imwrite(cp_path / f'{r.file}_class{cl}.tif',
                                 cp_masks[idx],
                                 compress=6)

        self.cellpose_masks = cp_masks
        return cp_masks

    def score_cellpose_results(self, mask_dir=None, label_fn=None):
        assert self.cellpose_masks is not None, 'Run get_cellpose_results() first'
        if mask_dir is not None and label_fn is None:
            label_fn = get_label_fn(self.df_ens.image_path[0],
                                    self.path / mask_dir)
        for i, r in self.df_ens.iterrows():
            if label_fn is not None:
                msk_path = self.label_fn(r.image_path)
                msk = _read_msk(msk_path,
                                n_classes=self.n_classes,
                                instance_labels=self.instance_labels)
                self.df_ens.loc[i, 'mask_path'] = msk_path
            else:
                msk = self.ds.labels[r.file][:]
            _, msk = cv2.connectedComponents(msk, connectivity=4)
            pred = self.cellpose_masks[i]
            ap, tp, fp, fn = get_instance_segmentation_metrics(
                msk, pred, is_binary=False, min_pixel=self.min_pixel_export)
            self.df_ens.loc[i, 'mean_average_precision'] = ap.mean()
            self.df_ens.loc[i, 'average_precision_at_iou_50'] = ap[0]
        return self.df_ens

    def show_cellpose_results(self,
                              files=None,
                              unc=True,
                              unc_metric=None,
                              metric_name='mean_average_precision'):
        assert self.df_ens is not None, "Please run `get_ensemble_results` first."
        df = self.df_ens.reset_index()
        if files is not None: df = df.set_index('file', drop=False).loc[files]
        for _, r in df.iterrows():
            imgs = []
            imgs.append(_read_img(r.image_path)[:])
            if metric_name in r.index:
                try:
                    mask = self.ds.labels[idx][:]
                except:
                    mask = _read_msk(r.mask_path,
                                     n_classes=self.n_classes,
                                     instance_labels=self.instance_labels)
                _, comps = cv2.connectedComponents(
                    (mask == self.cellpose_export_class).astype('uint8'),
                    connectivity=4)
                imgs.append(label2rgb(comps, bg_label=0))
                hastarget = True
            else:
                hastarget = False
            imgs.append(label2rgb(self.cellpose_masks[r['index']], bg_label=0))
            if unc: imgs.append(zarr.load(r.uncertainty_path))
            plot_results(*imgs,
                         df=r,
                         hastarget=hastarget,
                         metric_name=metric_name,
                         unc_metric=unc_metric)

    def lr_find(self, files=None, **kwargs):
        files = files or self.files
        dls = self._get_dls(files)
        model = self._create_model()
        learn = Learner(dls,
                        model,
                        metrics=self.metrics,
                        wd=self.weight_decay,
                        loss_func=self.loss_fn,
                        opt_func=_optim_dict[self.optim])
        if self.mixed_precision_training: learn.to_fp16()
        sug_lrs = learn.lr_find(**kwargs)
        return sug_lrs, learn.recorder

    def export_imagej_rois(self, output_folder='ROI_sets', **kwargs):
        assert self.df_ens is not None, "Please run prediction first."

        output_folder = Path(output_folder)
        output_folder.mkdir(exist_ok=True, parents=True)
        for idx, r in progress_bar(self.df_ens.iterrows(),
                                   total=len(self.df_ens)):
            mask = np.argmax(zarr.load(r.softmax_path),
                             axis=-1).astype('uint8')
            uncertainty = zarr.load(r.uncertainty_path)
            export_roi_set(mask,
                           uncertainty,
                           name=r.file,
                           path=output_folder,
                           ascending=False,
                           **kwargs)

    def export_cellpose_rois(self,
                             output_folder='cellpose_ROI_sets',
                             **kwargs):
        output_folder = Path(output_folder)
        output_folder.mkdir(exist_ok=True, parents=True)
        for idx, r in progress_bar(self.df_ens.iterrows(),
                                   total=len(self.df_ens)):
            mask = self.cellpose_masks[idx]
            uncertainty = zarr.load(r.uncertainty_path)
            export_roi_set(mask,
                           uncertainty,
                           instance_labels=True,
                           name=r.file,
                           path=output_folder,
                           ascending=False,
                           **kwargs)