예제 #1
0
 def ensemble_results(self,
                      files,
                      save_dir=None,
                      filetype='.png',
                      use_tta=None,
                      **kwargs):
     use_tta = use_tta or self.pred_tta
     pth_out = self.path / '.tmp' / f'{self.arch}_ensemble'
     pth_out.mkdir(exist_ok=True, parents=True)
     if save_dir:
         save_dir = self.path / save_dir
         pred_path = save_dir / 'masks'
         pred_path.mkdir(parents=True, exist_ok=True)
         if use_tta:
             unc_path = save_dir / 'uncertainties'
             unc_path.mkdir(parents=True, exist_ok=True)
     res_list = []
     for f in files:
         for m in self.models.values():
             pth_tmp = self.path / '.tmp' / m.name / f'{f.stem}.npz'
             m_smx, m_std, m_enrgy = tta.Merger(), tta.Merger(), tta.Merger(
             )
             with open(pth_tmp, 'rb') as file:
                 tmp = np.load(file)
                 m_smx.append(tmp['smx'])
                 m_std.append(tmp['std'])
                 m_enrgy.append(tmp['enrgy'])
         smx = m_smx.result()
         seg = np.argmax(smx, axis=-1)
         std = m_std.result()
         enrgy = m_enrgy.result()
         np.savez(pth_out / f'{f.stem}.npz',
                  smx=smx,
                  seg=seg,
                  std=std,
                  enrgy=enrgy)
         df_tmp = pd.Series({
             'file': f.name,
             'model': pth_out.name,
             'img_path': f,
             'res_path': pth_out / f'{f.stem}.npz',
             'energy_max': enrgy.numpy()
         })
         res_list.append(df_tmp)
         if save_dir:
             save_mask(seg.numpy(),
                       pred_path / f'{df_tmp.file}_{df_tmp.model}_mask',
                       filetype)
             if use_tta:
                 save_unc(std.numpy(),
                          unc_path / f'{df_tmp.file}_{df_tmp.model}_unc',
                          filetype)
     return pd.DataFrame(res_list)
예제 #2
0
    def ensemble_results(self,
                         files,
                         path=None,
                         export_dir=None,
                         filetype='.png',
                         use_tta=None,
                         **kwargs):
        use_tta = use_tta or self.pred_tta
        if export_dir:
            export_dir = Path(export_dir)
            pred_path = export_dir / 'masks'
            pred_path.mkdir(parents=True, exist_ok=True)
            if use_tta:
                unc_path = export_dir / 'uncertainties'
                unc_path.mkdir(parents=True, exist_ok=True)

        store = str(path / 'ensemble') if path else zarr.storage.TempStore()
        root = zarr.group(store=store, overwrite=True)
        chunk_store = root.chunk_store.path
        g_smx, g_seg, g_std, g_eng = root.create_groups(
            'ens_smx', 'ens_seg', 'ens_std', 'ens_energy')
        res_list = []
        for f in files:
            df_fil = self.df_models[self.df_models.file == f.name]
            assert len(df_fil) == len(
                self.models), "Predictions and models to not match."
            m_smx, m_std, m_eng = tta.Merger(), tta.Merger(), tta.Merger()
            for idx, r in df_fil.iterrows():
                m_smx.append(zarr.load(r.smx_path))
                m_std.append(zarr.load(r.std_path))
                m_eng.append(zarr.load(r.eng_path))
            smx = m_smx.result().numpy()
            g_smx[f.name] = smx
            g_seg[f.name] = np.argmax(smx, axis=-1)
            g_std[f.name] = m_std.result().numpy()
            eng = m_eng.result()
            g_eng[f.name] = eng.numpy()
            m_eng_max = energy_max(eng, ks=self.energy_ks).numpy()
            df_tmp = pd.Series({
                'file':
                f.name,
                'model':
                f'{self.arch}_ensemble',
                'energy_max':
                m_eng_max,
                'img_path':
                f,
                'pred_path':
                f'{chunk_store}/{g_seg.path}/{f.name}',
                'smx_path':
                f'{chunk_store}/{g_smx.path}/{f.name}',
                'std_path':
                f'{chunk_store}/{g_std.path}/{f.name}',
                'eng_path':
                f'{chunk_store}/{g_eng.path}/{f.name}'
            })
            res_list.append(df_tmp)
            if export_dir:
                save_mask(g_seg[f.name][:],
                          pred_path / f'{df_tmp.file}_{df_tmp.model}_mask',
                          filetype)
                if use_tta:
                    save_unc(g_std[f.name][:],
                             unc_path / f'{df_tmp.file}_{df_tmp.model}_unc',
                             filetype)
        return pd.DataFrame(res_list)
