def __call__(self, *args, **kw):
     # newbase = base | (kw & options)  # imagine all the syntax
     if self._magic:
         kw = self._magic(kw)
     newbase = ub.dict_union(self._base, ub.dict_isect(kw, self._options))
     new = Element(newbase, self._options, self._magic)
     return new
Exemple #2
0
 def json_id(self):
     children = ub.odict([(key, child.json_id())
                          for key, child in self.children()])
     params = ub.odict([
         (key, value.tolist() if isinstance(value, np.ndarray) else value)
         for key, value in self._params.items()])
     return ub.dict_union(ub.odict([('__class__', self.__class__.__name__)]),
                          params,
                          children)
Exemple #3
0
class DetectPredictCLIConfig(scfg.Config):
    default = ub.dict_union(
        {
            'dataset': scfg.Value(None, help='coco dataset, path to images or folder of images'),
            'out_dpath': scfg.Value('./out', help='output directory'),
            'draw': scfg.Value(False),
            'workdir': scfg.Value('~/work/bioharn', help='work directory for sampler if needed'),
        },
        DetectPredictConfig.default
    )
Exemple #4
0
    def _normalize_attrs(self):
        tup = self.clf_key.split('-')
        wrap_type = None if len(tup) == 1 else tup[1]
        est_type = tup[0]

        self.est_type = est_type
        self.wrap_type = wrap_type

        est_kw1, est_kw2 = self._lookup_params(self.est_type)
        self.est_kw1 = est_kw1
        self.est_kw2 = est_kw2
        self.est_kw = ub.dict_union(est_kw1, est_kw2, self.kw)
    def __init__(self, base, options={}, _magic=None):
        """
        Args:
            base (dict): the keys / values this schema must contain
            options (options): the keys / values this schema may contain
            _magic (callable): called when creating an instance of this schema
                element. Allows convinience attributes to be converted to the
                formal jsonschema specs. TODO: _magic is a terrible name, we
                need to rename it with something descriptive.
        """
        self._base = base

        if isinstance(options, (set, list, tuple)):
            options = {k: None for k in options}

        self._options = ub.dict_union(options, self.__generics__)
        self._magic = _magic
        super().__init__(base)
Exemple #6
0
    def likely_duplicates(cls, pfiles, thresh=0.2, verbose=1):
        final_groups = {}
        active_groups = [pfiles]
        mode = 'thread'
        max_workers = 6

        while active_groups:
            group_sizes = list(map(len, active_groups))
            total_active = sum(group_sizes)
            print('Checking {} active groups with {} items'.format(
                len(active_groups), total_active))
            groups = ub.dict_union(
                *[ProgressiveFile.group_pfiles(g) for g in active_groups])

            # Mark all groups that need refinement
            refine_items = []
            next_groups = []
            for key, group in groups.items():
                if len(group) > 1 and key[-1] < thresh:
                    next_groups.append(group)
                    refine_items.extend([
                        item for item in group if item.step_id()[-1] < thresh
                    ])
                else:
                    # Any group that doesnt need refinment is added to the
                    # solution and will not appear in the next active group
                    final_groups[key] = group

            # Refine any item that needs it
            if len(refine_items):
                # TODO: if there are few enough items, just refine to the
                # threshold?
                ProgressiveFile.parallel_refine(refine_items,
                                                mode=mode,
                                                step_idx='next',
                                                max_workers=max_workers,
                                                verbose=verbose)

            # Continue refinement as long as there are active groups
            active_groups = next_groups
        return final_groups
Exemple #7
0
    def likely_overlaps(cls, pfiles1, pfiles2, thresh=0.2, verbose=1):
        """
        This is similar to finding duplicates, but between two sets of files

        Example:
            >>> fpaths = _demodata_files(num_files=100, rng=0)
            >>> fpaths1 = fpaths[0::2]
            >>> fpaths2 = fpaths[1::2]
            >>> pfiles1 = [ProgressiveFile(f) for f in fpaths1]
            >>> pfiles2 = [ProgressiveFile(f) for f in fpaths2]
            >>> overlap, only1, only2 = ProgressiveFile.likely_overlaps(pfiles1, pfiles2)
            >>> print(len(overlaps))
            >>> print(len(only1))
            >>> print(len(only2))
        """
        final_groups = {}

        # Mark each set of files, so we only refine if a duplicate group
        # contains elements from multiple sets

        set1 = {id(p) for p in pfiles1}
        set2 = {id(p) for p in pfiles2}

        def _membership(p):
            partof = []
            pid = id(p)
            if pid in set1:
                partof.append(1)
            if pid in set2:
                partof.append(2)
            return partof

        pfiles = pfiles1 + pfiles2

        active_groups = [pfiles]
        mode = 'thread'
        max_workers = 6

        if isinstance(thresh, dict):
            frac_thresh = thresh.get('frac', None)
            byte_thresh = thresh.get('byte', None)
        else:
            frac_thresh = thresh
            byte_thresh = thresh

        while active_groups:
            group_sizes = list(map(len, active_groups))
            total_active = sum(group_sizes)
            print('Checking {} active groups with {} items'.format(
                len(active_groups), total_active))
            groups = ub.dict_union(
                *[ProgressiveFile.group_pfiles(g) for g in active_groups])

            # Mark all groups that need refinement
            refine_items = []
            next_groups = []
            for key, group in groups.items():
                membership = {m for p in group for m in _membership(p)}

                group_frac = key[3]
                group_byte = key[1]
                # Check if we have hashed enough of the file by fraction or
                # number of bytes.
                terms = []
                if frac_thresh is not None:
                    terms.append(group_frac >= frac_thresh)
                if byte_thresh is not None:
                    terms.append(group_byte >= byte_thresh)
                good_enough = any(terms) or len(terms) == 0

                if not good_enough and len(membership) > 1 and len(group) > 1:
                    next_groups.append(group)
                    needs_refine = [
                        item for item in group
                        if not item.complete_enough(frac_thresh=frac_thresh,
                                                    byte_thresh=byte_thresh)
                    ]
                    refine_items.extend(needs_refine)
                else:
                    # Any group that doesnt need refinment is added to the
                    # solution and will not appear in the next active group
                    final_groups[key] = group

            # Refine any item that needs it
            if len(refine_items):
                # TODO: if there are few enough items, just refine to the
                # threshold?
                ProgressiveFile.parallel_refine(refine_items,
                                                mode=mode,
                                                step_idx='next',
                                                max_workers=max_workers,
                                                verbose=verbose)

            # Continue refinement as long as there are active groups
            active_groups = next_groups

        only1 = {}
        only2 = {}
        overlap = {}
        for key, group in final_groups.items():
            membership = {m for p in group for m in _membership(p)}
            if len(membership) == 1:
                if ub.peek(membership) == 1:
                    only1[key] = group
                else:
                    only2[key] = group
            else:
                overlap[key] = group

        return overlap, only1, only2
