Exemple #1
0
def main():
    # Parse command-line arguments
    args = parse_args()

    # Enable auto logging
    mlflow.fastai.autolog()

    # Create Learner model
    learn = Learner(get_data_loaders(),
                    Model(),
                    loss_func=nn.MSELoss(),
                    splitter=splitter)

    # Start MLflow session
    with mlflow.start_run():
        # Train and fit with default or supplied command line arguments
        learn.fit_one_cycle(args.epochs, args.lr)
class AutoEmbedderCategoryEncoder(CustomCategoryEncoder):
    """Uses an `AutoEmbedder` model to perform encoding of categorical features."""
    _preprocessor_cls: Type[
        CategoryEncoderPreprocessor] = AutoEmbedderPreprocessor
    learn: Learner = None
    emb_szs: Dict[str, int] = None

    def encode(self, X: TabDataLoader):
        """Encodes all elements in `data`."""
        data = X if isinstance(X, TabDataLoader) else X.train
        preds = self.learn.get_preds(dl=data, reorder=False)[0].cpu().numpy()
        return pd.DataFrame(preds, columns=self.get_feature_names())

    def fit(self, X: TabularDataLoaders):
        """Creates the learner and trains it."""
        emb_szs = get_emb_sz(X.train_ds, {})
        self.emb_szs = {col: sz for col, sz in zip(self.cat_names, emb_szs)}
        n_conts = len(X.cont_names)
        n_cats = sum(list(map(lambda e: e[1], emb_szs)))
        in_sz = n_conts + n_cats
        out_sz = n_conts + len(X.cat_names)
        # Create the embedding model
        model = AutoEmbedder(in_sz, out_sz, emb_szs, [2000, 1000])
        self.learn = Learner(X, model, loss_func=EmbeddingLoss(model), wd=1.0)
        # TODO hide training progress?
        with self.learn.no_bar():
            self.learn.fit_one_cycle(20, lr_max=3e-3)

    def decode(self, X: pd.DataFrame) -> pd.DataFrame:
        """Decodes multiple items for one feature embedding."""
        column_idx = 0
        df = pd.DataFrame()
        data = torch.tensor(X[self.get_feature_names()].values)
        embeddings = self.learn.model.embeddings.embeddings
        # Split data into chunks depending on embedding sizes
        data = torch.split(data,
                           list(map(lambda o: o[1], self.emb_szs.values())),
                           dim=-1)
        # Iterate over features, decoding each one for all rows
        for (embedding_vectors, embedding_layer,
             (colname, (n_unique_values,
                        embedding_size))) in zip(data, embeddings,
                                                 self.emb_szs.items()):
            # Calculate the embedding output for each category value
            cat_embeddings = embedding_layer(
                torch.tensor(
                    range(n_unique_values)).to(device=embedding_layer.device))
            # Compute cosine similarity over embeddings
            most_similar = expanded(
                embedding_vectors, cat_embeddings,
                lambda a, b: F.cosine_similarity(a, b, dim=-1))
            # Map values to their most similar category
            most_similar = most_similar.argmax(dim=-1)
            # Save data into decoded column
            df[colname] = most_similar.cpu().numpy()
            # move forward the column index
            column_idx += embedding_size
        return df

    def get_feature_names(self) -> List[str]:
        """
        Returns a list of encoded feature names.
        For embeddings, this is a list of original categorical names followed by embedding index,
        e.g. [feature_a_0, feature_a_1, feature_b_0, feature_b_1].
        """
        return [
            f"{column}_{feature_num}" for column in self.cat_names
            for feature_num in range(self.emb_szs[column][1])
        ]

    def get_emb_szs(self):
        """Returns a dict of embedding sizes for each categorical feature."""
        return self.emb_szs