예제 #3
0
def predict_tiles(self: Learner,
                  ds_idx=1,
                  dl=None,
                  path=None,
                  mc_dropout=False,
                  n_times=1,
                  use_tta=False,
                  tta_merge='mean',
                  tta_tfms=None,
                  uncertainty_estimates=True,
                  energy_T=1):
    "Make predictions and reconstruct tiles, optional with dropout and/or tta applied."

    if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)
    assert isinstance(
        dl.dataset, TileDataset), "Provide dataloader containing a TileDataset"
    if use_tta:
        tfms = tta_tfms or [
            tta.HorizontalFlip(),
            tta.Rotate90(angles=[90, 180, 270])
        ]
    else:
        tfms = []

    self.model.eval()
    if mc_dropout: self.apply_dropout()

    store = str(path) if path else zarr.storage.TempStore()
    root = zarr.group(store=store, overwrite=True)
    g_smx, g_seg, g_std, g_eng = root.create_groups('smx', 'seg', 'std',
                                                    'energy')

    i = 0
    last_file = None
    for data in progress_bar(dl, leave=False):
        if isinstance(data, TensorImage): images = data
        else: images, _, _ = data
        m_smx = tta.Merger()
        m_energy = tta.Merger()
        out_list_smx = []
        for t in tta.Compose(tfms):
            for _ in range(n_times):
                aug_images = t.augment_image(images)
                with torch.no_grad():
                    out = self.model(aug_images)
                out = t.deaugment_mask(out)
                if dl.padding[0] != images.shape[-1] - out.shape[-1]:
                    padding = (
                        (images.shape[-1] - out.shape[-1] - dl.padding[0]) //
                        2, ) * 4
                    out = F.pad(out, padding)
                m_smx.append(F.softmax(out, dim=1))
                if uncertainty_estimates:
                    e = (energy_T * torch.logsumexp(out / energy_T, dim=1)
                         )  #negative energy score
                    m_energy.append(e)

        ll = []
        ll.append(
            [x for x in m_smx.result().permute(0, 2, 3, 1).cpu().numpy()])
        if uncertainty_estimates:
            ll.append(
                [x for x in torch.mean(m_smx.result('std'), 1).cpu().numpy()])
            ll.append([x for x in m_energy.result().cpu().numpy()])
        for j, preds in enumerate(zip(*ll)):
            if len(preds) == 3: smx, std, eng = preds
            else: smx = preds[0]
            idx = i + j
            f = dl.files[dl.image_indices[idx]]
            outShape = dl.image_shapes[idx]
            outSlice = dl.out_slices[idx]
            inSlice = dl.in_slices[idx]
            if last_file != f:
                z_smx = g_smx.zeros(f.name,
                                    shape=(*outShape, dl.c),
                                    dtype='float32')
                z_seg = g_seg.zeros(f.name, shape=outShape, dtype='uint8')
                z_std = g_std.zeros(f.name, shape=outShape, dtype='float32')
                z_eng = g_eng.zeros(f.name, shape=outShape, dtype='float32')
                last_file = f
            z_smx[outSlice] = smx[inSlice]
            z_seg[outSlice] = np.argmax(smx, axis=-1)[inSlice]
            if uncertainty_estimates:
                z_std[outSlice] = std[inSlice]
                z_eng[outSlice] = eng[inSlice]
        i += dl.bs

    return g_smx, g_seg, g_std, g_eng
예제 #4
0
def predict_tiles(self: Learner,
                  ds_idx=1,
                  dl=None,
                  mc_dropout=False,
                  n_times=1,
                  use_tta=False,
                  tta_merge='mean',
                  energy_T=1,
                  energy_ks=20,
                  padding=(0, 0, 0, 0)):  #(-52,-52,-52,-52)
    "Make predictions and reconstruct tiles, optional with dropout and/or tta applied."

    if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)
    if use_tta:
        tfms = [tta.HorizontalFlip(), tta.Rotate90(angles=[90, 180, 270])]
    else:
        tfms = []

    self.model.eval()
    if mc_dropout: self.apply_dropout()

    smx_means, energy_means, stds = [], [], []
    for data in progress_bar(dl, leave=False):
        if isinstance(data, TensorImage): images = data
        else: images, _, _ = data
        m_smx = tta.Merger()
        m_energy = tta.Merger()
        out_list_smx = []
        for t in tta.Compose(tfms):
            for _ in range(n_times):
                aug_images = t.augment_image(images)
                with torch.no_grad():
                    out = self.model(aug_images)
                out = t.deaugment_mask(out)
                if sum(padding) > 0: out = F.pad(out, padding)
                m_smx.append(F.softmax(out, dim=1))
                e = (energy_T * torch.logsumexp(out / energy_T, dim=1)
                     )  #negative energy score
                m_energy.append(e)

        stds.append(m_smx.result('std'))
        smx_means.append(m_smx.result())
        energy_means.append(m_energy.result())

    softmax_pred = torch.cat(smx_means).permute(0, 2, 3, 1)
    smx_tiles = [x for x in softmax_pred.cpu().numpy()]

    std_pred = torch.cat(stds).permute(0, 2, 3, 1)
    std_tiles = [x[..., 0] for x in std_pred.cpu().numpy()]

    energy_pred = torch.cat(energy_means)  #.permute(0,2,3,1)
    energy_tiles = [x for x in energy_pred.cpu().numpy()]

    smxcores = dl.reconstruct_from_tiles(smx_tiles)
    segmentations = [np.argmax(x, axis=-1) for x in smxcores]
    std_deviations = dl.reconstruct_from_tiles(std_tiles)
    energy_scores = dl.reconstruct_from_tiles(energy_tiles)

    if energy_ks is not None:
        energy_scores = [energy_max(e, energy_ks) for e in energy_scores]

    return smxcores, segmentations, std_deviations, energy_scores
