Ejemplo n.º 1
0
    def _evaluate_classification(self, model, epoch, data_loader, model_name,
                                 dataset_name, ranks, lr_finder):
        labelmap = []

        if data_loader.dataset.classes and get_model_attr(model, 'classification_classes') and \
                len(data_loader.dataset.classes) < len(get_model_attr(model, 'classification_classes')):

            for class_name in sorted(data_loader.dataset.classes.keys()):
                labelmap.append(data_loader.dataset.classes[class_name])

        cmc, mAP, norm_cm = metrics.evaluate_classification(
            data_loader, model, self.use_gpu, ranks, labelmap)

        if self.writer is not None and not lr_finder:
            self.writer.add_scalar(
                'Val/{}/{}/mAP'.format(dataset_name, model_name), mAP,
                epoch + 1)
            for i, r in enumerate(ranks):
                self.writer.add_scalar(
                    'Val/{}/{}/Rank-{}'.format(dataset_name, model_name, r),
                    cmc[i], epoch + 1)

        if not lr_finder:
            print('** Results ({}) **'.format(model_name))
            print('mAP: {:.2%}'.format(mAP))
            for i, r in enumerate(ranks):
                print('Rank-{:<3}: {:.2%}'.format(r, cmc[i]))
            if norm_cm.shape[0] <= 20:
                metrics.show_confusion_matrix(norm_cm)

        return cmc[0]
Ejemplo n.º 2
0
    def fast_ai(self):
        criterion = self.engine.main_losses[0]

        if self.epochs_warmup != 0:
            get_model_attr(self.model, 'to')(self.model_device)
            print("Warmup the model's weights for {} epochs".format(self.epochs_warmup))
            self.engine.run(max_epoch=self.epochs_warmup, lr_finder=self.engine_cfg, stop_callback=self.stop_callback, eval_freq=1)
            print("Finished warmuping the model. Continue to find learning rate:")

        # run lr finder
        num_iter = len(self.engine.train_loader)
        lr_finder = WrappedLRFinder(self.model, self.optimizer, criterion, device=self.model_device)
        lr_finder.range_test(self.engine.train_loader, start_lr=self.min_lr, end_lr=self.max_lr,
                                smooth_f=self.smooth_f, num_iter=num_iter, step_mode='exp')
        ax, optim_lr = lr_finder.plot(suggest_lr=True)
        # save plot if needed
        if self.path_to_savefig:
            fig = ax.get_figure()
            fig.savefig(self.path_to_savefig)

        # reset weights and optimizer state
        if self.epochs_warmup != 0:
            self.engine.restore_model()
        else:
            lr_finder.reset()

        return optim_lr
Ejemplo n.º 3
0
 def restore_model(self):
     print("restoring model and seeds to initial state...")
     model_device = next(
         self.models[self.main_model_name].parameters()).device
     get_model_attr(self.models[self.main_model_name],
                    'load_state_dict')(self.state_cacher.retrieve("model"))
     self.optims[self.main_model_name].load_state_dict(
         self.state_cacher.retrieve("optimizer"))
     get_model_attr(self.models[self.main_model_name], 'to')(model_device)
     set_random_seed(self.seed)
Ejemplo n.º 4
0
 def backup_model(self):
     print("backuping model...")
     model_device = next(
         self.models[self.main_model_name].parameters()).device
     # explicitly put the model on the CPU before storing it in memory
     self.state_cacher.store(key="model",
                             state_dict=get_model_attr(
                                 self.models[self.main_model_name],
                                 'cpu')().state_dict())
     self.state_cacher.store(
         key="optimizer",
         state_dict=self.optims[self.main_model_name].state_dict())
     # restore the model device
     get_model_attr(self.models[self.main_model_name], 'to')(model_device)
