Exemplo n.º 1
0
 def lr_find(self, files=None, bs=None, n_jobs=-1, verbose=1, **kwargs):
     bs = bs or self.bs
     files = files or self.files
     train_ds = RandomTileDataset(files,
                                  label_fn=self.label_fn,
                                  n_jobs=n_jobs,
                                  verbose=verbose,
                                  **self.mw_kwargs,
                                  **self.ds_kwargs)
     dls = DataLoaders.from_dsets(train_ds, train_ds, bs=bs)
     pre = None if self.pretrained == 'new' else self.pretrained
     model = torch.hub.load(self.repo,
                            self.arch,
                            pretrained=pre,
                            n_classes=dls.c,
                            in_channels=self.in_channels)
     if torch.cuda.is_available(): dls.cuda(), model.cuda()
     learn = Learner(dls,
                     model,
                     metrics=self.metrics,
                     wd=self.wd,
                     loss_func=self.loss_fn,
                     opt_func=_optim_dict[self.optim])
     if self.mpt: learn.to_fp16()
     sug_lrs = learn.lr_find(**kwargs)
     return sug_lrs, learn.recorder
Exemplo n.º 2
0
 def lr_find(self, files=None, **kwargs):
     files = files or self.files
     dls = self._get_dls(files)
     model = self._create_model()
     learn = Learner(dls,
                     model,
                     metrics=self.metrics,
                     wd=self.weight_decay,
                     loss_func=self.loss_fn,
                     opt_func=_optim_dict[self.optim])
     if self.mixed_precision_training: learn.to_fp16()
     sug_lrs = learn.lr_find(**kwargs)
     return sug_lrs, learn.recorder
Exemplo n.º 3
0
 def lr_find(self, files=None, **kwargs):
     files = files or self.files
     dls = self.get_dls(files)
     pre = None if self.pretrained == 'new' else self.pretrained
     model = self.get_model(pretrained=pre)
     learn = Learner(dls,
                     model,
                     metrics=self.metrics,
                     wd=self.wd,
                     loss_func=self.loss_fn,
                     opt_func=_optim_dict[self.optim])
     if self.mpt: learn.to_fp16()
     sug_lrs = learn.lr_find(**kwargs)
     return sug_lrs, learn.recorder
Exemplo n.º 4
0
 def predict(self, files, model_no, path=None, **kwargs):
     model_path = self.models[model_no]
     model = self.load_model(model_path)
     ds_kwargs = self.ds_kwargs
     # Adding extra padding (overlap) for models that have the same input and output shape
     if ds_kwargs['padding'][0] == 0:
         ds_kwargs['padding'] = (self.extra_padding, ) * 2
     ds = TileDataset(files, **ds_kwargs)
     dls = DataLoaders.from_dsets(ds,
                                  batch_size=self.bs,
                                  after_batch=self.get_batch_tfms(),
                                  shuffle=False,
                                  drop_last=False,
                                  **self.dl_kwargs)
     if torch.cuda.is_available(): dls.cuda()
     learn = Learner(dls, model, loss_func=self.loss_fn)
     if self.mpt: learn.to_fp16()
     if path: path = path / f'model_{model_no}'
     return learn.predict_tiles(dl=dls.train, path=path, **kwargs)
Exemplo n.º 5
0
 def predict(self, files, model_no, bs=None, **kwargs):
     bs = bs or self.bs
     model_path = self.models[model_no]
     model = self.load_model(model_path)
     batch_tfms = Normalize.from_stats(*self.stats)
     ds = TileDataset(files, **self.ds_kwargs)
     dls = DataLoaders.from_dsets(ds,
                                  batch_size=bs,
                                  after_batch=batch_tfms,
                                  shuffle=False,
                                  drop_last=False,
                                  num_workers=0)
     if torch.cuda.is_available(): dls.cuda(), model.cuda()
     learn = Learner(dls, model, loss_func=self.loss_fn)
     if self.mpt: learn.to_fp16()
     results = learn.predict_tiles(dl=dls.train, **kwargs)
     pth_tmp = self.path / '.tmp' / model_path.name
     save_tmp(pth_tmp, files, results)
     return results