Exemple #8
0
def run_benchmark_renormalization():
    """
    See if we can renormalize probabilities after update with a faster method
    that maintains memory a bit better

    Example:
        >>> import sys, ubelt
        >>> sys.path.append(ubelt.expandpath('~/misc/tests/python'))
        >>> from bench_renormalization import *  # NOQA
        >>> run_benchmark_renormalization()
    """
    import ubelt as ub
    import xdev
    import pathlib
    import timerit

    fpath = pathlib.Path('~/misc/tests/python/renormalize_cython.pyx').expanduser()
    renormalize_cython = xdev.import_module_from_pyx(fpath, annotate=True,
                                                     verbose=3, recompile=True)

    xdev.profile_now(renormalize_demo_v1)(1000, 100)
    xdev.profile_now(renormalize_demo_v2)(1000, 100)
    xdev.profile_now(renormalize_demo_v3)(1000, 100)
    xdev.profile_now(renormalize_demo_v4)(1000, 100)

    func_list = [
        # renormalize_demo_v1,
        renormalize_demo_v2,
        # renormalize_demo_v3,
        # renormalize_demo_v4,
        renormalize_cython.renormalize_demo_cython_v1,
        renormalize_cython.renormalize_demo_cython_v2,
        renormalize_cython.renormalize_demo_cython_v3,
    ]
    methods = {f.__name__: f for f in func_list}
    for key, method in methods.items():
        with timerit.Timer(label=key, verbose=0) as t:
            method(1000, 100)
        print(f'{key:<30} {t.toc():0.6f}')

    arg_basis = {
        'T': [10, 20,  30,  50],
        'D': [10, 50, 100, 300],
    }
    args_grid = []
    for argkw in list(ub.named_product(arg_basis)):
        if argkw['T'] <= argkw['D']:
            arg_basis['size'] = argkw['T'] * argkw['D']
            args_grid.append(argkw)

    ti = timerit.Timerit(100, bestof=10, verbose=2)

    measures = []

    for method_name, method in methods.items():
        for argkw in args_grid:
            row = ub.dict_union({'method': method_name}, argkw)
            key = ub.repr2(row, compact=1)
            argkey = ub.repr2(argkw, compact=1)

            kwargs = ub.dict_subset(argkw, ['T', 'D'])
            for timer in ti.reset('time'):
                with timer:
                    method(**kwargs)

            row['mean'] = ti.mean()
            row['min'] = ti.min()
            row['key'] = key
            row['argkey'] = argkey
            measures.append(row)

    import pandas as pd
    df = pd.DataFrame(measures)
    import kwplot
    sns = kwplot.autosns()

    kwplot.figure(fnum=1, pnum=(1, 2, 1), docla=True)
    sns.lineplot(data=df, x='D', y='min', hue='method', style='method')
    kwplot.figure(fnum=1, pnum=(1, 2, 2), docla=True)
    sns.lineplot(data=df, x='T', y='min', hue='method', style='method')

    p = (df.pivot(['method'], ['argkey'], ['mean']))
    print(p.mean(axis=1).sort_values())