Ejemplo n.º 5
0
def check_classification_classes(model, datamanager, classes, test_only=False):
    def check_classes_consistency(ref_classes, probe_classes, strict=False):
        if strict:
            if len(ref_classes) != len(probe_classes):
                return False
            return sorted(probe_classes.keys()) == sorted(ref_classes.keys())

        if len(ref_classes) > len(probe_classes):
            return False
        probe_names = probe_classes.keys()
        for cl in ref_classes.keys():
            if cl not in probe_names:
                return False

        return True

    classes_map = {v: k
                   for k, v in enumerate(sorted(classes))} if classes else {}
    if test_only:
        for name, dataloader in datamanager.test_loader.items():
            if not dataloader[
                    'query'].dataset.classes:  # current text annotation doesn't contain classes names
                print(
                    f'Warning: classes are not defined for validation dataset {name}'
                )
                continue
            if not get_model_attr(model, 'classification_classes'):
                print(
                    'Warning: classes are not provided in the current snapshot. Consistency checks are skipped.'
                )
                continue
            if not check_classes_consistency(
                    get_model_attr(model, 'classification_classes'),
                    dataloader['query'].dataset.classes,
                    strict=False):
                raise ValueError('Inconsistent classes in evaluation dataset')
            if classes and not check_classes_consistency(
                    classes_map,
                    get_model_attr(model, 'classification_classes'),
                    strict=True):
                raise ValueError(
                    'Classes provided via --classes should be the same as in the loaded model'
                )
    elif classes:
        if not check_classes_consistency(
                classes_map, datamanager.train_loader.dataset.classes,
                strict=True):
            raise ValueError('Inconsistent classes in training dataset')
Ejemplo n.º 6
0
def score_extraction(data_loader,
                     model,
                     use_gpu,
                     labelmap=[],
                     head_id=0,
                     perf_monitor=None,
                     feature_dump_mode='none'):

    assert feature_dump_mode in __FEATURE_DUMP_MODES
    return_featuremaps = feature_dump_mode != __FEATURE_DUMP_MODES[0]

    with torch.no_grad():
        out_scores, gt_labels, all_feature_maps, all_feature_vecs = [], [], [], []
        for batch_idx, data in enumerate(data_loader):
            batch_images, batch_labels = data[0], data[1]
            if perf_monitor: perf_monitor.on_test_batch_begin(batch_idx, None)
            if use_gpu:
                batch_images = batch_images.cuda()

            if labelmap:
                for i, label in enumerate(labelmap):
                    batch_labels[torch.where(batch_labels == i)] = label

            if perf_monitor: perf_monitor.on_test_batch_end(batch_idx, None)

            if return_featuremaps:
                logits, features, global_features = model.forward(
                    batch_images, return_all=return_featuremaps)[head_id]
                if feature_dump_mode == __FEATURE_DUMP_MODES[1]:
                    all_feature_maps.append(features)
                all_feature_vecs.append(global_features)
            else:
                logits = model.forward(batch_images)[head_id]
            out_scores.append(logits * get_model_attr(model, 'scale'))
            gt_labels.append(batch_labels)

        out_scores = torch.cat(out_scores, 0).data.cpu().numpy()
        gt_labels = torch.cat(gt_labels, 0).data.cpu().numpy()

        if all_feature_vecs:
            all_feature_vecs = torch.cat(all_feature_vecs,
                                         0).data.cpu().numpy()
            all_feature_vecs = all_feature_vecs.reshape(
                all_feature_vecs.shape[0], -1)
            if feature_dump_mode == __FEATURE_DUMP_MODES[2]:
                return (out_scores, all_feature_vecs), gt_labels

        if all_feature_maps:
            all_feature_maps = torch.cat(all_feature_maps,
                                         0).data.cpu().numpy()
            return (out_scores, all_feature_maps, all_feature_vecs), gt_labels

    return out_scores, gt_labels
Ejemplo n.º 7
0
def evaluate_classification(dataloader,
                            model,
                            use_gpu,
                            topk=(1, ),
                            labelmap=[]):
    if get_model_attr(model, 'is_ie_model'):
        scores, labels = score_extraction_from_ir(dataloader, model, labelmap)
    else:
        scores, labels = score_extraction(dataloader, model, use_gpu, labelmap)

    m_ap = mean_average_precision(scores, labels)

    cmc = []
    for k in topk:
        cmc.append(mean_top_k_accuracy(scores, labels, k=k))

    norm_cm = norm_confusion_matrix(scores, labels)

    return cmc, m_ap, norm_cm