Exemplo n.º 6
0
class EnsembleLearner(GetAttr):
    _default = 'config'

    def __init__(self,
                 image_dir='images',
                 mask_dir=None,
                 config=None,
                 path=None,
                 ensemble_dir=None,
                 item_tfms=None,
                 label_fn=None,
                 metrics=None,
                 cbs=None,
                 ds_kwargs={},
                 dl_kwargs={},
                 model_kwargs={},
                 stats=None,
                 files=None):

        self.config = config or Config()
        self.stats = stats
        self.dl_kwargs = dl_kwargs
        self.model_kwargs = model_kwargs
        self.add_ds_kwargs = ds_kwargs
        self.item_tfms = item_tfms
        self.path = Path(path) if path is not None else Path('.')
        self.metrics = metrics or [Iou(), Dice_f1()]
        self.loss_fn = self.get_loss()
        self.cbs = cbs or [
            SaveModelCallback(monitor='iou'), ElasticDeformCallback
        ]  #ShowGraphCallback
        self.ensemble_dir = ensemble_dir or self.path / 'ensemble'

        self.files = L(files) or get_image_files(self.path / image_dir,
                                                 recurse=False)
        assert len(
            self.files
        ) > 0, f'Found {len(self.files)} images in "{image_dir}". Please check your images and image folder'
        if any([mask_dir, label_fn]):
            if label_fn: self.label_fn = label_fn
            else:
                self.label_fn = get_label_fn(self.files[0],
                                             self.path / mask_dir)
            #Check if corresponding masks exist
            mask_check = [self.label_fn(x).exists() for x in self.files]
            chk_str = f'Found {len(self.files)} images in "{image_dir}" and {sum(mask_check)} masks in "{mask_dir}".'
            assert len(self.files) == sum(mask_check) and len(
                self.files
            ) > 0, f'Please check your images and masks (and folders). {chk_str}'
            print(chk_str)

        else:
            self.label_fn = label_fn
        self.n_splits = min(len(self.files), self.max_splits)

        self.models = {}
        self.recorder = {}
        self._set_splits()
        self.ds = RandomTileDataset(self.files,
                                    label_fn=self.label_fn,
                                    **self.mw_kwargs,
                                    **self.ds_kwargs)
        self.in_channels = self.ds.get_data(max_n=1)[0].shape[-1]
        self.df_val, self.df_ens, self.df_model, self.ood = None, None, None, None

    @property
    def out_size(self):
        return self.ds_kwargs['tile_shape'][0] - self.ds_kwargs['padding'][0]

    def _set_splits(self):
        if self.n_splits > 1:
            kf = KFold(self.n_splits,
                       shuffle=True,
                       random_state=self.random_state)
            self.splits = {
                key: (self.files[idx[0]], self.files[idx[1]])
                for key, idx in zip(range(1, self.n_splits +
                                          1), kf.split(self.files))
            }
        else:
            self.splits = {1: (self.files[0], self.files[0])}

    def compose_albumentations(self, **kwargs):
        return _compose_albumentations(**kwargs)

    @property
    def ds_kwargs(self):
        # Setting default shapes and padding
        ds_kwargs = self.add_ds_kwargs.copy()
        for key, value in get_default_shapes(self.arch).items():
            ds_kwargs.setdefault(key, value)
        # Settings from config
        ds_kwargs[
            'loss_weights'] = True if self.loss == 'WeightedSoftmaxCrossEntropy' else False
        ds_kwargs['zoom_sigma'] = self.zoom_sigma
        ds_kwargs['flip'] = self.flip
        ds_kwargs['deformation_grid'] = (self.deformation_grid, ) * 2
        ds_kwargs['deformation_magnitude'] = (self.deformation_magnitude, ) * 2
        if sum(self.albumentation_kwargs.values()) > 0:
            ds_kwargs['albumentation_tfms'] = self.compose_albumentations(
                **self.albumentation_kwargs)
        return ds_kwargs

    def get_loss(self):
        if self.loss == 'WeightedSoftmaxCrossEntropy':
            return WeightedSoftmaxCrossEntropy(axis=1)
        if self.loss == 'CrossEntropyLoss': return CrossEntropyLossFlat(axis=1)
        else:
            kwargs = {
                'alpha': self.loss_alpha,
                'beta': self.loss_beta,
                'gamma': self.loss_gamma
            }
            return load_kornia_loss(self.loss, **kwargs)

    def get_model(self, pretrained):
        if self.arch in [
                "unet_deepflash2", "unet_falk2019", "unet_ronnberger2015",
                "unet_custom", "unext50_deepflash2"
        ]:
            model = torch.hub.load(self.repo,
                                   self.arch,
                                   pretrained=pretrained,
                                   n_classes=self.c,
                                   in_channels=self.in_channels,
                                   **self.model_kwargs)
        else:
            kwargs = dict(encoder_name=self.encoder_name,
                          encoder_weights=self.encoder_weights,
                          in_channels=self.in_channels,
                          classes=self.c,
                          **self.model_kwargs)
            model = load_smp_model(self.arch, **kwargs)
        if torch.cuda.is_available(): model.cuda()
        return model

    def get_dls(self, files, files_val=None):
        ds = []
        ds.append(
            RandomTileDataset(files,
                              label_fn=self.label_fn,
                              **self.mw_kwargs,
                              **self.ds_kwargs))
        if files_val:
            ds.append(
                TileDataset(files_val,
                            label_fn=self.label_fn,
                            **self.mw_kwargs,
                            **self.ds_kwargs))
        else:
            ds.append(ds[0])
        dls = DataLoaders.from_dsets(*ds,
                                     bs=self.bs,
                                     after_item=self.item_tfms,
                                     after_batch=self.get_batch_tfms(),
                                     **self.dl_kwargs)
        if torch.cuda.is_available(): dls.cuda()
        return dls

    def save_model(self, file, model, pickle_protocol=2):
        state = model.state_dict()
        state = {
            'model': state,
            'arch': self.arch,
            'stats': self.stats,
            'c': self.c
        }
        if self.arch in [
                "unet_deepflash2", "unet_falk2019", "unet_ronnberger2015",
                "unet_custom", "unext50_deepflash2"
        ]:
            state['repo'] = self.repo
        else:
            state['encoder_name'] = self.encoder_name
        torch.save(state,
                   file,
                   pickle_protocol=pickle_protocol,
                   _use_new_zipfile_serialization=False)

    def load_model(self, file, with_meta=True, device=None, strict=True):
        if isinstance(device, int): device = torch.device('cuda', device)
        elif device is None: device = 'cpu'
        state = torch.load(file, map_location=device)
        hasopt = 'model' in state  #set(state)=={'model', 'arch', 'repo', 'stats', 'c'}
        if hasopt:
            model_state = state['model']
            if with_meta:
                for opt in state:
                    if opt != 'model': setattr(self.config, opt, state[opt])
        else:
            model_state = state
        model = self.get_model(pretrained=None)
        model.load_state_dict(model_state, strict=strict)
        return model

    def get_batch_tfms(self):
        self.stats = self.stats or self.ds.compute_stats()
        tfms = [Normalize.from_stats(*self.stats)]
        if isinstance(self.loss_fn, WeightedSoftmaxCrossEntropy):
            tfms.append(WeightTransform(self.out_size, **self.mw_kwargs))
        return tfms

    def fit(self, i, n_iter=None, lr_max=None, **kwargs):
        n_iter = n_iter or self.n_iter
        lr_max = lr_max or self.lr
        name = self.ensemble_dir / f'{self.arch}_model-{i}.pth'
        pre = None if self.pretrained == 'new' else self.pretrained
        model = self.get_model(pretrained=pre)
        files_train, files_val = self.splits[i]
        dls = self.get_dls(files_train, files_val)
        self.learn = Learner(dls,
                             model,
                             metrics=self.metrics,
                             wd=self.wd,
                             loss_func=self.loss_fn,
                             opt_func=_optim_dict[self.optim],
                             cbs=self.cbs)
        self.learn.model_dir = self.ensemble_dir.parent / '.tmp'
        if self.mpt: self.learn.to_fp16()
        print(f'Starting training for {name.name}')
        epochs = calc_iterations(n_iter=n_iter,
                                 ds_length=len(dls.train_ds),
                                 bs=self.bs)
        self.learn.fit_one_cycle(epochs, lr_max)

        print(f'Saving model at {name}')
        name.parent.mkdir(exist_ok=True, parents=True)
        self.save_model(name, self.learn.model)
        self.models[i] = name
        self.recorder[i] = self.learn.recorder
        #del model
        #gc.collect()
        #torch.cuda.empty_cache()

    def fit_ensemble(self, n_iter, skip=False, **kwargs):
        for i in range(1, self.n + 1):
            if skip and (i in self.models): continue
            self.fit(i, n_iter, **kwargs)

    def set_n(self, n):
        for i in range(n, len(self.models)):
            self.models.pop(i + 1, None)
        self.n = n

    def predict(self, files, model_no, path=None, **kwargs):
        model_path = self.models[model_no]
        model = self.load_model(model_path)
        ds_kwargs = self.ds_kwargs
        # Adding extra padding (overlap) for models that have the same input and output shape
        if ds_kwargs['padding'][0] == 0:
            ds_kwargs['padding'] = (self.extra_padding, ) * 2
        ds = TileDataset(files, **ds_kwargs)
        dls = DataLoaders.from_dsets(ds,
                                     batch_size=self.bs,
                                     after_batch=self.get_batch_tfms(),
                                     shuffle=False,
                                     drop_last=False,
                                     **self.dl_kwargs)
        if torch.cuda.is_available(): dls.cuda()
        learn = Learner(dls, model, loss_func=self.loss_fn)
        if self.mpt: learn.to_fp16()
        if path: path = path / f'model_{model_no}'
        return learn.predict_tiles(dl=dls.train, path=path, **kwargs)

    def get_valid_results(self,
                          model_no=None,
                          export_dir=None,
                          filetype='.png',
                          **kwargs):
        res_list = []
        model_list = self.models if not model_no else [model_no]
        if export_dir:
            export_dir = Path(export_dir)
            pred_path = export_dir / 'masks'
            pred_path.mkdir(parents=True, exist_ok=True)
            if self.tta:
                unc_path = export_dir / 'uncertainties'
                unc_path.mkdir(parents=True, exist_ok=True)
        for i in model_list:
            _, files_val = self.splits[i]
            g_smx, g_seg, g_std, g_eng = self.predict(files_val, i, **kwargs)
            chunk_store = g_smx.chunk_store.path
            for j, f in enumerate(files_val):
                msk = self.ds.get_data(f, mask=True)[0]
                pred = g_seg[f.name][:]
                m_iou = iou(msk, pred)
                m_path = self.models[i].name
                m_eng_max = energy_max(g_eng[f.name][:], ks=self.energy_ks)
                df_tmp = pd.Series({
                    'file':
                    f.name,
                    'model':
                    m_path,
                    'model_no':
                    i,
                    'img_path':
                    f,
                    'iou':
                    m_iou,
                    'energy_max':
                    m_eng_max.numpy(),
                    'msk_path':
                    self.label_fn(f),
                    'pred_path':
                    f'{chunk_store}/{g_seg.path}/{f.name}',
                    'smx_path':
                    f'{chunk_store}/{g_smx.path}/{f.name}',
                    'std_path':
                    f'{chunk_store}/{g_std.path}/{f.name}'
                })
                res_list.append(df_tmp)
                if export_dir:
                    save_mask(pred,
                              pred_path / f'{df_tmp.file}_{df_tmp.model}_mask',
                              filetype)
                    if self.tta:
                        save_unc(
                            g_std[f.name][:],
                            unc_path / f'{df_tmp.file}_{df_tmp.model}_unc',
                            filetype)
        self.df_val = pd.DataFrame(res_list)
        if export_dir:
            self.df_val.to_csv(export_dir / f'val_results.csv', index=False)
            self.df_val.to_excel(export_dir / f'val_results.xlsx')
        return self.df_val

    def show_valid_results(self, model_no=None, files=None, **kwargs):
        if self.df_val is None: self.get_valid_results(**kwargs)
        df = self.df_val
        if files is not None: df = df.set_index('file', drop=False).loc[files]
        if model_no is not None: df = df[df.model_no == model_no]
        for _, r in df.iterrows():
            img = self.ds.get_data(r.img_path)[0][:]
            msk = self.ds.get_data(r.img_path, mask=True)[0]
            pred = zarr.load(r.pred_path)
            std = zarr.load(r.std_path)
            _d_model = f'Model {r.model_no}'
            if self.tta:
                plot_results(img, msk, pred, std, df=r, model=_d_model)
            else:
                plot_results(img,
                             msk,
                             pred,
                             np.zeros_like(pred),
                             df=r,
                             model=_d_model)

    def load_ensemble(self, path=None):
        path = path or self.ensemble_dir
        models = get_files(path, extensions='.pth', recurse=False)
        assert len(models) > 0, f'No models found in {path}'
        self.models = {}
        for m in models:
            model_id = int(m.stem[-1])
            self.models[model_id] = m
        print(f'Found {len(self.models)} models in folder {path}')
        print(self.models)

    def ensemble_results(self,
                         files,
                         path=None,
                         export_dir=None,
                         filetype='.png',
                         use_tta=None,
                         **kwargs):
        use_tta = use_tta or self.pred_tta
        if export_dir:
            export_dir = Path(export_dir)
            pred_path = export_dir / 'masks'
            pred_path.mkdir(parents=True, exist_ok=True)
            if use_tta:
                unc_path = export_dir / 'uncertainties'
                unc_path.mkdir(parents=True, exist_ok=True)

        store = str(path / 'ensemble') if path else zarr.storage.TempStore()
        root = zarr.group(store=store, overwrite=True)
        chunk_store = root.chunk_store.path
        g_smx, g_seg, g_std, g_eng = root.create_groups(
            'ens_smx', 'ens_seg', 'ens_std', 'ens_energy')
        res_list = []
        for f in files:
            df_fil = self.df_models[self.df_models.file == f.name]
            assert len(df_fil) == len(
                self.models), "Predictions and models to not match."
            m_smx, m_std, m_eng = tta.Merger(), tta.Merger(), tta.Merger()
            for idx, r in df_fil.iterrows():
                m_smx.append(zarr.load(r.smx_path))
                m_std.append(zarr.load(r.std_path))
                m_eng.append(zarr.load(r.eng_path))
            smx = m_smx.result().numpy()
            g_smx[f.name] = smx
            g_seg[f.name] = np.argmax(smx, axis=-1)
            g_std[f.name] = m_std.result().numpy()
            eng = m_eng.result()
            g_eng[f.name] = eng.numpy()
            m_eng_max = energy_max(eng, ks=self.energy_ks).numpy()
            df_tmp = pd.Series({
                'file':
                f.name,
                'model':
                f'{self.arch}_ensemble',
                'energy_max':
                m_eng_max,
                'img_path':
                f,
                'pred_path':
                f'{chunk_store}/{g_seg.path}/{f.name}',
                'smx_path':
                f'{chunk_store}/{g_smx.path}/{f.name}',
                'std_path':
                f'{chunk_store}/{g_std.path}/{f.name}',
                'eng_path':
                f'{chunk_store}/{g_eng.path}/{f.name}'
            })
            res_list.append(df_tmp)
            if export_dir:
                save_mask(g_seg[f.name][:],
                          pred_path / f'{df_tmp.file}_{df_tmp.model}_mask',
                          filetype)
                if use_tta:
                    save_unc(g_std[f.name][:],
                             unc_path / f'{df_tmp.file}_{df_tmp.model}_unc',
                             filetype)
        return pd.DataFrame(res_list)

    def get_ensemble_results(self,
                             new_files,
                             export_dir=None,
                             filetype='.png',
                             **kwargs):
        res_list = []
        for i in self.models:
            g_smx, g_seg, g_std, g_eng = self.predict(new_files, i, **kwargs)
            chunk_store = g_smx.chunk_store.path
            for j, f in enumerate(new_files):
                m_path = self.models[i].name
                df_tmp = pd.Series({
                    'file':
                    f.name,
                    'model_no':
                    i,
                    'model':
                    m_path,
                    'img_path':
                    f,
                    'pred_path':
                    f'{chunk_store}/{g_seg.path}/{f.name}',
                    'smx_path':
                    f'{chunk_store}/{g_smx.path}/{f.name}',
                    'std_path':
                    f'{chunk_store}/{g_std.path}/{f.name}',
                    'eng_path':
                    f'{chunk_store}/{g_eng.path}/{f.name}'
                })
                res_list.append(df_tmp)
        self.df_models = pd.DataFrame(res_list)
        self.df_ens = self.ensemble_results(new_files,
                                            export_dir=export_dir,
                                            filetype=filetype,
                                            **kwargs)
        return self.df_ens

    def score_ensemble_results(self, mask_dir=None, label_fn=None):
        if not label_fn:
            label_fn = get_label_fn(self.df_ens.img_path[0],
                                    self.path / mask_dir)
        for idx, r in self.df_ens.iterrows():
            msk_path = self.label_fn(r.img_path)
            msk = _read_msk(msk_path)
            self.df_ens.loc[idx, 'msk_path'] = msk_path
            pred = zarr.load(r.pred_path)
            self.df_ens.loc[idx, 'iou'] = iou(msk, pred)
        return self.df_ens

    def show_ensemble_results(self,
                              files=None,
                              model_no=None,
                              unc=True,
                              unc_metric=None):
        if self.df_ens is None:
            assert print("Please run `get_ensemble_results` first.")
        if model_no is None: df = self.df_ens
        else: df = self.df_models[df_models.model_no == model_no]
        if files is not None: df = df.set_index('file', drop=False).loc[files]
        for _, r in df.iterrows():
            imgs = []
            imgs.append(_read_img(r.img_path)[:])
            if 'iou' in r.index:
                imgs.append(_read_msk(r.msk_path))
                hastarget = True
            else:
                hastarget = False
            imgs.append(zarr.load(r.pred_path))
            if unc: imgs.append(zarr.load(r.std_path))
            plot_results(*imgs,
                         df=r,
                         hastarget=hastarget,
                         unc_metric=unc_metric)

    def lr_find(self, files=None, **kwargs):
        files = files or self.files
        dls = self.get_dls(files)
        pre = None if self.pretrained == 'new' else self.pretrained
        model = self.get_model(pretrained=pre)
        learn = Learner(dls,
                        model,
                        metrics=self.metrics,
                        wd=self.wd,
                        loss_func=self.loss_fn,
                        opt_func=_optim_dict[self.optim])
        if self.mpt: learn.to_fp16()
        sug_lrs = learn.lr_find(**kwargs)
        return sug_lrs, learn.recorder

    def show_mask_weights(self, files, figsize=(12, 12), **kwargs):
        masks = [self.label_fn(Path(f)) for f in files]
        for m in masks:
            print(self.mw_kwargs)
            print(f'Calculating weights. Please wait...')
            msk = _read_msk(m)
            _, w, _ = calculate_weights(msk, n_dims=self.c, **self.mw_kwargs)
            fig, axes = plt.subplots(nrows=1,
                                     ncols=2,
                                     figsize=figsize,
                                     **kwargs)
            axes[0].imshow(msk)
            axes[0].set_axis_off()
            axes[0].set_title(f'Mask {m.name}')
            axes[1].imshow(w)
            axes[1].set_axis_off()
            axes[1].set_title('Weights')
            plt.show()

    def ood_train(self, features=['energy_max'], **kwargs):
        self.ood = Pipeline([('scaler', StandardScaler()),
                             ('svm', svm.OneClassSVM(**kwargs))])
        self.ood.fit(self.df_ens[features])

    def ood_score(self, features=['energy_max']):
        self.df_ens['ood_score'] = self.ood.score_samples(
            self.df_ens[features])

    def ood_save(self, path):
        path = Path(path)
        joblib.dump(self.ood, path.with_suffix('.pkl'))
        print(f'Saved OOD model to {path}.pkl')

    def ood_load(self, path):
        path = Path(path)
        try:
            self.ood = joblib.load(path)
            print(f'Successsfully loaded OOD Model from {path}')
        except:
            print('Error! Select valid joblib file (.pkl)')

    def clear_tmp(self):
        try:
            shutil.rmtree('/tmp/*', ignore_errors=True)
            shutil.rmtree(self.path / '.tmp')
            print(f'Deleted temporary files from {self.path/".tmp"}')
        except:
            print(f'No temporary files to delete at {self.path/".tmp"}')