Exemple #9
0
def train():
    """
    Example:
        >>> train()
    """
    import random
    np.random.seed(1031726816 % 4294967295)
    torch.manual_seed(137852547 % 4294967295)
    random.seed(2497950049 % 4294967295)

    xpu = xpu_device.XPU.from_argv()
    print('Chosen xpu = {!r}'.format(xpu))

    cifar_num = 10

    if ub.argflag('--lab'):
        datasets = cifar_training_datasets(output_colorspace='LAB',
                                           norm_mode='independent',
                                           cifar_num=cifar_num)
    elif ub.argflag('--rgb'):
        datasets = cifar_training_datasets(output_colorspace='RGB',
                                           norm_mode='independent',
                                           cifar_num=cifar_num)
    elif ub.argflag('--rgb-dep'):
        datasets = cifar_training_datasets(output_colorspace='RGB',
                                           norm_mode='dependant',
                                           cifar_num=cifar_num)
    else:
        raise AssertionError('specify --rgb / --lab')

    import netharn.models.densenet

    # batch_size = (128 // 3) * 3
    batch_size = 64

    # initializer_ = (initializers.KaimingNormal, {
    #     'nonlinearity': 'relu',
    # })

    lr = 0.1
    initializer_ = (initializers.LSUV, {})

    hyper = hyperparams.HyperParams(
        workdir=ub.ensuredir('train_cifar_work'),
        model=(
            netharn.models.densenet.DenseNet,
            {
                'cifar': True,
                'block_config': (32, 32, 32),  # 100 layer depth
                'num_classes': datasets['train'].n_classes,
                'drop_rate': float(ub.argval('--drop_rate', default=.2)),
                'groups': 1,
            }),
        optimizer=(
            torch.optim.SGD,
            {
                # 'weight_decay': .0005,
                'weight_decay':
                float(ub.argval('--weight_decay', default=.0005)),
                'momentum': 0.9,
                'nesterov': True,
                'lr': 0.1,
            }),
        scheduler=(nh.schedulers.ListedLR, {
            'points': {
                0: lr,
                150: lr * 0.1,
                250: lr * 0.01,
            },
            'interpolate': False
        }),
        monitor=(nh.Monitor, {
            'minimize': ['loss'],
            'maximize': ['mAP'],
            'patience': 314,
            'max_epoch': 314,
        }),
        initializer=initializer_,
        criterion=(torch.nn.CrossEntropyLoss, {}),
        # Specify anything else that is special about your hyperparams here
        # Especially if you make a custom_batch_runner
        augment=str(datasets['train'].augmenter),
        other=ub.dict_union(
            {
                # TODO: type of augmentation as a parameter dependency
                # 'augmenter': str(datasets['train'].augmenter),
                # 'augment': datasets['train'].augment,
                'batch_size': batch_size,
                'colorspace': datasets['train'].output_colorspace,
                'n_classes': datasets['train'].n_classes,
                # 'center_inputs': datasets['train'].center_inputs,
            },
            datasets['train'].center_inputs.__dict__),
    )
    # if ub.argflag('--rgb-indie'):
    #     hyper.other['norm'] = 'dependant'
    hyper.input_ids['train'] = datasets['train'].input_id

    xpu = xpu_device.XPU.cast('auto')
    print('xpu = {}'.format(xpu))

    data_kw = {'batch_size': batch_size}
    if xpu.is_gpu():
        data_kw.update({'num_workers': 8, 'pin_memory': True})

    tags = ['train', 'vali', 'test']

    loaders = ub.odict()
    for tag in tags:
        dset = datasets[tag]
        shuffle = tag == 'train'
        data_kw_ = data_kw.copy()
        if tag != 'train':
            data_kw_['batch_size'] = max(batch_size // 4, 1)
        loader = torch.utils.data.DataLoader(dset, shuffle=shuffle, **data_kw_)
        loaders[tag] = loader

    harn = fit_harness.FitHarness(
        hyper=hyper,
        datasets=datasets,
        xpu=xpu,
        loaders=loaders,
    )
    # harn.monitor = early_stop.EarlyStop(patience=40)
    harn.monitor = monitor.Monitor(min_keys=['loss'],
                                   max_keys=['global_acc', 'class_acc'],
                                   patience=40)

    harn.initialize()
    harn.run()
Exemple #10
0
def rectify_normalizer(in_channels, key=ub.NoParam, dim=2):
    """
    Allows dictionary based specification of a normalizing layer

    Example:
        >>> rectify_normalizer(8)
        BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        >>> rectify_normalizer(8, 'batch')
        BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        >>> rectify_normalizer(8, {'type': 'batch'})
        BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        >>> rectify_normalizer(8, 'group')
        GroupNorm(8, 8, eps=1e-05, affine=True)
        >>> rectify_normalizer(8, {'type': 'group', 'num_groups': 2})
        GroupNorm(2, 8, eps=1e-05, affine=True)
        >>> rectify_normalizer(8, dim=3)
        BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        >>> rectify_normalizer(8, None)
        None
    """
    if key is None:
        return None

    if key is ub.NoParam:
        key = 'batch'

    if isinstance(key, six.string_types):
        if key == 'batch':
            key = {'type': 'batch'}
        elif key == 'group':
            key = {
                'type': 'group',
                'num_groups': ('gcd', min(in_channels, 32))
            }
        elif key == 'batch+group':
            key = {'type': 'batch+group'}
        else:
            raise KeyError(key)
    elif isinstance(key, dict):
        key = key.copy()
    else:
        raise TypeError(type(key))

    norm_type = key.pop('type')
    if norm_type == 'batch':
        in_channels_key = 'num_features'

        if dim == 1:
            cls = torch.nn.BatchNorm1d
        elif dim == 2:
            cls = torch.nn.BatchNorm2d
        elif dim == 3:
            cls = torch.nn.BatchNorm3d
        else:
            raise ValueError(dim)
    elif norm_type == 'group':
        in_channels_key = 'num_channels'
        if isinstance(key['num_groups'], tuple):
            if key['num_groups'][0] == 'gcd':
                key['num_groups'] = gcd(key['num_groups'][1], in_channels)
        if in_channels % key['num_groups'] != 0:
            raise AssertionError(
                'Cannot divide n_inputs {} by num groups {}'.format(
                    in_channels, key['num_groups']))
        cls = torch.nn.GroupNorm
    elif norm_type == 'batch+group':
        return torch.nn.Sequential(
            rectify_normalizer(in_channels, 'batch', dim=dim),
            rectify_normalizer(in_channels,
                               ub.dict_union({'type': 'group'}, key),
                               dim=dim),
        )
    else:
        raise KeyError('unknown type: {}'.format(key))
    assert in_channels_key not in key
    key[in_channels_key] = in_channels
    return cls(**key)
Exemple #11
0
def setup_harness(workers=None):
    """
    CommandLine:
        python ~/code/clab/examples/yolo_voc.py setup_harness
        python ~/code/clab/examples/yolo_voc.py setup_harness --profile

    Example:
        >>> harn = setup_harness(workers=0)
        >>> harn.initialize()
        >>> harn.dry = True
        >>> # xdoc: +SKIP
        >>> harn.run()
    """
    workdir = ub.truepath('~/work/VOC2007')
    devkit_dpath = ub.truepath('~/data/VOC/VOCdevkit')
    YoloVOCDataset.ensure_voc_data()

    if ub.argflag('--2007'):
        dsetkw = {'years': [2007]}
    elif ub.argflag('--2012'):
        dsetkw = {'years': [2007, 2012]}
    else:
        dsetkw = {'years': [2007]}

    data_choice = ub.argval('--data', 'normal')

    if data_choice == 'combined':
        datasets = {
            'test': YoloVOCDataset(devkit_dpath, split='test', **dsetkw),
            'train': YoloVOCDataset(devkit_dpath, split='trainval', **dsetkw),
        }
    elif data_choice == 'notest':
        datasets = {
            'train': YoloVOCDataset(devkit_dpath, split='train', **dsetkw),
            'vali': YoloVOCDataset(devkit_dpath, split='val', **dsetkw),
        }
    elif data_choice == 'normal':
        datasets = {
            'train': YoloVOCDataset(devkit_dpath, split='train', **dsetkw),
            'vali': YoloVOCDataset(devkit_dpath, split='val', **dsetkw),
            'test': YoloVOCDataset(devkit_dpath, split='test', **dsetkw),
        }
    else:
        raise KeyError(data_choice)

    nice = ub.argval('--nice', default=None)

    pretrained_fpath = darknet.initial_weights()

    # NOTE: XPU implicitly supports DataParallel just pass --gpu=0,1,2,3
    xpu = xpu_device.XPU.cast('argv')
    print('xpu = {!r}'.format(xpu))

    ensure_ulimit()

    postproc_params = dict(
        conf_thresh=0.001,
        nms_thresh=0.5,
        ovthresh=0.5,
    )

    max_epoch = 160

    lr_step_points = {
        0: 0.001,
        60: 0.0001,
        90: 0.00001,
    }

    if ub.argflag('--warmup'):
        lr_step_points = {
            # warmup learning rate
            0:  0.0001,
            1:  0.0001,
            2:  0.0002,
            3:  0.0003,
            4:  0.0004,
            5:  0.0005,
            6:  0.0006,
            7:  0.0007,
            8:  0.0008,
            9:  0.0009,
            10: 0.0010,
            # cooldown learning rate
            60: 0.0001,
            90: 0.00001,
        }

    batch_size = int(ub.argval('--batch_size', default=16))
    n_cpus = psutil.cpu_count(logical=True)
    workers = int(ub.argval('--workers', default=int(n_cpus / 2)))

    print('Making loaders')
    loaders = make_loaders(datasets, batch_size=batch_size,
                           workers=workers if workers is not None else workers)

    """
    Reference:
        Original YOLO9000 hyperparameters are defined here:
        https://github.com/pjreddie/darknet/blob/master/cfg/yolo-voc.2.0.cfg

        https://github.com/longcw/yolo2-pytorch/issues/1#issuecomment-286410772

        Notes:
            jitter is a translation / crop parameter
            https://groups.google.com/forum/#!topic/darknet/A-JJeXprvJU

            thresh in 2.0.cfg is iou_thresh here
    """

    print('Making hyperparams')
    hyper = hyperparams.HyperParams(

        model=(darknet.Darknet19, {
            'num_classes': datasets['train'].num_classes,
            'anchors': datasets['train'].anchors
        }),

        criterion=(darknet_loss.DarknetLoss, {
            'anchors': datasets['train'].anchors,
            'object_scale': 5.0,
            'noobject_scale': 1.0,
            'class_scale': 1.0,
            'coord_scale': 1.0,
            'iou_thresh': 0.6,
            'reproduce_longcw': ub.argflag('--longcw'),
            'denom': ub.argval('--denom', default='num_boxes'),
        }),

        optimizer=(torch.optim.SGD, dict(
            lr=lr_step_points[0],
            momentum=0.9,
            weight_decay=0.0005
        )),

        # initializer=(nninit.KaimingNormal, {}),
        initializer=(nninit.Pretrained, {
            'fpath': pretrained_fpath,
        }),

        scheduler=(ListedLR, dict(
            step_points=lr_step_points
        )),

        other=ub.dict_union({
            'nice': str(nice),
            'batch_size': loaders['train'].batch_sampler.batch_size,
        }, postproc_params),
        centering=None,

        # centering=datasets['train'].centering,
        augment=datasets['train'].augmenter,
    )

    harn = fit_harness.FitHarness(
        hyper=hyper, xpu=xpu, loaders=loaders, max_iter=max_epoch,
        workdir=workdir,
    )
    harn.postproc_params = postproc_params
    harn.nice = nice
    harn.monitor = monitor.Monitor(min_keys=['loss'],
                                   # max_keys=['global_acc', 'class_acc'],
                                   patience=max_epoch)

    @harn.set_batch_runner
    def batch_runner(harn, inputs, labels):
        """
        Custom function to compute the output of a batch and its loss.

        Example:
            >>> import sys
            >>> sys.path.append('/home/joncrall/code/clab/examples')
            >>> from yolo_voc import *
            >>> harn = setup_harness(workers=0)
            >>> harn.initialize()
            >>> batch = harn._demo_batch(0, 'train')
            >>> inputs, labels = batch
            >>> criterion = harn.criterion
            >>> weights_fpath = darknet.demo_weights()
            >>> state_dict = torch.load(weights_fpath)['model_state_dict']
            >>> harn.model.module.load_state_dict(state_dict)
            >>> outputs, loss = harn._custom_run_batch(harn, inputs, labels)
        """
        # hack for data parallel
        # if harn.current_tag == 'train':
        outputs = harn.model(*inputs)
        # else:
        #     # Run test and validation on a single GPU
        #     outputs = harn.model.module(*inputs)

        # darknet criterion needs to know the input image shape
        inp_size = tuple(inputs[0].shape[-2:])

        aoff_pred, iou_pred, prob_pred = outputs
        gt_boxes, gt_classes, orig_size, indices, gt_weights = labels

        loss = harn.criterion(aoff_pred, iou_pred, prob_pred, gt_boxes,
                              gt_classes, gt_weights=gt_weights,
                              inp_size=inp_size, epoch=harn.epoch)
        return outputs, loss

    @harn.add_batch_metric_hook
    def custom_metrics(harn, output, labels):
        metrics_dict = ub.odict()
        criterion = harn.criterion
        metrics_dict['L_bbox'] = float(criterion.bbox_loss.data.cpu().numpy())
        metrics_dict['L_iou'] = float(criterion.iou_loss.data.cpu().numpy())
        metrics_dict['L_cls'] = float(criterion.cls_loss.data.cpu().numpy())
        return metrics_dict

    # Set as a harness attribute instead of using a closure
    harn.batch_confusions = []

    @harn.add_iter_callback
    def on_batch(harn, tag, loader, bx, inputs, labels, outputs, loss):
        """
        Custom hook to run on each batch (used to compute mAP on the fly)

        Example:
            >>> harn = setup_harness(workers=0)
            >>> harn.initialize()
            >>> batch = harn._demo_batch(0, 'train')
            >>> inputs, labels = batch
            >>> criterion = harn.criterion
            >>> loader = harn.loaders['train']
            >>> weights_fpath = darknet.demo_weights()
            >>> state_dict = torch.load(weights_fpath)['model_state_dict']
            >>> harn.model.module.load_state_dict(state_dict)
            >>> outputs, loss = harn._custom_run_batch(harn, inputs, labels)
            >>> tag = 'train'
            >>> on_batch(harn, tag, loader, bx, inputs, labels, outputs, loss)
        """
        # Accumulate relevant outputs to measure
        gt_boxes, gt_classes, orig_size, indices, gt_weights = labels
        # aoff_pred, iou_pred, prob_pred = outputs
        im_sizes = orig_size
        inp_size = inputs[0].shape[-2:][::-1]

        conf_thresh = harn.postproc_params['conf_thresh']
        nms_thresh = harn.postproc_params['nms_thresh']
        ovthresh = harn.postproc_params['ovthresh']

        postout = harn.model.module.postprocess(outputs, inp_size, im_sizes,
                                                conf_thresh, nms_thresh)
        # batch_pred_boxes, batch_pred_scores, batch_pred_cls_inds = postout
        # Compute: y_pred, y_true, and y_score for this batch
        batch_pred_boxes, batch_pred_scores, batch_pred_cls_inds = postout
        batch_true_boxes, batch_true_cls_inds = labels[0:2]
        batch_orig_sz, batch_img_inds = labels[2:4]

        y_batch = []
        for bx, index in enumerate(batch_img_inds.data.cpu().numpy().ravel()):
            pred_boxes  = batch_pred_boxes[bx]
            pred_scores = batch_pred_scores[bx]
            pred_cxs    = batch_pred_cls_inds[bx]

            # Group groundtruth boxes by class
            true_boxes_ = batch_true_boxes[bx].data.cpu().numpy()
            true_cxs = batch_true_cls_inds[bx].data.cpu().numpy()
            true_weights = gt_weights[bx].data.cpu().numpy()

            # Unnormalize the true bboxes back to orig coords
            orig_size = batch_orig_sz[bx]
            sx, sy = np.array(orig_size) / np.array(inp_size)
            if len(true_boxes_):
                true_boxes = np.hstack([true_boxes_, true_weights[:, None]])
                true_boxes[:, 0:4:2] *= sx
                true_boxes[:, 1:4:2] *= sy

            y = voc.EvaluateVOC.image_confusions(true_boxes, true_cxs,
                                                 pred_boxes, pred_scores,
                                                 pred_cxs, ovthresh=ovthresh)
            y['gx'] = index
            y_batch.append(y)

        harn.batch_confusions.extend(y_batch)

    @harn.add_epoch_callback
    def on_epoch(harn, tag, loader):
        y = pd.concat(harn.batch_confusions)
        num_classes = len(loader.dataset.label_names)

        mean_ap, ap_list = voc.EvaluateVOC.compute_map(y, num_classes)

        harn.log_value(tag + ' epoch mAP', mean_ap, harn.epoch)
        # max_ap = np.nanmax(ap_list)
        # harn.log_value(tag + ' epoch max-AP', max_ap, harn.epoch)
        harn.batch_confusions.clear()

    return harn
Exemple #12
0
 def info(self):
     return ub.dict_union(self._info, {
         'unique': self.unique(),
         'normed': self.normalize(),
     })
Exemple #13
0
def 字典_合并(*args):
    # 字典_取值({'a': 1, 'b': 1}, {'b': 2, 'c': 2})
    data = ub.dict_union(*args)
    return data
Exemple #14
0
def setup_harness(workers=None):
    """
    CommandLine:
        python ~/code/clab/examples/yolo_voc2.py setup_harness
        python ~/code/clab/examples/yolo_voc2.py setup_harness --profile
        python ~/code/clab/examples/yolo_voc2.py setup_harness --flamegraph

    Example:
        >>> harn = setup_harness(workers=0)
        >>> harn.initialize()
        >>> harn.dry = True
        >>> harn.run()
    """
    workdir = ub.truepath('~/work/VOC2007')
    devkit_dpath = ub.truepath('~/data/VOC/VOCdevkit')
    YoloVOCDataset.ensure_voc_data()

    if ub.argflag('--2007'):
        dsetkw = {'years': [2007]}
    elif ub.argflag('--2012'):
        dsetkw = {'years': [2007, 2012]}
    else:
        dsetkw = {'years': [2007]}

    data_choice = ub.argval('--data', 'normal')

    if ub.argflag('--small'):
        dsetkw['base_wh'] = np.array([7, 7]) * 32
        dsetkw['scales'] = [-1, 1]

    if data_choice == 'combined':
        datasets = {
            'test': YoloVOCDataset(devkit_dpath, split='test', **dsetkw),
            'train': YoloVOCDataset(devkit_dpath, split='trainval', **dsetkw),
        }
    elif data_choice == 'notest':
        datasets = {
            'train': YoloVOCDataset(devkit_dpath, split='train', **dsetkw),
            'vali': YoloVOCDataset(devkit_dpath, split='val', **dsetkw),
        }
    elif data_choice == 'normal':
        datasets = {
            'train': YoloVOCDataset(devkit_dpath, split='train', **dsetkw),
            'vali': YoloVOCDataset(devkit_dpath, split='val', **dsetkw),
            'test': YoloVOCDataset(devkit_dpath, split='test', **dsetkw),
        }
    else:
        raise KeyError(data_choice)

    nice = ub.argval('--nice', default=None)

    pretrained_fpath = ensure_lightnet_initial_weights()

    # NOTE: XPU implicitly supports DataParallel just pass --gpu=0,1,2,3
    xpu = xpu_device.XPU.cast('argv')
    print('xpu = {!r}'.format(xpu))

    ensure_ulimit()

    postproc_params = dict(
        conf_thresh=0.001,
        # nms_thresh=0.5,
        nms_thresh=0.4,
        ovthresh=0.5,
    )

    max_epoch = 160

    lr_step_points = {
        0: 0.001,
        60: 0.0001,
        90: 0.00001,
    }

    # if ub.argflag('--warmup'):
    lr_step_points = {
        # warmup learning rate
        0:  0.0001,
        1:  0.0001,
        2:  0.0002,
        3:  0.0003,
        4:  0.0004,
        5:  0.0005,
        6:  0.0006,
        7:  0.0007,
        8:  0.0008,
        9:  0.0009,
        10: 0.0010,
        # cooldown learning rate
        60: 0.0001,
        90: 0.00001,
    }

    batch_size = int(ub.argval('--batch_size', default=16))
    n_cpus = psutil.cpu_count(logical=True)
    if workers is None:
        workers = int(ub.argval('--workers', default=int(n_cpus / 2)))

    print('Making loaders')
    loaders = make_loaders(datasets, batch_size=batch_size,
                           workers=workers if workers is not None else workers)

    # anchors = {'num': 5, 'values': list(ub.flatten(datasets['train'].anchors))}
    anchors = dict(num=5, values=[1.3221, 1.73145, 3.19275, 4.00944, 5.05587,
                                  8.09892, 9.47112, 4.84053, 11.2364, 10.0071])

    print('Making hyperparams')
    hyper = hyperparams.HyperParams(

        # model=(darknet.Darknet19, {
        model=(light_yolo.Yolo, {
            'num_classes': datasets['train'].num_classes,
            'anchors': anchors,
            'conf_thresh': postproc_params['conf_thresh'],
            'nms_thresh': postproc_params['nms_thresh'],
        }),

        criterion=(RegionLoss, {
            'num_classes': datasets['train'].num_classes,
            'anchors': anchors,
            # 'object_scale': 5.0,
            # 'noobject_scale': 1.0,
            # 'class_scale': 1.0,
            # 'coord_scale': 1.0,
            # 'thresh': 0.6,
        }),

        optimizer=(torch.optim.SGD, dict(
            lr=lr_step_points[0],
            momentum=0.9,
            weight_decay=0.0005
        )),

        initializer=(nninit.Pretrained, {
            'fpath': pretrained_fpath,
        }),

        scheduler=(ListedLR, dict(
            step_points=lr_step_points
        )),

        other=ub.dict_union({
            'nice': str(nice),
            'batch_size': loaders['train'].batch_sampler.batch_size,
        }, postproc_params),
        centering=None,

        # centering=datasets['train'].centering,
        augment=datasets['train'].augmenter,
    )

    harn = fit_harness.FitHarness(
        hyper=hyper, xpu=xpu, loaders=loaders, max_iter=max_epoch,
        workdir=workdir,
    )
    harn.postproc_params = postproc_params
    harn.nice = nice
    harn.monitor = monitor.Monitor(min_keys=['loss'],
                                   # max_keys=['global_acc', 'class_acc'],
                                   patience=max_epoch)

    @harn.set_batch_runner
    @profiler.profile
    def batch_runner(harn, inputs, labels):
        """
        Custom function to compute the output of a batch and its loss.

        Example:
            >>> import sys
            >>> sys.path.append('/home/joncrall/code/clab/examples')
            >>> from yolo_voc2 import *
            >>> harn = setup_harness(workers=0)
            >>> harn.initialize()
            >>> batch = harn._demo_batch(0, 'vali')
            >>> inputs, labels = batch
            >>> criterion = harn.criterion
            >>> weights_fpath = light_yolo.demo_weights()
            >>> state_dict = torch.load(weights_fpath)['weights']
            >>> harn.model.module.load_state_dict(state_dict)
            >>> outputs, loss = harn._custom_run_batch(harn, inputs, labels)
        """
        if harn.dry:
            shape = harn.model.module.output_shape_for(inputs[0].shape)
            outputs = torch.rand(*shape)
        else:
            outputs = harn.model.forward(*inputs)

        # darknet criterion needs to know the input image shape
        # inp_size = tuple(inputs[0].shape[-2:])
        target = labels[0]

        bsize = inputs[0].shape[0]

        n_items = len(harn.loaders['train'])
        bx = harn.bxs.get('train', 0)
        seen = harn.epoch * n_items + (bx * bsize)
        loss = harn.criterion(outputs, target, seen=seen)
        return outputs, loss

    @harn.add_batch_metric_hook
    @profiler.profile
    def custom_metrics(harn, output, labels):
        metrics_dict = ub.odict()
        criterion = harn.criterion
        metrics_dict['L_bbox'] = float(criterion.loss_coord.data.cpu().numpy())
        metrics_dict['L_iou'] = float(criterion.loss_conf.data.cpu().numpy())
        metrics_dict['L_cls'] = float(criterion.loss_cls.data.cpu().numpy())
        return metrics_dict

    # Set as a harness attribute instead of using a closure
    harn.batch_confusions = []

    @harn.add_iter_callback
    @profiler.profile
    def on_batch(harn, tag, loader, bx, inputs, labels, outputs, loss):
        """
        Custom hook to run on each batch (used to compute mAP on the fly)

        Example:
            >>> harn = setup_harness(workers=0)
            >>> harn.initialize()
            >>> batch = harn._demo_batch(0, 'vali')
            >>> inputs, labels = batch
            >>> criterion = harn.criterion
            >>> loader = harn.loaders['train']
            >>> weights_fpath = light_yolo.demo_weights()
            >>> state_dict = torch.load(weights_fpath)['weights']
            >>> harn.model.module.load_state_dict(state_dict)
            >>> outputs, loss = harn._custom_run_batch(harn, inputs, labels)
            >>> tag = 'train'
            >>> on_batch(harn, tag, loader, bx, inputs, labels, outputs, loss)

        Ignore:

            >>> target, gt_weights, batch_orig_sz, batch_index = labels
            >>> bx = 0
            >>> postout = harn.model.module.postprocess(outputs.clone())
            >>> item = postout[bx].cpu().numpy()
            >>> item = item[item[:, 4] > .6]
            >>> cxywh = util.Boxes(item[..., 0:4], 'cxywh')
            >>> orig_size = batch_orig_sz[bx].numpy().ravel()
            >>> tlbr = cxywh.scale(orig_size).asformat('tlbr').data


            >>> truth_bx = target[bx]
            >>> truth_bx = truth_bx[truth_bx[:, 0] != -1]
            >>> truth_tlbr = util.Boxes(truth_bx[..., 1:5].numpy(), 'cxywh').scale(orig_size).asformat('tlbr').data

            >>> chw = inputs[0][bx].numpy().transpose(1, 2, 0)
            >>> rgb255 = cv2.resize(chw * 255, tuple(orig_size))
            >>> mplutil.figure(fnum=1, doclf=True)
            >>> mplutil.imshow(rgb255, colorspace='rgb')
            >>> mplutil.draw_boxes(tlbr, 'tlbr')
            >>> mplutil.draw_boxes(truth_tlbr, 'tlbr', color='orange')
            >>> mplutil.show_if_requested()
        """
        # Accumulate relevant outputs to measure
        target, gt_weights, batch_orig_sz, batch_index = labels
        # inp_size = inputs[0].shape[-2:][::-1]

        conf_thresh = harn.postproc_params['conf_thresh']
        nms_thresh = harn.postproc_params['nms_thresh']
        ovthresh = harn.postproc_params['ovthresh']

        if outputs is None:
            return

        get_bboxes = harn.model.module.postprocess
        get_bboxes.conf_thresh = conf_thresh
        get_bboxes.nms_thresh = nms_thresh

        postout = harn.model.module.postprocess(outputs)

        batch_pred_boxes = []
        batch_pred_scores = []
        batch_pred_cls_inds = []
        for bx, item_ in enumerate(postout):
            item = item_.cpu().numpy()
            if len(item):
                cxywh = util.Boxes(item[..., 0:4], 'cxywh')
                orig_size = batch_orig_sz[bx].cpu().numpy().ravel()
                tlbr = cxywh.scale(orig_size).asformat('tlbr').data
                batch_pred_boxes.append(tlbr)
                batch_pred_scores.append(item[..., 4])
                batch_pred_cls_inds.append(item[..., 5])
            else:
                batch_pred_boxes.append(np.empty((0, 4)))
                batch_pred_scores.append(np.empty(0))
                batch_pred_cls_inds.append(np.empty(0))

        batch_true_cls_inds = target[..., 0]
        batch_true_boxes = target[..., 1:5]

        batch_img_inds = batch_index

        y_batch = []
        for bx, index in enumerate(batch_img_inds.data.cpu().numpy().ravel()):
            pred_boxes = batch_pred_boxes[bx]
            pred_scores = batch_pred_scores[bx]
            pred_cxs = batch_pred_cls_inds[bx]

            # Group groundtruth boxes by class
            true_boxes_ = batch_true_boxes[bx].data.cpu().numpy()
            true_cxs = batch_true_cls_inds[bx].data.cpu().numpy()
            true_weights = gt_weights[bx].data.cpu().numpy()

            # Unnormalize the true bboxes back to orig coords
            orig_size = batch_orig_sz[bx]
            if len(true_boxes_):
                true_boxes = util.Boxes(true_boxes_, 'cxywh').scale(
                    orig_size).asformat('tlbr').data
                true_boxes = np.hstack([true_boxes, true_weights[:, None]])
            else:
                true_boxes = true_boxes_.reshape(-1, 4)

            y = voc.EvaluateVOC.image_confusions(true_boxes, true_cxs,
                                                 pred_boxes, pred_scores,
                                                 pred_cxs, ovthresh=ovthresh)
            y['gx'] = index
            y_batch.append(y)

        harn.batch_confusions.extend(y_batch)

    @harn.add_epoch_callback
    @profiler.profile
    def on_epoch(harn, tag, loader):
        y = pd.concat(harn.batch_confusions)
        num_classes = len(loader.dataset.label_names)

        mean_ap, ap_list = voc.EvaluateVOC.compute_map(y, num_classes)

        harn.log_value(tag + ' epoch mAP', mean_ap, harn.epoch)
        max_ap = np.nanmax(ap_list)
        harn.log_value(tag + ' epoch max-AP', max_ap, harn.epoch)
        harn.batch_confusions.clear()

    return harn
Exemple #15
0
    def run_protocol(
        self, config: Dict[str, Any],
        extra_plugins: Dict[str, Any] = dict()) -> None:
        """Run the protocol by printout out the config.

        Args:

            Config passed in uses 3 parameters to control the launching of the protocols
            - protocol: either 'ond' or 'condda' to define which protocol to run
            - harness:  either 'local' or 'par' to define which harness to use
            - workdir: a directory to save all the information from the run including
                - Config
                - Output of algorithm

        Example:
            >>> from sailon_tinker_launcher.main import *
            >>> dpath = ub.ensure_app_cache_dir('tinker/tests')
            >>> config = get_debug_config()
            >>> self = LaunchSailonProtocol()
            >>> self.run_protocol(config)
            >>> assert(self.working_folder.exists())
            >>> ub.delete(str(self.working_folder), verbose=False)

        """

        # Setup working folder and create new config for this run
        self.working_folder, working_config_fp, privileged_config, config = self.setup_experiment(
            config)

        # Now experiment setup, start a new logger for this
        fh = logging.FileHandler(
            self.working_folder /
            f'{datetime.now().strftime("%Y_%m_%d-%I_%M_%S_%p")}.log')
        fh.setLevel(logging.DEBUG)
        formatter = logging.Formatter(
            '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s')
        fh.setFormatter(formatter)
        logging.getLogger().addHandler(fh)
        log.info(f'Config Filepath: {working_config_fp}')
        log.info(f'Config: \n{json.dumps(config, indent=4)}')

        # Load the harness
        # This config is not used but will throw error if not pointed at
        harnness_config_path = Path(protocol_folder.__file__).parent
        if privileged_config['harness'] == 'local':
            log.info('Loading Local Harness')
            harness = LocalInterface('configuration.json',
                                     str(harnness_config_path))
            harness.result_directory = config['detectors']['csv_folder']
            harness.file_provider.results_folder = config['detectors'][
                'csv_folder']
        elif privileged_config['harness'] == 'par':
            log.info('Loading Par Harness')
            harness = ParInterface('configuration.json',
                                   str(harnness_config_path))
            harness.folder = config['detectors']['csv_folder']
        else:
            raise AttributeError(
                f'Valid harnesses "local" or "par".  '
                f'Given harness "{privileged_config["harness"]}" ')

        # Get the plugins
        plugins = ub.dict_union(discoverable_plugins('tinker'), extra_plugins)

        log.debug('Plugins found:')
        log.debug(plugins)
        # Load the protocol
        if privileged_config['protocol'] == 'ond':
            log.info('Running OND Protocol')
            run_protocol = OND(discovered_plugins=plugins,
                               algorithmsdirectory='',
                               harness=harness,
                               config_file=str(working_config_fp))
        elif privileged_config['protocol'] == 'condda':
            log.info('Running Condda Protocol')
            run_protocol = Condda(discovered_plugins=plugins,
                                  algorithmsdirectory='',
                                  harness=harness,
                                  config_file=str(working_config_fp))
        else:
            raise AttributeError(
                f'Please set protocol to either "ond" or "condda".  '
                f'"{privileged_config["protocol"]}" in the config files')

        # Run the protocol
        run_protocol.run_protocol()
        log.info('Protocol Finished')

        logging.getLogger().removeHandler(fh)
Exemple #16
0
def rectify_normalizer(in_channels, key=ub.NoParam, dim=2, **kwargs):
    """
    Allows dictionary based specification of a normalizing layer

    Args:
        in_channels (int): number of input channels
        dim (int): dimensionality
        **kwargs: extra args

    Example:
        >>> rectify_normalizer(8)
        BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        >>> rectify_normalizer(8, 'batch')
        BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        >>> rectify_normalizer(8, {'type': 'batch'})
        BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        >>> rectify_normalizer(8, 'group')
        GroupNorm(8, 8, eps=1e-05, affine=True)
        >>> rectify_normalizer(8, {'type': 'group', 'num_groups': 2})
        GroupNorm(2, 8, eps=1e-05, affine=True)
        >>> rectify_normalizer(8, dim=3)
        BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        >>> rectify_normalizer(8, None)
        None
        >>> rectify_normalizer(8, key={'type': 'syncbatch'})
    """
    if key is None:
        return None

    if key is ub.NoParam:
        key = 'batch'

    if isinstance(key, six.string_types):
        key = {'type': key}
    elif isinstance(key, dict):
        key = key.copy()
    else:
        raise TypeError(type(key))

    norm_type = key.pop('type')
    if norm_type == 'batch':
        in_channels_key = 'num_features'

        if dim == 0:
            cls = torch.nn.BatchNorm1d
        elif dim == 1:
            cls = torch.nn.BatchNorm1d
        elif dim == 2:
            cls = torch.nn.BatchNorm2d
        elif dim == 3:
            cls = torch.nn.BatchNorm3d
        else:
            raise ValueError(dim)
    elif norm_type == 'syncbatch':
        in_channels_key = 'num_features'
        cls = torch.nn.SyncBatchNorm
    elif norm_type == 'group':
        in_channels_key = 'num_channels'
        if key.get('num_groups') is None:
            key['num_groups'] = ('gcd', min(in_channels, 32))

        if isinstance(key['num_groups'], tuple):
            if key['num_groups'][0] == 'gcd':
                key['num_groups'] = gcd(
                    key['num_groups'][1], in_channels)
        if in_channels % key['num_groups'] != 0:
            raise AssertionError(
                'Cannot divide n_inputs {} by num groups {}'.format(
                    in_channels, key['num_groups']))
        cls = torch.nn.GroupNorm
    elif norm_type == 'batch+group':
        return torch.nn.Sequential(
            rectify_normalizer(in_channels, 'batch', dim=dim),
            rectify_normalizer(in_channels, ub.dict_union({'type': 'group'}, key), dim=dim),
        )
    else:
        raise KeyError('unknown type: {}'.format(key))
    assert in_channels_key not in key
    key[in_channels_key] = in_channels

    try:
        import copy
        kw = copy.copy(key)
        kw.update(kwargs)
        return cls(**kw)
    except Exception:
        raise
        # Ignore kwargs
        import warnings
        warnings.warn('kwargs ignored in rectify normalizer')
        return cls(**key)
Exemple #17
0
def train():
    """
    Example:
        >>> train()
    """
    import random
    np.random.seed(1031726816 % 4294967295)
    torch.manual_seed(137852547 % 4294967295)
    random.seed(2497950049 % 4294967295)

    xpu = xpu_device.XPU.from_argv()
    print('Chosen xpu = {!r}'.format(xpu))

    cifar_num = 10

    if ub.argflag('--lab'):
        datasets = cifar_training_datasets(output_colorspace='LAB',
                                           norm_mode='independent',
                                           cifar_num=cifar_num)
    elif ub.argflag('--rgb'):
        datasets = cifar_training_datasets(output_colorspace='RGB',
                                           norm_mode='independent',
                                           cifar_num=cifar_num)
    elif ub.argflag('--rgb-dep'):
        datasets = cifar_training_datasets(output_colorspace='RGB',
                                           norm_mode='dependant',
                                           cifar_num=cifar_num)
    else:
        raise AssertionError('specify --rgb / --lab')

    import clab.models.densenet

    # batch_size = (128 // 3) * 3
    batch_size = 64

    # initializer_ = (nninit.KaimingNormal, {
    #     'nonlinearity': 'relu',
    # })
    initializer_ = (nninit.LSUV, {})

    hyper = hyperparams.HyperParams(
        model=(
            clab.models.densenet.DenseNet,
            {
                'cifar': True,
                'block_config': (32, 32, 32),  # 100 layer depth
                'num_classes': datasets['train'].n_classes,
                'drop_rate': float(ub.argval('--drop_rate', default=.2)),
                'groups': 1,
            }),
        optimizer=(
            torch.optim.SGD,
            {
                # 'weight_decay': .0005,
                'weight_decay':
                float(ub.argval('--weight_decay', default=.0005)),
                'momentum': 0.9,
                'nesterov': True,
                'lr': 0.1,
            }),
        scheduler=(torch.optim.lr_scheduler.ReduceLROnPlateau, {
            'factor': .5,
        }),
        initializer=initializer_,
        criterion=(torch.nn.CrossEntropyLoss, {}),
        # Specify anything else that is special about your hyperparams here
        # Especially if you make a custom_batch_runner
        augment=str(datasets['train'].augmenter),
        other=ub.dict_union(
            {
                # TODO: type of augmentation as a parameter dependency
                # 'augmenter': str(datasets['train'].augmenter),
                # 'augment': datasets['train'].augment,
                'batch_size': batch_size,
                'colorspace': datasets['train'].output_colorspace,
                'n_classes': datasets['train'].n_classes,
                # 'center_inputs': datasets['train'].center_inputs,
            },
            datasets['train'].center_inputs.__dict__),
    )
    # if ub.argflag('--rgb-indie'):
    #     hyper.other['norm'] = 'dependant'
    hyper.input_ids['train'] = datasets['train'].input_id

    xpu = xpu_device.XPU.cast('auto')
    print('xpu = {}'.format(xpu))

    data_kw = {'batch_size': batch_size}
    if xpu.is_gpu():
        data_kw.update({'num_workers': 8, 'pin_memory': True})

    tags = ['train', 'vali', 'test']

    loaders = ub.odict()
    for tag in tags:
        dset = datasets[tag]
        shuffle = tag == 'train'
        data_kw_ = data_kw.copy()
        if tag != 'train':
            data_kw_['batch_size'] = max(batch_size // 4, 1)
        loader = torch.utils.data.DataLoader(dset, shuffle=shuffle, **data_kw_)
        loaders[tag] = loader

    harn = fit_harness.FitHarness(
        hyper=hyper,
        datasets=datasets,
        xpu=xpu,
        loaders=loaders,
    )
    # harn.monitor = early_stop.EarlyStop(patience=40)
    harn.monitor = monitor.Monitor(min_keys=['loss'],
                                   max_keys=['global_acc', 'class_acc'],
                                   patience=40)

    @harn.set_batch_runner
    def batch_runner(harn, inputs, labels):
        """
        Custom function to compute the output of a batch and its loss.
        """
        output = harn.model(*inputs)
        label = labels[0]
        loss = harn.criterion(output, label)
        outputs = [output]
        return outputs, loss

    task = harn.datasets['train'].task
    all_labels = task.labels
    # ignore_label = datasets['train'].ignore_label
    # from clab import metrics
    from clab.metrics import (confusion_matrix, pixel_accuracy_from_confusion,
                              perclass_accuracy_from_confusion)

    @harn.add_batch_metric_hook
    def custom_metrics(harn, outputs, labels):
        label = labels[0]
        output = outputs[0]

        y_pred = output.data.max(dim=1)[1].cpu().numpy()
        y_true = label.data.cpu().numpy()

        cfsn = confusion_matrix(y_pred, y_true, labels=all_labels)

        global_acc = pixel_accuracy_from_confusion(cfsn)  # same as acc
        perclass_acc = perclass_accuracy_from_confusion(cfsn)
        # class_accuracy = perclass_acc.fillna(0).mean()
        class_accuracy = np.nan_to_num(perclass_acc).mean()

        metrics_dict = ub.odict()
        metrics_dict['global_acc'] = global_acc
        metrics_dict['class_acc'] = class_accuracy
        return metrics_dict

    workdir = ub.ensuredir('train_cifar_work')
    harn.setup_dpath(workdir)

    harn.run()