Ejemplo n.º 8
0
    def model_eval_fn(model):
        """
        Runs evaluation of the model on the validation set and
        returns the target metric value.
        Used to evaluate the original model before compression
        if NNCF-based accuracy-aware training is used.
        """
        from torchreid.metrics.classification import evaluate_classification

        if test_loader is None:
            raise RuntimeError(
                'Cannot perform a model evaluation on the validation '
                'dataset since the validation data loader was not passed '
                'to wrap_nncf_model')

        model_type = get_model_attr(model, 'type')
        targets = list(test_loader.keys())
        use_gpu = cur_device.type == 'cuda'
        for dataset_name in targets:
            domain = 'source' if dataset_name in datamanager_for_init.sources else 'target'
            print(f'##### Evaluating {dataset_name} ({domain}) #####')
            if model_type == 'classification':
                cmc, _, _ = evaluate_classification(
                    test_loader[dataset_name]['query'], model, use_gpu=use_gpu)
                accuracy = cmc[0]
            elif model_type == 'multilabel':
                mAP, _, _, _, _, _, _ = evaluate_multilabel_classification(
                    test_loader[dataset_name]['query'], model, use_gpu=use_gpu)
                accuracy = mAP
            else:
                raise ValueError(
                    f'Cannot perform a model evaluation on the validation dataset'
                    f'since the model has unsupported model_type {model_type or "None"}'
                )

        return accuracy
Ejemplo n.º 9
0
    def test(self,
             epoch,
             dist_metric='euclidean',
             normalize_feature=False,
             visrank=False,
             visrank_topk=10,
             save_dir='',
             use_metric_cuhk03=False,
             ranks=(1, 5, 10, 20),
             rerank=False,
             lr_finder=False,
             test_only=False):
        r"""Tests model on target datasets.

        .. note::

            This function has been called in ``run()``.

        .. note::

            The test pipeline implemented in this function suits both image- and
            video-reid. In general, a subclass of Engine only needs to re-implement
            ``extract_features()`` and ``parse_data_for_eval()`` (most of the time),
            but not a must. Please refer to the source code for more details.
        """

        self.set_model_mode('eval')
        targets = list(self.test_loader.keys())
        top1, cur_top1, ema_top1 = [-1] * 3
        should_save_ema_model = False

        for dataset_name in targets:
            domain = 'source' if dataset_name in self.datamanager.sources else 'target'
            print('##### Evaluating {} ({}) #####'.format(
                dataset_name, domain))
            # TO DO reduce amount of code for evaluation functions (DRY rule)
            for model_id, (model_name,
                           model) in enumerate(self.models.items()):
                ema_condition = (self.use_ema_decay and not lr_finder
                                 and not test_only
                                 and model_name == self.main_model_name)
                model_type = get_model_attr(model, 'type')
                if model_type == 'classification':
                    # do not evaluate second model till last epoch
                    if (model_name != self.main_model_name and not test_only
                            and epoch != (self.max_epoch - 1)):
                        continue
                    cur_top1 = self._evaluate_classification(
                        model=model,
                        epoch=epoch,
                        data_loader=self.test_loader[dataset_name]['query'],
                        model_name=model_name,
                        dataset_name=dataset_name,
                        ranks=ranks,
                        lr_finder=lr_finder)
                    if ema_condition:
                        ema_top1 = self._evaluate_classification(
                            model=self.ema_model.module,
                            epoch=epoch,
                            data_loader=self.test_loader[dataset_name]
                            ['query'],
                            model_name='EMA model',
                            dataset_name=dataset_name,
                            ranks=ranks,
                            lr_finder=lr_finder)
                elif model_type == 'contrastive':
                    pass
                elif model_type == 'multilabel':
                    # do not evaluate second model till last epoch
                    if (model_name != self.main_model_name and not test_only
                            and epoch != (self.max_epoch - 1)):
                        continue
                    # we compute mAP, but consider it top1 for consistency
                    # with single label classification
                    cur_top1 = self._evaluate_multilabel_classification(
                        model=model,
                        epoch=epoch,
                        data_loader=self.test_loader[dataset_name]['query'],
                        model_name=model_name,
                        dataset_name=dataset_name,
                        lr_finder=lr_finder)
                    if ema_condition:
                        ema_top1 = self._evaluate_multilabel_classification(
                            model=self.ema_model.module,
                            epoch=epoch,
                            data_loader=self.test_loader[dataset_name]
                            ['query'],
                            model_name='EMA model',
                            dataset_name=dataset_name,
                            lr_finder=lr_finder)
                elif dataset_name == 'lfw':
                    self._evaluate_pairwise(
                        model=model,
                        epoch=epoch,
                        data_loader=self.test_loader[dataset_name]['pairs'],
                        model_name=model_name)
                else:
                    cur_top1 = self._evaluate_reid(
                        model=model,
                        epoch=epoch,
                        model_name=model_name,
                        dataset_name=dataset_name,
                        query_loader=self.test_loader[dataset_name]['query'],
                        gallery_loader=self.test_loader[dataset_name]
                        ['gallery'],
                        dist_metric=dist_metric,
                        normalize_feature=normalize_feature,
                        visrank=visrank,
                        visrank_topk=visrank_topk,
                        save_dir=save_dir,
                        use_metric_cuhk03=use_metric_cuhk03,
                        ranks=ranks,
                        rerank=rerank,
                        lr_finder=lr_finder)

                if model_id == 0:
                    # the function should return accuracy results for the first (main) model only
                    if self.use_ema_decay and ema_top1 >= cur_top1:
                        should_save_ema_model = True
                        top1 = ema_top1
                    else:
                        top1 = cur_top1

        return top1, should_save_ema_model