Exemplo n.º 7
0
    def fit(self,
            i,
            n_iter=None,
            lr_max=None,
            bs=None,
            n_jobs=-1,
            verbose=1,
            **kwargs):
        n_iter = n_iter or self.n_iter
        lr_max = lr_max or self.lr
        bs = bs or self.bs
        self.stats = self.stats or self.ds.compute_stats()
        name = self.ensemble_dir / f'{self.arch}_model-{i}.pth'
        files_train, files_val = self.splits[i]
        train_ds = RandomTileDataset(files_train,
                                     label_fn=self.label_fn,
                                     n_jobs=n_jobs,
                                     verbose=verbose,
                                     **self.mw_kwargs,
                                     **self.ds_kwargs)
        valid_ds = TileDataset(files_val,
                               label_fn=self.label_fn,
                               n_jobs=n_jobs,
                               verbose=verbose,
                               **self.mw_kwargs,
                               **self.ds_kwargs)
        batch_tfms = Normalize.from_stats(*self.stats)
        dls = DataLoaders.from_dsets(train_ds,
                                     valid_ds,
                                     bs=bs,
                                     after_item=self.item_tfms,
                                     after_batch=batch_tfms)
        pre = None if self.pretrained == 'new' else self.pretrained
        model = torch.hub.load(self.repo,
                               self.arch,
                               pretrained=pre,
                               n_classes=dls.c,
                               in_channels=self.in_channels,
                               **kwargs)
        if torch.cuda.is_available(): dls.cuda(), model.cuda()
        learn = Learner(dls,
                        model,
                        metrics=self.metrics,
                        wd=self.wd,
                        loss_func=self.loss_fn,
                        opt_func=_optim_dict[self.optim],
                        cbs=self.cbs)
        learn.model_dir = self.ensemble_dir.parent / '.tmp'
        if self.mpt: learn.to_fp16()
        print(f'Starting training for {name.name}')
        epochs = calc_iterations(n_iter=n_iter, ds_length=len(train_ds), bs=bs)
        learn.fit_one_cycle(epochs, lr_max)

        print(f'Saving model at {name}')
        name.parent.mkdir(exist_ok=True, parents=True)
        self.save_model(name, learn.model)
        self.models[i] = name
        self.recorder[i] = learn.recorder
        del model
        gc.collect()
        torch.cuda.empty_cache()