Exemple #3
0
class EnsembleLearner(GetAttr):
    _default = 'config'

    def __init__(self,
                 image_dir='images',
                 mask_dir=None,
                 config=None,
                 path=None,
                 ensemble_dir=None,
                 item_tfms=None,
                 label_fn=None,
                 metrics=None,
                 cbs=None,
                 ds_kwargs={},
                 dl_kwargs={},
                 model_kwargs={},
                 stats=None,
                 files=None):

        self.config = config or Config()
        self.stats = stats
        self.dl_kwargs = dl_kwargs
        self.model_kwargs = model_kwargs
        self.add_ds_kwargs = ds_kwargs
        self.item_tfms = item_tfms
        self.path = Path(path) if path is not None else Path('.')
        self.metrics = metrics or [Iou(), Dice_f1()]
        self.loss_fn = self.get_loss()
        self.cbs = cbs or [
            SaveModelCallback(monitor='iou'), ElasticDeformCallback
        ]  #ShowGraphCallback
        self.ensemble_dir = ensemble_dir or self.path / 'ensemble'

        self.files = L(files) or get_image_files(self.path / image_dir,
                                                 recurse=False)
        assert len(
            self.files
        ) > 0, f'Found {len(self.files)} images in "{image_dir}". Please check your images and image folder'
        if any([mask_dir, label_fn]):
            if label_fn: self.label_fn = label_fn
            else:
                self.label_fn = get_label_fn(self.files[0],
                                             self.path / mask_dir)
            #Check if corresponding masks exist
            mask_check = [self.label_fn(x).exists() for x in self.files]
            chk_str = f'Found {len(self.files)} images in "{image_dir}" and {sum(mask_check)} masks in "{mask_dir}".'
            assert len(self.files) == sum(mask_check) and len(
                self.files
            ) > 0, f'Please check your images and masks (and folders). {chk_str}'
            print(chk_str)

        else:
            self.label_fn = label_fn
        self.n_splits = min(len(self.files), self.max_splits)

        self.models = {}
        self.recorder = {}
        self._set_splits()
        self.ds = RandomTileDataset(self.files,
                                    label_fn=self.label_fn,
                                    **self.mw_kwargs,
                                    **self.ds_kwargs)
        self.in_channels = self.ds.get_data(max_n=1)[0].shape[-1]
        self.df_val, self.df_ens, self.df_model, self.ood = None, None, None, None

    @property
    def out_size(self):
        return self.ds_kwargs['tile_shape'][0] - self.ds_kwargs['padding'][0]

    def _set_splits(self):
        if self.n_splits > 1:
            kf = KFold(self.n_splits,
                       shuffle=True,
                       random_state=self.random_state)
            self.splits = {
                key: (self.files[idx[0]], self.files[idx[1]])
                for key, idx in zip(range(1, self.n_splits +
                                          1), kf.split(self.files))
            }
        else:
            self.splits = {1: (self.files[0], self.files[0])}

    def compose_albumentations(self, **kwargs):
        return _compose_albumentations(**kwargs)

    @property
    def ds_kwargs(self):
        # Setting default shapes and padding
        ds_kwargs = self.add_ds_kwargs.copy()
        for key, value in get_default_shapes(self.arch).items():
            ds_kwargs.setdefault(key, value)
        # Settings from config
        ds_kwargs[
            'loss_weights'] = True if self.loss == 'WeightedSoftmaxCrossEntropy' else False
        ds_kwargs['zoom_sigma'] = self.zoom_sigma
        ds_kwargs['flip'] = self.flip
        ds_kwargs['deformation_grid'] = (self.deformation_grid, ) * 2
        ds_kwargs['deformation_magnitude'] = (self.deformation_magnitude, ) * 2
        if sum(self.albumentation_kwargs.values()) > 0:
            ds_kwargs['albumentation_tfms'] = self.compose_albumentations(
                **self.albumentation_kwargs)
        return ds_kwargs

    def get_loss(self):
        if self.loss == 'WeightedSoftmaxCrossEntropy':
            return WeightedSoftmaxCrossEntropy(axis=1)
        if self.loss == 'CrossEntropyLoss': return CrossEntropyLossFlat(axis=1)
        else:
            kwargs = {
                'alpha': self.loss_alpha,
                'beta': self.loss_beta,
                'gamma': self.loss_gamma
            }
            return load_kornia_loss(self.loss, **kwargs)

    def get_model(self, pretrained):
        if self.arch in [
                "unet_deepflash2", "unet_falk2019", "unet_ronnberger2015",
                "unet_custom", "unext50_deepflash2"
        ]:
            model = torch.hub.load(self.repo,
                                   self.arch,
                                   pretrained=pretrained,
                                   n_classes=self.c,
                                   in_channels=self.in_channels,
                                   **self.model_kwargs)
        else:
            kwargs = dict(encoder_name=self.encoder_name,
                          encoder_weights=self.encoder_weights,
                          in_channels=self.in_channels,
                          classes=self.c,
                          **self.model_kwargs)
            model = load_smp_model(self.arch, **kwargs)
        if torch.cuda.is_available(): model.cuda()
        return model

    def get_dls(self, files, files_val=None):
        ds = []
        ds.append(
            RandomTileDataset(files,
                              label_fn=self.label_fn,
                              **self.mw_kwargs,
                              **self.ds_kwargs))
        if files_val:
            ds.append(
                TileDataset(files_val,
                            label_fn=self.label_fn,
                            **self.mw_kwargs,
                            **self.ds_kwargs))
        else:
            ds.append(ds[0])
        dls = DataLoaders.from_dsets(*ds,
                                     bs=self.bs,
                                     after_item=self.item_tfms,
                                     after_batch=self.get_batch_tfms(),
                                     **self.dl_kwargs)
        if torch.cuda.is_available(): dls.cuda()
        return dls

    def save_model(self, file, model, pickle_protocol=2):
        state = model.state_dict()
        state = {
            'model': state,
            'arch': self.arch,
            'stats': self.stats,
            'c': self.c
        }
        if self.arch in [
                "unet_deepflash2", "unet_falk2019", "unet_ronnberger2015",
                "unet_custom", "unext50_deepflash2"
        ]:
            state['repo'] = self.repo
        else:
            state['encoder_name'] = self.encoder_name
        torch.save(state,
                   file,
                   pickle_protocol=pickle_protocol,
                   _use_new_zipfile_serialization=False)

    def load_model(self, file, with_meta=True, device=None, strict=True):
        if isinstance(device, int): device = torch.device('cuda', device)
        elif device is None: device = 'cpu'
        state = torch.load(file, map_location=device)
        hasopt = 'model' in state  #set(state)=={'model', 'arch', 'repo', 'stats', 'c'}
        if hasopt:
            model_state = state['model']
            if with_meta:
                for opt in state:
                    if opt != 'model': setattr(self.config, opt, state[opt])
        else:
            model_state = state
        model = self.get_model(pretrained=None)
        model.load_state_dict(model_state, strict=strict)
        return model

    def get_batch_tfms(self):
        self.stats = self.stats or self.ds.compute_stats()
        tfms = [Normalize.from_stats(*self.stats)]
        if isinstance(self.loss_fn, WeightedSoftmaxCrossEntropy):
            tfms.append(WeightTransform(self.out_size, **self.mw_kwargs))
        return tfms

    def fit(self, i, n_iter=None, lr_max=None, **kwargs):
        n_iter = n_iter or self.n_iter
        lr_max = lr_max or self.lr
        name = self.ensemble_dir / f'{self.arch}_model-{i}.pth'
        pre = None if self.pretrained == 'new' else self.pretrained
        model = self.get_model(pretrained=pre)
        files_train, files_val = self.splits[i]
        dls = self.get_dls(files_train, files_val)
        self.learn = Learner(dls,
                             model,
                             metrics=self.metrics,
                             wd=self.wd,
                             loss_func=self.loss_fn,
                             opt_func=_optim_dict[self.optim],
                             cbs=self.cbs)
        self.learn.model_dir = self.ensemble_dir.parent / '.tmp'
        if self.mpt: self.learn.to_fp16()
        print(f'Starting training for {name.name}')
        epochs = calc_iterations(n_iter=n_iter,
                                 ds_length=len(dls.train_ds),
                                 bs=self.bs)
        self.learn.fit_one_cycle(epochs, lr_max)

        print(f'Saving model at {name}')
        name.parent.mkdir(exist_ok=True, parents=True)
        self.save_model(name, self.learn.model)
        self.models[i] = name
        self.recorder[i] = self.learn.recorder
        #del model
        #gc.collect()
        #torch.cuda.empty_cache()

    def fit_ensemble(self, n_iter, skip=False, **kwargs):
        for i in range(1, self.n + 1):
            if skip and (i in self.models): continue
            self.fit(i, n_iter, **kwargs)

    def set_n(self, n):
        for i in range(n, len(self.models)):
            self.models.pop(i + 1, None)
        self.n = n

    def predict(self, files, model_no, path=None, **kwargs):
        model_path = self.models[model_no]
        model = self.load_model(model_path)
        ds_kwargs = self.ds_kwargs
        # Adding extra padding (overlap) for models that have the same input and output shape
        if ds_kwargs['padding'][0] == 0:
            ds_kwargs['padding'] = (self.extra_padding, ) * 2
        ds = TileDataset(files, **ds_kwargs)
        dls = DataLoaders.from_dsets(ds,
                                     batch_size=self.bs,
                                     after_batch=self.get_batch_tfms(),
                                     shuffle=False,
                                     drop_last=False,
                                     **self.dl_kwargs)
        if torch.cuda.is_available(): dls.cuda()
        learn = Learner(dls, model, loss_func=self.loss_fn)
        if self.mpt: learn.to_fp16()
        if path: path = path / f'model_{model_no}'
        return learn.predict_tiles(dl=dls.train, path=path, **kwargs)

    def get_valid_results(self,
                          model_no=None,
                          export_dir=None,
                          filetype='.png',
                          **kwargs):
        res_list = []
        model_list = self.models if not model_no else [model_no]
        if export_dir:
            export_dir = Path(export_dir)
            pred_path = export_dir / 'masks'
            pred_path.mkdir(parents=True, exist_ok=True)
            if self.tta:
                unc_path = export_dir / 'uncertainties'
                unc_path.mkdir(parents=True, exist_ok=True)
        for i in model_list:
            _, files_val = self.splits[i]
            g_smx, g_seg, g_std, g_eng = self.predict(files_val, i, **kwargs)
            chunk_store = g_smx.chunk_store.path
            for j, f in enumerate(files_val):
                msk = self.ds.get_data(f, mask=True)[0]
                pred = g_seg[f.name][:]
                m_iou = iou(msk, pred)
                m_path = self.models[i].name
                m_eng_max = energy_max(g_eng[f.name][:], ks=self.energy_ks)
                df_tmp = pd.Series({
                    'file':
                    f.name,
                    'model':
                    m_path,
                    'model_no':
                    i,
                    'img_path':
                    f,
                    'iou':
                    m_iou,
                    'energy_max':
                    m_eng_max.numpy(),
                    'msk_path':
                    self.label_fn(f),
                    'pred_path':
                    f'{chunk_store}/{g_seg.path}/{f.name}',
                    'smx_path':
                    f'{chunk_store}/{g_smx.path}/{f.name}',
                    'std_path':
                    f'{chunk_store}/{g_std.path}/{f.name}'
                })
                res_list.append(df_tmp)
                if export_dir:
                    save_mask(pred,
                              pred_path / f'{df_tmp.file}_{df_tmp.model}_mask',
                              filetype)
                    if self.tta:
                        save_unc(
                            g_std[f.name][:],
                            unc_path / f'{df_tmp.file}_{df_tmp.model}_unc',
                            filetype)
        self.df_val = pd.DataFrame(res_list)
        if export_dir:
            self.df_val.to_csv(export_dir / f'val_results.csv', index=False)
            self.df_val.to_excel(export_dir / f'val_results.xlsx')
        return self.df_val

    def show_valid_results(self, model_no=None, files=None, **kwargs):
        if self.df_val is None: self.get_valid_results(**kwargs)
        df = self.df_val
        if files is not None: df = df.set_index('file', drop=False).loc[files]
        if model_no is not None: df = df[df.model_no == model_no]
        for _, r in df.iterrows():
            img = self.ds.get_data(r.img_path)[0][:]
            msk = self.ds.get_data(r.img_path, mask=True)[0]
            pred = zarr.load(r.pred_path)
            std = zarr.load(r.std_path)
            _d_model = f'Model {r.model_no}'
            if self.tta:
                plot_results(img, msk, pred, std, df=r, model=_d_model)
            else:
                plot_results(img,
                             msk,
                             pred,
                             np.zeros_like(pred),
                             df=r,
                             model=_d_model)

    def load_ensemble(self, path=None):
        path = path or self.ensemble_dir
        models = get_files(path, extensions='.pth', recurse=False)
        assert len(models) > 0, f'No models found in {path}'
        self.models = {}
        for m in models:
            model_id = int(m.stem[-1])
            self.models[model_id] = m
        print(f'Found {len(self.models)} models in folder {path}')
        print(self.models)

    def ensemble_results(self,
                         files,
                         path=None,
                         export_dir=None,
                         filetype='.png',
                         use_tta=None,
                         **kwargs):
        use_tta = use_tta or self.pred_tta
        if export_dir:
            export_dir = Path(export_dir)
            pred_path = export_dir / 'masks'
            pred_path.mkdir(parents=True, exist_ok=True)
            if use_tta:
                unc_path = export_dir / 'uncertainties'
                unc_path.mkdir(parents=True, exist_ok=True)

        store = str(path / 'ensemble') if path else zarr.storage.TempStore()
        root = zarr.group(store=store, overwrite=True)
        chunk_store = root.chunk_store.path
        g_smx, g_seg, g_std, g_eng = root.create_groups(
            'ens_smx', 'ens_seg', 'ens_std', 'ens_energy')
        res_list = []
        for f in files:
            df_fil = self.df_models[self.df_models.file == f.name]
            assert len(df_fil) == len(
                self.models), "Predictions and models to not match."
            m_smx, m_std, m_eng = tta.Merger(), tta.Merger(), tta.Merger()
            for idx, r in df_fil.iterrows():
                m_smx.append(zarr.load(r.smx_path))
                m_std.append(zarr.load(r.std_path))
                m_eng.append(zarr.load(r.eng_path))
            smx = m_smx.result().numpy()
            g_smx[f.name] = smx
            g_seg[f.name] = np.argmax(smx, axis=-1)
            g_std[f.name] = m_std.result().numpy()
            eng = m_eng.result()
            g_eng[f.name] = eng.numpy()
            m_eng_max = energy_max(eng, ks=self.energy_ks).numpy()
            df_tmp = pd.Series({
                'file':
                f.name,
                'model':
                f'{self.arch}_ensemble',
                'energy_max':
                m_eng_max,
                'img_path':
                f,
                'pred_path':
                f'{chunk_store}/{g_seg.path}/{f.name}',
                'smx_path':
                f'{chunk_store}/{g_smx.path}/{f.name}',
                'std_path':
                f'{chunk_store}/{g_std.path}/{f.name}',
                'eng_path':
                f'{chunk_store}/{g_eng.path}/{f.name}'
            })
            res_list.append(df_tmp)
            if export_dir:
                save_mask(g_seg[f.name][:],
                          pred_path / f'{df_tmp.file}_{df_tmp.model}_mask',
                          filetype)
                if use_tta:
                    save_unc(g_std[f.name][:],
                             unc_path / f'{df_tmp.file}_{df_tmp.model}_unc',
                             filetype)
        return pd.DataFrame(res_list)

    def get_ensemble_results(self,
                             new_files,
                             export_dir=None,
                             filetype='.png',
                             **kwargs):
        res_list = []
        for i in self.models:
            g_smx, g_seg, g_std, g_eng = self.predict(new_files, i, **kwargs)
            chunk_store = g_smx.chunk_store.path
            for j, f in enumerate(new_files):
                m_path = self.models[i].name
                df_tmp = pd.Series({
                    'file':
                    f.name,
                    'model_no':
                    i,
                    'model':
                    m_path,
                    'img_path':
                    f,
                    'pred_path':
                    f'{chunk_store}/{g_seg.path}/{f.name}',
                    'smx_path':
                    f'{chunk_store}/{g_smx.path}/{f.name}',
                    'std_path':
                    f'{chunk_store}/{g_std.path}/{f.name}',
                    'eng_path':
                    f'{chunk_store}/{g_eng.path}/{f.name}'
                })
                res_list.append(df_tmp)
        self.df_models = pd.DataFrame(res_list)
        self.df_ens = self.ensemble_results(new_files,
                                            export_dir=export_dir,
                                            filetype=filetype,
                                            **kwargs)
        return self.df_ens

    def score_ensemble_results(self, mask_dir=None, label_fn=None):
        if not label_fn:
            label_fn = get_label_fn(self.df_ens.img_path[0],
                                    self.path / mask_dir)
        for idx, r in self.df_ens.iterrows():
            msk_path = self.label_fn(r.img_path)
            msk = _read_msk(msk_path)
            self.df_ens.loc[idx, 'msk_path'] = msk_path
            pred = zarr.load(r.pred_path)
            self.df_ens.loc[idx, 'iou'] = iou(msk, pred)
        return self.df_ens

    def show_ensemble_results(self,
                              files=None,
                              model_no=None,
                              unc=True,
                              unc_metric=None):
        if self.df_ens is None:
            assert print("Please run `get_ensemble_results` first.")
        if model_no is None: df = self.df_ens
        else: df = self.df_models[df_models.model_no == model_no]
        if files is not None: df = df.set_index('file', drop=False).loc[files]
        for _, r in df.iterrows():
            imgs = []
            imgs.append(_read_img(r.img_path)[:])
            if 'iou' in r.index:
                imgs.append(_read_msk(r.msk_path))
                hastarget = True
            else:
                hastarget = False
            imgs.append(zarr.load(r.pred_path))
            if unc: imgs.append(zarr.load(r.std_path))
            plot_results(*imgs,
                         df=r,
                         hastarget=hastarget,
                         unc_metric=unc_metric)

    def lr_find(self, files=None, **kwargs):
        files = files or self.files
        dls = self.get_dls(files)
        pre = None if self.pretrained == 'new' else self.pretrained
        model = self.get_model(pretrained=pre)
        learn = Learner(dls,
                        model,
                        metrics=self.metrics,
                        wd=self.wd,
                        loss_func=self.loss_fn,
                        opt_func=_optim_dict[self.optim])
        if self.mpt: learn.to_fp16()
        sug_lrs = learn.lr_find(**kwargs)
        return sug_lrs, learn.recorder

    def show_mask_weights(self, files, figsize=(12, 12), **kwargs):
        masks = [self.label_fn(Path(f)) for f in files]
        for m in masks:
            print(self.mw_kwargs)
            print(f'Calculating weights. Please wait...')
            msk = _read_msk(m)
            _, w, _ = calculate_weights(msk, n_dims=self.c, **self.mw_kwargs)
            fig, axes = plt.subplots(nrows=1,
                                     ncols=2,
                                     figsize=figsize,
                                     **kwargs)
            axes[0].imshow(msk)
            axes[0].set_axis_off()
            axes[0].set_title(f'Mask {m.name}')
            axes[1].imshow(w)
            axes[1].set_axis_off()
            axes[1].set_title('Weights')
            plt.show()

    def ood_train(self, features=['energy_max'], **kwargs):
        self.ood = Pipeline([('scaler', StandardScaler()),
                             ('svm', svm.OneClassSVM(**kwargs))])
        self.ood.fit(self.df_ens[features])

    def ood_score(self, features=['energy_max']):
        self.df_ens['ood_score'] = self.ood.score_samples(
            self.df_ens[features])

    def ood_save(self, path):
        path = Path(path)
        joblib.dump(self.ood, path.with_suffix('.pkl'))
        print(f'Saved OOD model to {path}.pkl')

    def ood_load(self, path):
        path = Path(path)
        try:
            self.ood = joblib.load(path)
            print(f'Successsfully loaded OOD Model from {path}')
        except:
            print('Error! Select valid joblib file (.pkl)')

    def clear_tmp(self):
        try:
            shutil.rmtree('/tmp/*', ignore_errors=True)
            shutil.rmtree(self.path / '.tmp')
            print(f'Deleted temporary files from {self.path/".tmp"}')
        except:
            print(f'No temporary files to delete at {self.path/".tmp"}')