Ejemplo n.º 10
0
    def __init__(self,
                 datamanager,
                 models,
                 optimizers,
                 schedulers,
                 use_gpu=True,
                 save_all_chkpts=True,
                 train_patience=10,
                 lr_decay_factor=1000,
                 lr_finder=None,
                 early_stopping=False,
                 should_freeze_aux_models=False,
                 nncf_metainfo=None,
                 compression_ctrl=None,
                 initial_lr=None,
                 target_metric='train_loss',
                 epoch_interval_for_aux_model_freeze=None,
                 epoch_interval_for_turn_off_mutual_learning=None,
                 use_ema_decay=False,
                 ema_decay=0.999,
                 seed=5):

        self.datamanager = datamanager
        self.train_loader = self.datamanager.train_loader
        self.test_loader = self.datamanager.test_loader
        self.use_gpu = (torch.cuda.is_available() and use_gpu)
        self.save_all_chkpts = save_all_chkpts
        self.writer = None
        self.use_ema_decay = use_ema_decay
        self.start_epoch = 0
        self.lr_finder = lr_finder
        self.fixbase_epoch = 0
        self.iter_to_wait = 0
        self.best_metric = 0.0
        self.max_epoch = None
        self.num_batches = None
        assert target_metric in ['train_loss', 'test_acc']
        self.target_metric = target_metric
        self.epoch = None
        self.train_patience = train_patience
        self.early_stopping = early_stopping
        self.state_cacher = StateCacher(in_memory=True, cache_dir=None)
        self.param_history = set()
        self.seed = seed
        self.models = OrderedDict()
        self.optims = OrderedDict()
        self.scheds = OrderedDict()
        self.ema_model = None
        if should_freeze_aux_models:
            print(
                f'Engine: should_freeze_aux_models={should_freeze_aux_models}')
        self.should_freeze_aux_models = should_freeze_aux_models
        self.nncf_metainfo = deepcopy(nncf_metainfo)
        self.compression_ctrl = compression_ctrl
        self.initial_lr = initial_lr
        self.epoch_interval_for_aux_model_freeze = epoch_interval_for_aux_model_freeze
        self.epoch_interval_for_turn_off_mutual_learning = epoch_interval_for_turn_off_mutual_learning
        self.model_names_to_freeze = []
        self.current_lr = None

        if isinstance(models, (tuple, list)):
            assert isinstance(optimizers, (tuple, list))
            assert isinstance(schedulers, (tuple, list))

            num_models = len(models)
            assert len(optimizers) == num_models
            assert len(schedulers) == num_models

            for model_id, (model, optimizer, scheduler) in enumerate(
                    zip(models, optimizers, schedulers)):
                model_name = 'main_model' if model_id == 0 else f'aux_model_{model_id}'
                self.register_model(model_name, model, optimizer, scheduler)
                if use_ema_decay and model_id == 0:
                    self.ema_model = ModelEmaV2(model, decay=ema_decay)
                if should_freeze_aux_models and model_id > 0:
                    self.model_names_to_freeze.append(model_name)
        else:
            assert not isinstance(optimizers, (tuple, list))
            assert not isinstance(schedulers, (tuple, list))
            assert not isinstance(models, (tuple, list))
            self.register_model('main_model', models, optimizers, schedulers)
            if use_ema_decay:
                self.ema_model = ModelEmaV2(models, decay=ema_decay)
        self.main_model_name = self.get_model_names()[0]
        self.scales = dict()
        for model_name, model in self.models.items():
            scale = get_model_attr(model, 'scale')
            if not get_model_attr(model,
                                  'use_angle_simple_linear') and scale != 1.:
                print(
                    f"WARNING:: Angle Linear is not used but the scale parameter in the loss {scale} != 1."
                )
            self.scales[model_name] = scale
        self.am_scale = self.scales[
            self.main_model_name]  # for loss initialization
        assert initial_lr is not None
        self.lb_lr = initial_lr / lr_decay_factor
        self.per_batch_annealing = isinstance(
            self.scheds[self.main_model_name],
            (CosineAnnealingCycleRestart, OneCycleLR))