예제 #5
0
    def predict(self,
                ds,
                use_tta=True,
                bs=4,
                use_gaussian=True,
                sigma_scale=1. / 8,
                uncertainty_estimates=True,
                uncertainty_type='uncertainty',
                energy_scores=False,
                energy_T=1.,
                verbose=0):

        if verbose > 0:
            print('Ensemble prediction with models:', self.models_paths)

        tfms = [tta.HorizontalFlip(), tta.VerticalFlip()] if use_tta else []
        if verbose > 0: print('Using Test-Time Augmentation with:', tfms)

        dl = DataLoader(ds, bs, num_workers=0, shuffle=False, pin_memory=True)

        # Create zero arrays
        data_shape = ds.image_shapes[0]
        softmax = np.zeros((*data_shape, ds.c), dtype='float32')
        merge_map = np.zeros(data_shape, dtype='float32')
        stdeviation = np.zeros(
            data_shape, dtype='float32') if uncertainty_estimates else None
        energy = np.zeros(data_shape,
                          dtype='float32') if energy_scores else None

        # Define merge weights
        if use_gaussian:
            mw_numpy = _get_gaussian(ds.output_shape, sigma_scale)
        else:
            mw_numpy = np.ones(dl.output_shape)
        mw = torch.from_numpy(mw_numpy).to(self.device)

        # Loop over tiles (indices required!)
        for tiles, idxs in iter(dl):
            tiles = tiles.to(self.device)
            smx_merger = tta.Merger()
            if energy_scores:
                energy_merger = tta.Merger()

            # Loop over tt-augmentations
            for t in tta.Compose(tfms):
                aug_tiles = t.augment_image(tiles)
                model_merger = tta.Merger()
                if energy_scores: engergy_list = []

                # Loop over models
                for model in self.models:
                    with torch.inference_mode():
                        logits = model(aug_tiles)
                    logits = t.deaugment_mask(logits)
                    smx_merger.append(F.softmax(logits, dim=1))
                    if energy_scores:
                        energy_merger.append(-energy_score(
                            logits, energy_T))  #negative energy score

            out_list = []
            # Apply gaussian weigthing
            batch_smx = smx_merger.result() * mw.view(1, 1, *mw.shape)
            # Reshape and append to list
            out_list.append(
                [x for x in batch_smx.permute(0, 2, 3, 1).cpu().numpy()])

            if uncertainty_estimates:
                batch_std = torch.mean(smx_merger.result(uncertainty_type),
                                       dim=1) * mw.view(1, *mw.shape)
                out_list.append([x for x in batch_std.cpu().numpy()])

            if energy_scores:
                batch_energy = energy_merger.result() * mw.view(1, *mw.shape)
                out_list.append([x for x in batch_energy.cpu().numpy()])

            # Compose predictions
            for preds in zip(*out_list, idxs):
                if len(preds) == 4: smx, std, eng, idx = preds
                elif uncertainty_estimates: smx, std, idx = preds
                elif energy_scores: smx, eng, idx = preds

                else: smx, idx = preds
                out_slice = ds.out_slices[idx]
                in_slice = ds.in_slices[idx]
                softmax[out_slice] += smx[in_slice]
                merge_map[out_slice] += mw_numpy[in_slice]

                if uncertainty_estimates:
                    stdeviation[out_slice] += std[in_slice]
                if energy_scores:
                    energy[out_slice] += eng[in_slice]

        # Normalize weighting
        softmax /= merge_map[..., np.newaxis]
        if uncertainty_estimates:
            stdeviation /= merge_map
        if energy_scores:
            energy /= merge_map

        return softmax, stdeviation, energy