Exemple #4
0
    def fit(self,
            i,
            n_iter=None,
            lr_max=None,
            bs=None,
            n_jobs=-1,
            verbose=1,
            **kwargs):
        n_iter = n_iter or self.n_iter
        lr_max = lr_max or self.lr
        bs = bs or self.bs
        self.stats = self.stats or self.ds.compute_stats()
        name = self.ensemble_dir / f'{self.arch}_model-{i}.pth'
        files_train, files_val = self.splits[i]
        train_ds = RandomTileDataset(files_train,
                                     label_fn=self.label_fn,
                                     n_jobs=n_jobs,
                                     verbose=verbose,
                                     **self.mw_kwargs,
                                     **self.ds_kwargs)
        valid_ds = TileDataset(files_val,
                               label_fn=self.label_fn,
                               n_jobs=n_jobs,
                               verbose=verbose,
                               **self.mw_kwargs,
                               **self.ds_kwargs)
        batch_tfms = Normalize.from_stats(*self.stats)
        dls = DataLoaders.from_dsets(train_ds,
                                     valid_ds,
                                     bs=bs,
                                     after_item=self.item_tfms,
                                     after_batch=batch_tfms)
        pre = None if self.pretrained == 'new' else self.pretrained
        model = torch.hub.load(self.repo,
                               self.arch,
                               pretrained=pre,
                               n_classes=dls.c,
                               in_channels=self.in_channels,
                               **kwargs)
        if torch.cuda.is_available(): dls.cuda(), model.cuda()
        learn = Learner(dls,
                        model,
                        metrics=self.metrics,
                        wd=self.wd,
                        loss_func=self.loss_fn,
                        opt_func=_optim_dict[self.optim],
                        cbs=self.cbs)
        learn.model_dir = self.ensemble_dir.parent / '.tmp'
        if self.mpt: learn.to_fp16()
        print(f'Starting training for {name.name}')
        epochs = calc_iterations(n_iter=n_iter, ds_length=len(train_ds), bs=bs)
        learn.fit_one_cycle(epochs, lr_max)

        print(f'Saving model at {name}')
        name.parent.mkdir(exist_ok=True, parents=True)
        self.save_model(name, learn.model)
        self.models[i] = name
        self.recorder[i] = learn.recorder
        del model
        gc.collect()
        torch.cuda.empty_cache()