Ejemplo n.º 11
0
def evaluate_multilabel_classification(dataloader, model, use_gpu):
    def average_precision(output, target):
        epsilon = 1e-8

        # sort examples
        indices = output.argsort()[::-1]
        # Computes prec@i
        total_count_ = np.cumsum(np.ones((len(output), 1)))

        target_ = target[indices]
        ind = target_ == 1
        pos_count_ = np.cumsum(ind)
        total = pos_count_[-1]
        pos_count_[np.logical_not(ind)] = 0
        pp = pos_count_ / total_count_
        precision_at_i_ = np.sum(pp)
        precision_at_i = precision_at_i_ / (total + epsilon)

        return precision_at_i

    def mAP(targs, preds, pos_thr=0.5):
        """Returns the model's average precision for each class
        Return:
            ap (FloatTensor): 1xK tensor, with avg precision for each class k
        """
        if np.size(preds) == 0:
            return 0
        ap = np.zeros((preds.shape[1]))
        # compute average precision for each class
        for k in range(preds.shape[1]):
            scores = preds[:, k]
            targets = targs[:, k]
            ap[k] = average_precision(scores, targets)
        tp, fp, fn, tn = [], [], [], []
        for k in range(preds.shape[0]):
            scores = preds[k, :]
            targets = targs[k, :]
            pred = (scores > pos_thr).astype(np.int32)
            tp.append(((pred + targets) == 2).sum())
            fp.append(((pred - targets) == 1).sum())
            fn.append(((pred - targets) == -1).sum())
            tn.append(((pred + targets) == 0).sum())

        p_c = [
            tp[i] / (tp[i] + fp[i]) if tp[i] > 0 else 0.0
            for i in range(len(tp))
        ]
        r_c = [
            tp[i] / (tp[i] + fn[i]) if tp[i] > 0 else 0.0
            for i in range(len(tp))
        ]
        f_c = [
            2 * p_c[i] * r_c[i] / (p_c[i] + r_c[i]) if tp[i] > 0 else 0.0
            for i in range(len(tp))
        ]

        mean_p_c = sum(p_c) / len(p_c)
        mean_r_c = sum(r_c) / len(r_c)
        mean_f_c = sum(f_c) / len(f_c)

        p_o = sum(tp) / (np.array(tp) + np.array(fp)).sum()
        r_o = sum(tp) / (np.array(tp) + np.array(fn)).sum()
        f_o = 2 * p_o * r_o / (p_o + r_o)

        return ap.mean(), mean_p_c, mean_r_c, mean_f_c, p_o, r_o, f_o

    if get_model_attr(model, 'is_ie_model'):
        scores, labels = score_extraction_from_ir(dataloader, model)
    else:
        scores, labels = score_extraction(dataloader, model, use_gpu)

    scores = 1. / (1 + np.exp(-scores))
    mAP_score = mAP(labels, scores)

    return mAP_score