Exemplo n.º 8
0
class EnsembleLearner(GetAttr):
    _default = 'config'

    def __init__(self,
                 image_dir='images',
                 mask_dir=None,
                 config=None,
                 path=None,
                 ensemble_path=None,
                 preproc_dir=None,
                 label_fn=None,
                 metrics=None,
                 cbs=None,
                 ds_kwargs={},
                 dl_kwargs={},
                 model_kwargs={},
                 stats=None,
                 files=None):

        self.config = config or Config()
        self.stats = stats
        self.dl_kwargs = dl_kwargs
        self.model_kwargs = model_kwargs
        self.add_ds_kwargs = ds_kwargs
        self.path = Path(path) if path is not None else Path('.')
        default_metrics = [Dice()] if self.n_classes == 2 else [DiceMulti()]
        self.metrics = metrics or default_metrics
        self.loss_fn = self.get_loss()
        self.cbs = cbs or [
            SaveModelCallback(
                monitor='dice' if self.n_classes == 2 else 'dice_multi')
        ]  #ShowGraphCallback
        self.ensemble_dir = ensemble_path or self.path / self.ens_dir
        if ensemble_path is not None:
            ensemble_path.mkdir(exist_ok=True, parents=True)
            self.load_ensemble(path=ensemble_path)
        else:
            self.models = {}

        self.files = L(files) or get_image_files(self.path / image_dir,
                                                 recurse=False)
        assert len(
            self.files
        ) > 0, f'Found {len(self.files)} images in "{image_dir}". Please check your images and image folder'
        if any([mask_dir, label_fn]):
            if label_fn: self.label_fn = label_fn
            else:
                self.label_fn = get_label_fn(self.files[0],
                                             self.path / mask_dir)
            #Check if corresponding masks exist
            mask_check = [self.label_fn(x).exists() for x in self.files]
            chk_str = f'Found {len(self.files)} images in "{image_dir}" and {sum(mask_check)} masks in "{mask_dir}".'
            assert len(self.files) == sum(mask_check) and len(
                self.files
            ) > 0, f'Please check your images and masks (and folders). {chk_str}'
            print(chk_str)

        else:
            self.label_fn = label_fn
        self.n_splits = min(len(self.files), self.max_splits)
        self._set_splits()
        self.ds = RandomTileDataset(
            self.files,
            label_fn=self.label_fn,
            preproc_dir=preproc_dir,
            instance_labels=self.instance_labels,
            n_classes=self.n_classes,
            stats=self.stats,
            normalize=True,
            sample_mult=self.sample_mult if self.sample_mult > 0 else None,
            verbose=0,
            **self.add_ds_kwargs)

        self.stats = self.ds.stats
        self.in_channels = self.ds.get_data(max_n=1)[0].shape[-1]
        self.df_val, self.df_ens, self.df_model, self.ood = None, None, None, None
        self.recorder = {}

    def _set_splits(self):
        if self.n_splits > 1:
            kf = KFold(self.n_splits,
                       shuffle=True,
                       random_state=self.random_state)
            self.splits = {
                key: (self.files[idx[0]], self.files[idx[1]])
                for key, idx in zip(range(1, self.n_splits +
                                          1), kf.split(self.files))
            }
        else:
            self.splits = {1: (self.files[0], self.files[0])}

    def _compose_albumentations(self, **kwargs):
        return _compose_albumentations(**kwargs)

    @property
    def pred_ds_kwargs(self):
        # Setting default shapes and padding
        ds_kwargs = self.add_ds_kwargs.copy()
        ds_kwargs['use_preprocessed_labels'] = True
        ds_kwargs['preproc_dir'] = self.ds.preproc_dir
        ds_kwargs['instance_labels'] = self.instance_labels
        ds_kwargs['tile_shape'] = (self.tile_shape, ) * 2
        ds_kwargs['n_classes'] = self.n_classes
        ds_kwargs['shift'] = self.shift
        ds_kwargs['border_padding_factor'] = self.border_padding_factor
        return ds_kwargs

    @property
    def train_ds_kwargs(self):
        # Setting default shapes and padding
        ds_kwargs = self.add_ds_kwargs.copy()
        # Settings from config
        ds_kwargs['use_preprocessed_labels'] = True
        ds_kwargs['preproc_dir'] = self.ds.preproc_dir
        ds_kwargs['instance_labels'] = self.instance_labels
        ds_kwargs['stats'] = self.stats
        ds_kwargs['tile_shape'] = (self.tile_shape, ) * 2
        ds_kwargs['n_classes'] = self.n_classes
        ds_kwargs['shift'] = 1.
        ds_kwargs['border_padding_factor'] = 0.
        ds_kwargs['flip'] = self.flip
        ds_kwargs['albumentations_tfms'] = self._compose_albumentations(
            **self.albumentation_kwargs)
        ds_kwargs[
            'sample_mult'] = self.sample_mult if self.sample_mult > 0 else None
        return ds_kwargs

    @property
    def model_name(self):
        return f'{self.arch}_{self.encoder_name}_{self.n_classes}classes'

    def get_loss(self):
        kwargs = {
            'mode': self.mode,
            'classes': [x for x in range(1, self.n_classes)],
            'smooth_factor': self.loss_smooth_factor,
            'alpha': self.loss_alpha,
            'beta': self.loss_beta,
            'gamma': self.loss_gamma
        }
        return get_loss(self.loss, **kwargs)

    def _get_dls(self, files, files_val=None):
        ds = []
        ds.append(
            RandomTileDataset(files,
                              label_fn=self.label_fn,
                              **self.train_ds_kwargs))
        if files_val:
            ds.append(
                TileDataset(files_val,
                            label_fn=self.label_fn,
                            **self.train_ds_kwargs))
        else:
            ds.append(ds[0])
        dls = DataLoaders.from_dsets(*ds,
                                     bs=self.batch_size,
                                     pin_memory=True,
                                     **self.dl_kwargs)
        if torch.cuda.is_available(): dls.cuda()
        return dls

    def _create_model(self):
        model = create_smp_model(arch=self.arch,
                                 encoder_name=self.encoder_name,
                                 encoder_weights=self.encoder_weights,
                                 in_channels=self.in_channels,
                                 classes=self.n_classes,
                                 **self.model_kwargs)
        if torch.cuda.is_available(): model.cuda()
        return model

    def fit(self, i, n_iter=None, base_lr=None, **kwargs):
        n_iter = n_iter or self.n_iter
        base_lr = base_lr or self.base_lr
        name = self.ensemble_dir / f'{self.model_name}-fold{i}.pth'
        model = self._create_model()
        files_train, files_val = self.splits[i]
        dls = self._get_dls(files_train, files_val)
        log_name = f'{name.name}_{time.strftime("%Y%m%d-%H%M%S")}.csv'
        log_dir = self.ensemble_dir / 'logs'
        log_dir.mkdir(exist_ok=True, parents=True)
        cbs = self.cbs.append(CSVLogger(fname=log_dir / log_name))
        self.learn = Learner(dls,
                             model,
                             metrics=self.metrics,
                             wd=self.weight_decay,
                             loss_func=self.loss_fn,
                             opt_func=_optim_dict[self.optim],
                             cbs=self.cbs)
        self.learn.model_dir = self.ensemble_dir.parent / '.tmp'
        if self.mixed_precision_training: self.learn.to_fp16()
        print(f'Starting training for {name.name}')
        epochs = calc_iterations(n_iter=n_iter,
                                 ds_length=len(dls.train_ds),
                                 bs=self.batch_size)
        #self.learn.fit_one_cycle(epochs, lr_max)
        self.learn.fine_tune(epochs, base_lr=base_lr)

        print(f'Saving model at {name}')
        name.parent.mkdir(exist_ok=True, parents=True)
        save_smp_model(self.learn.model, self.arch, name, stats=self.stats)
        self.models[i] = name
        self.recorder[i] = self.learn.recorder

    def fit_ensemble(self, n_iter, skip=False, **kwargs):
        for i in range(1, self.n_models + 1):
            if skip and (i in self.models): continue
            self.fit(i, n_iter, **kwargs)

    def set_n(self, n):
        for i in range(n, len(self.models)):
            self.models.pop(i + 1, None)
        self.n_models = n

    def get_valid_results(self,
                          model_no=None,
                          zarr_store=None,
                          export_dir=None,
                          filetype='.png',
                          **kwargs):
        res_list = []
        model_list = self.models if not model_no else {
            k: v
            for k, v in self.models.items() if k == model_no
        }
        if export_dir:
            export_dir = Path(export_dir)
            pred_path = export_dir / 'masks'
            pred_path.mkdir(parents=True, exist_ok=True)
            unc_path = export_dir / 'uncertainties'
            unc_path.mkdir(parents=True, exist_ok=True)

        for i, model_path in model_list.items():
            ep = EnsemblePredict(models_paths=[model_path],
                                 zarr_store=zarr_store)
            _, files_val = self.splits[i]
            g_smx, g_std, g_eng = ep.predict_images(
                files_val,
                bs=self.batch_size,
                ds_kwargs=self.pred_ds_kwargs,
                **kwargs)
            del ep
            torch.cuda.empty_cache()

            chunk_store = g_smx.chunk_store.path
            for j, f in enumerate(files_val):
                msk = self.ds.get_data(f, mask=True)[0]
                pred = np.argmax(g_smx[f.name][:], axis=-1).astype('uint8')
                m_dice = dice_score(msk, pred)
                m_path = self.models[i].name
                df_tmp = pd.Series({
                    'file':
                    f.name,
                    'model':
                    m_path,
                    'model_no':
                    i,
                    'dice_score':
                    m_dice,
                    #'mean_energy': np.mean(g_eng[f.name][:][pred>0]),
                    'uncertainty_score':
                    np.mean(g_std[f.name][:][pred > 0])
                    if g_std is not None else None,
                    'image_path':
                    f,
                    'mask_path':
                    self.label_fn(f),
                    'softmax_path':
                    f'{chunk_store}/{g_smx.path}/{f.name}',
                    'engergy_path':
                    f'{chunk_store}/{g_eng.path}/{f.name}'
                    if g_eng is not None else None,
                    'uncertainty_path':
                    f'{chunk_store}/{g_std.path}/{f.name}'
                    if g_std is not None else None
                })
                res_list.append(df_tmp)
                if export_dir:
                    save_mask(pred,
                              pred_path / f'{df_tmp.file}_{df_tmp.model}_mask',
                              filetype)
                    if g_std is not None:
                        save_unc(
                            g_std[f.name][:], unc_path /
                            f'{df_tmp.file}_{df_tmp.model}_uncertainty',
                            filetype)
                    if g_eng is not None:
                        save_unc(
                            g_eng[f.name][:],
                            unc_path / f'{df_tmp.file}_{df_tmp.model}_energy',
                            filetype)
        self.df_val = pd.DataFrame(res_list)
        if export_dir:
            self.df_val.to_csv(export_dir / f'val_results.csv', index=False)
            self.df_val.to_excel(export_dir / f'val_results.xlsx')
        return self.df_val

    def show_valid_results(self, model_no=None, files=None, **kwargs):
        if self.df_val is None: self.get_valid_results(**kwargs)
        df = self.df_val
        if files is not None: df = df.set_index('file', drop=False).loc[files]
        if model_no is not None: df = df[df.model_no == model_no]
        for _, r in df.iterrows():
            img = self.ds.get_data(r.image_path)[0][:]
            msk = self.ds.get_data(r.image_path, mask=True)[0]
            pred = np.argmax(zarr.load(r.softmax_path),
                             axis=-1).astype('uint8')
            std = zarr.load(r.uncertainty_path)
            _d_model = f'Model {r.model_no}'
            if self.tta:
                plot_results(img, msk, pred, std, df=r, model=_d_model)
            else:
                plot_results(img,
                             msk,
                             pred,
                             np.zeros_like(pred),
                             df=r,
                             model=_d_model)

    def load_ensemble(self, path=None):
        path = path or self.ensemble_dir
        models = sorted(get_files(path, extensions='.pth', recurse=False))
        self.models = {}

        for i, m in enumerate(models, 1):
            if i == 0: self.n_classes = int(m.name.split('_')[2][0])
            else:
                assert self.n_classes == int(
                    m.name.split('_')[2][0]
                ), 'Check models. Models are trained on different number of classes.'
            self.models[i] = m

        if len(self.models) > 0:
            self.set_n(len(self.models))
            print(f'Found {len(self.models)} models in folder {path}:')
            print([m.name for m in self.models.values()])

            # Reset stats
            print(f'Loading stats from {self.models[1].name}')
            _, self.stats = load_smp_model(self.models[1])

    def get_ensemble_results(self,
                             files,
                             zarr_store=None,
                             export_dir=None,
                             filetype='.png',
                             **kwargs):
        ep = EnsemblePredict(models_paths=self.models.values(),
                             zarr_store=zarr_store)
        g_smx, g_std, g_eng = ep.predict_images(files,
                                                bs=self.batch_size,
                                                ds_kwargs=self.pred_ds_kwargs,
                                                **kwargs)
        chunk_store = g_smx.chunk_store.path
        del ep
        torch.cuda.empty_cache()

        if export_dir:
            export_dir = Path(export_dir)
            pred_path = export_dir / 'masks'
            pred_path.mkdir(parents=True, exist_ok=True)
            unc_path = export_dir / 'uncertainties'
            unc_path.mkdir(parents=True, exist_ok=True)

        res_list = []
        for f in files:
            pred = np.argmax(g_smx[f.name][:], axis=-1).astype('uint8')
            df_tmp = pd.Series({
                'file':
                f.name,
                'ensemble':
                self.model_name,
                'n_models':
                len(self.models),
                #'mean_energy': np.mean(g_eng[f.name][:][pred>0]),
                'uncertainty_score':
                np.mean(g_std[f.name][:][pred > 0])
                if g_std is not None else None,
                'image_path':
                f,
                'softmax_path':
                f'{chunk_store}/{g_smx.path}/{f.name}',
                'uncertainty_path':
                f'{chunk_store}/{g_std.path}/{f.name}'
                if g_std is not None else None,
                'energy_path':
                f'{chunk_store}/{g_eng.path}/{f.name}'
                if g_eng is not None else None
            })
            res_list.append(df_tmp)
            if export_dir:
                save_mask(pred,
                          pred_path / f'{df_tmp.file}_{df_tmp.ensemble}_mask',
                          filetype)
                if g_std is not None:
                    save_unc(g_std[f.name][:],
                             unc_path / f'{df_tmp.file}_{df_tmp.ensemble}_unc',
                             filetype)
                if g_eng is not None:
                    save_unc(
                        g_eng[f.name][:],
                        unc_path / f'{df_tmp.file}_{df_tmp.ensemble}_energy',
                        filetype)

        self.df_ens = pd.DataFrame(res_list)
        return g_smx, g_std, g_eng

    def score_ensemble_results(self, mask_dir=None, label_fn=None):
        if mask_dir is not None and label_fn is None:
            label_fn = get_label_fn(self.df_ens.image_path[0],
                                    self.path / mask_dir)
        for i, r in self.df_ens.iterrows():
            if label_fn is not None:
                msk_path = self.label_fn(r.image_path)
                msk = _read_msk(msk_path,
                                n_classes=self.n_classes,
                                instance_labels=self.instance_labels)
                self.df_ens.loc[i, 'mask_path'] = msk_path
            else:
                msk = self.ds.labels[r.file][:]
            pred = np.argmax(zarr.load(r.softmax_path),
                             axis=-1).astype('uint8')
            self.df_ens.loc[i, 'dice_score'] = dice_score(msk, pred)
        return self.df_ens

    def show_ensemble_results(self,
                              files=None,
                              unc=True,
                              unc_metric=None,
                              metric_name='dice_score'):
        assert self.df_ens is not None, "Please run `get_ensemble_results` first."
        df = self.df_ens
        if files is not None:
            df = df.reset_index().set_index('file', drop=False).loc[files]
        for _, r in df.iterrows():
            imgs = []
            imgs.append(_read_img(r.image_path)[:])
            if metric_name in r.index:
                try:
                    msk = self.ds.labels[r.file][:]
                except:
                    msk = _read_msk(r.mask_path,
                                    n_classes=self.n_classes,
                                    instance_labels=self.instance_labels)
                imgs.append(msk)
                hastarget = True
            else:
                hastarget = False
            imgs.append(
                np.argmax(zarr.load(r.softmax_path), axis=-1).astype('uint8'))
            if unc: imgs.append(zarr.load(r.uncertainty_path))
            plot_results(*imgs,
                         df=r,
                         hastarget=hastarget,
                         metric_name=metric_name,
                         unc_metric=unc_metric)

    def get_cellpose_results(self, export_dir=None):
        assert self.df_ens is not None, "Please run `get_ensemble_results` first."
        cl = self.cellpose_export_class
        assert cl < self.n_classes, f'{cl} not avaialable from {self.n_classes} classes'

        smxs, preds = [], []
        for _, r in self.df_ens.iterrows():
            softmax = zarr.load(r.softmax_path)
            smxs.append(softmax)
            preds.append(np.argmax(softmax, axis=-1).astype('uint8'))

        probs = [x[..., cl] for x in smxs]
        masks = [x == cl for x in preds]
        cp_masks = run_cellpose(probs,
                                masks,
                                model_type=self.cellpose_model,
                                diameter=self.cellpose_diameter,
                                min_size=self.min_pixel_export,
                                gpu=torch.cuda.is_available())

        if export_dir:
            export_dir = Path(export_dir)
            cp_path = export_dir / 'cellpose_masks'
            cp_path.mkdir(parents=True, exist_ok=True)
            for idx, r in self.df_ens.iterrows():
                tifffile.imwrite(cp_path / f'{r.file}_class{cl}.tif',
                                 cp_masks[idx],
                                 compress=6)

        self.cellpose_masks = cp_masks
        return cp_masks

    def score_cellpose_results(self, mask_dir=None, label_fn=None):
        assert self.cellpose_masks is not None, 'Run get_cellpose_results() first'
        if mask_dir is not None and label_fn is None:
            label_fn = get_label_fn(self.df_ens.image_path[0],
                                    self.path / mask_dir)
        for i, r in self.df_ens.iterrows():
            if label_fn is not None:
                msk_path = self.label_fn(r.image_path)
                msk = _read_msk(msk_path,
                                n_classes=self.n_classes,
                                instance_labels=self.instance_labels)
                self.df_ens.loc[i, 'mask_path'] = msk_path
            else:
                msk = self.ds.labels[r.file][:]
            _, msk = cv2.connectedComponents(msk, connectivity=4)
            pred = self.cellpose_masks[i]
            ap, tp, fp, fn = get_instance_segmentation_metrics(
                msk, pred, is_binary=False, min_pixel=self.min_pixel_export)
            self.df_ens.loc[i, 'mean_average_precision'] = ap.mean()
            self.df_ens.loc[i, 'average_precision_at_iou_50'] = ap[0]
        return self.df_ens

    def show_cellpose_results(self,
                              files=None,
                              unc=True,
                              unc_metric=None,
                              metric_name='mean_average_precision'):
        assert self.df_ens is not None, "Please run `get_ensemble_results` first."
        df = self.df_ens.reset_index()
        if files is not None: df = df.set_index('file', drop=False).loc[files]
        for _, r in df.iterrows():
            imgs = []
            imgs.append(_read_img(r.image_path)[:])
            if metric_name in r.index:
                try:
                    mask = self.ds.labels[idx][:]
                except:
                    mask = _read_msk(r.mask_path,
                                     n_classes=self.n_classes,
                                     instance_labels=self.instance_labels)
                _, comps = cv2.connectedComponents(
                    (mask == self.cellpose_export_class).astype('uint8'),
                    connectivity=4)
                imgs.append(label2rgb(comps, bg_label=0))
                hastarget = True
            else:
                hastarget = False
            imgs.append(label2rgb(self.cellpose_masks[r['index']], bg_label=0))
            if unc: imgs.append(zarr.load(r.uncertainty_path))
            plot_results(*imgs,
                         df=r,
                         hastarget=hastarget,
                         metric_name=metric_name,
                         unc_metric=unc_metric)

    def lr_find(self, files=None, **kwargs):
        files = files or self.files
        dls = self._get_dls(files)
        model = self._create_model()
        learn = Learner(dls,
                        model,
                        metrics=self.metrics,
                        wd=self.weight_decay,
                        loss_func=self.loss_fn,
                        opt_func=_optim_dict[self.optim])
        if self.mixed_precision_training: learn.to_fp16()
        sug_lrs = learn.lr_find(**kwargs)
        return sug_lrs, learn.recorder

    def export_imagej_rois(self, output_folder='ROI_sets', **kwargs):
        assert self.df_ens is not None, "Please run prediction first."

        output_folder = Path(output_folder)
        output_folder.mkdir(exist_ok=True, parents=True)
        for idx, r in progress_bar(self.df_ens.iterrows(),
                                   total=len(self.df_ens)):
            mask = np.argmax(zarr.load(r.softmax_path),
                             axis=-1).astype('uint8')
            uncertainty = zarr.load(r.uncertainty_path)
            export_roi_set(mask,
                           uncertainty,
                           name=r.file,
                           path=output_folder,
                           ascending=False,
                           **kwargs)

    def export_cellpose_rois(self,
                             output_folder='cellpose_ROI_sets',
                             **kwargs):
        output_folder = Path(output_folder)
        output_folder.mkdir(exist_ok=True, parents=True)
        for idx, r in progress_bar(self.df_ens.iterrows(),
                                   total=len(self.df_ens)):
            mask = self.cellpose_masks[idx]
            uncertainty = zarr.load(r.uncertainty_path)
            export_roi_set(mask,
                           uncertainty,
                           instance_labels=True,
                           name=r.file,
                           path=output_folder,
                           ascending=False,
                           **kwargs)