Ejemplo n.º 12
0
def _build_optim(model,
                 optim='adam',
                 base_optim='sgd',
                 lr=0.0003,
                 weight_decay=5e-04,
                 momentum=0.9,
                 sgd_dampening=0,
                 sgd_nesterov=False,
                 rmsprop_alpha=0.99,
                 adam_beta1=0.9,
                 adam_beta2=0.99,
                 staged_lr=False,
                 new_layers='',
                 base_lr_mult=0.1,
                 nbd=False,
                 lr_finder=False,
                 sam_rho=0.05,
                 sam_adaptive=False):

    param_groups = []
    if optim not in AVAI_OPTIMS:
        raise ValueError('Unsupported optimizer: {}. Must be one of {}'.format(
            optim, AVAI_OPTIMS))

    if isinstance(base_optim, SAM):
        raise ValueError('Invalid base optimizer. SAM cannot be the base one')

    if not isinstance(model, nn.Module):
        raise TypeError(
            'model given to build_optimizer must be an instance of nn.Module')

    if staged_lr:
        if isinstance(new_layers, str):
            if new_layers is None:
                warnings.warn(
                    'new_layers is empty, therefore, staged_lr is useless')
            new_layers = [new_layers]

        base_params = []
        base_layers = []
        new_params = []

        for name, module in model.named_children():
            if name in new_layers:
                new_params += [p for p in module.parameters()]
            else:
                base_params += [p for p in module.parameters()]
                base_layers.append(name)

        param_groups = [
            {
                'params': base_params,
                'lr': lr * base_lr_mult
            },
            {
                'params': new_params
            },
        ]

    # we switch off nbd when lr_finder enabled
    # because optimizer builded once and lr in biases isn't changed
    elif nbd and not lr_finder:
        compression_params = set()
        CompressionParameter = get_compression_parameter()
        if CompressionParameter:
            for param_group in get_model_attr(model, 'get_config_optim')(lr):
                layer_params = param_group['params']
                for name, param in layer_params:
                    if param.requires_grad and isinstance(
                            param, CompressionParameter):
                        compression_params.add(param)

        for param_group in get_model_attr(model, 'get_config_optim')(lr):
            if 'weight_decay' in param_group:
                # weight_decay is already set for these parameters
                param_groups.append(param_group)
                continue

            decay, bias_no_decay, weight_no_decay = [], [], []
            group_lr = param_group['lr']
            layer_params = param_group['params']
            for name, param in layer_params:
                if param in compression_params:
                    continue  # Param is already registered
                elif not param.requires_grad:
                    continue  # frozen weights
                elif name.endswith("bias"):
                    bias_no_decay.append(param)
                elif len(param.shape) == 1:
                    weight_no_decay.append(param)
                elif (name.endswith("weight")
                      and ("norm" in name or "query_embed" in name)):
                    weight_no_decay.append(param)
                else:
                    decay.append(param)

            cur_params = [{
                'params': decay,
                'lr': group_lr,
                'weight_decay': weight_decay
            }, {
                'params': bias_no_decay,
                'lr': 2 * group_lr,
                'weight_decay': 0.0
            }, {
                'params': weight_no_decay,
                'lr': group_lr,
                'weight_decay': 0.0
            }]
            param_groups.extend(cur_params)

        if compression_params:
            param_groups.append({
                'params': list(compression_params),
                'lr': lr,
                'weight_decay': 0.0
            })
    else:
        for param_group in get_model_attr(model, 'get_config_optim')(lr):
            group_weight_decay = param_group[
                'weight_decay'] if 'weight_decay' in param_group else weight_decay
            param_groups.append({
                'params': [param for _, param in param_group['params']],
                'lr':
                param_group['lr'],
                'weight_decay':
                group_weight_decay
            })

    if optim == 'adam':
        optimizer = torch.optim.AdamW(
            param_groups,
            betas=(adam_beta1, adam_beta2),
        )

    elif optim == 'amsgrad':
        optimizer = torch.optim.AdamW(
            param_groups,
            betas=(adam_beta1, adam_beta2),
            amsgrad=True,
        )

    elif optim == 'sgd':
        optimizer = torch.optim.SGD(
            param_groups,
            momentum=momentum,
            dampening=sgd_dampening,
            nesterov=sgd_nesterov,
        )

    elif optim == 'rmsprop':
        optimizer = torch.optim.RMSprop(
            param_groups,
            momentum=momentum,
            alpha=rmsprop_alpha,
        )

    elif optim == 'radam':
        optimizer = RAdam(param_groups, betas=(adam_beta1, adam_beta2))

    if optim == 'sam':
        if not base_optim:
            raise ValueError("SAM cannot operate without base optimizer. "
                             "Please add it to configuration file")
        optimizer = SAM(params=param_groups,
                        base_optimizer=base_optim,
                        rho=sam_rho,
                        adaptive=sam_adaptive)

    return optimizer