コード例 #1
0
    def load_data(self, data_path):
        """Load the data from Fluent export. Checks file or dir"""

        smoke_logger.info(f"Beginning to load {data_path}")
        XYZ_cacher = ub.Cacher(f'{data_path}_XYZ',
                               cfgstr=ub.hash_data('dependencies'))
        concentrations_cacher = ub.Cacher(f'{data_path}_concentration',
                                          cfgstr=ub.hash_data('dependencies'))

        self.XYZ = XYZ_cacher.tryload()
        self.concentrations = concentrations_cacher.tryload()

        if self.XYZ is None or self.concentrations is None:
            if os.path.isfile(data_path):
                self.load_file(data_path)
            elif os.path.isdir(data_path):
                self.load_directory(data_path)
            else:
                raise ValueError(
                    f"data path {data_path} was niether a directory nor a file."
                )
            assert self.concentrations is not None
            assert self.XYZ is not None
            XYZ_cacher.save(self.XYZ)
            concentrations_cacher.save(self.concentrations)

        smoke_logger.info(f"done loading {data_path}")
コード例 #2
0
    def _cached_pairwise_features(extr, edges):
        """
        Create pairwise features for annotations in a test inference object
        based on the features used to learn here

        TODO: need a more systematic way of specifying which feature dimensions
        need to be computed

        Notes:
            Given a edge (u, v), we need to:
            * Check which classifiers we have
            * Check which feat-cols the classifier needs,
               and construct a configuration that can acheive that.
                * Construct the chip/feat config
                * Construct the vsone config
                * Additional LNBNN enriching config
                * Pairwise feature construction config
            * Then we can apply the feature to the classifier

        edges = [(1, 2)]
        """
        edges = list(edges)
        if extr.verbose:
            print('[pairfeat] Requesting {} cached pairwise features'.format(
                len(edges)))

        # TODO: use object properties
        if len(edges) == 0:
            assert extr.feat_dims is not None, 'no edges and unset feat dims'
            index = nxu.ensure_multi_index([], ('aid1', 'aid2'))
            feats = pd.DataFrame(columns=extr.feat_dims, index=index)
            return feats
        else:
            use_cache = not extr.need_lnbnn and len(edges) > 2
            cache_dir = join(extr.ibs.get_cachedir(), 'infr_bulk_cache')
            feat_cfgstr = extr._make_cfgstr(edges)
            cacher = ub.Cacher('bulk_pairfeats_v3',
                               feat_cfgstr,
                               enabled=use_cache,
                               dpath=cache_dir,
                               verbose=extr.verbose - 3)

            # if cacher.exists() and extr.verbose > 3:
            #     fpath = cacher.get_fpath()
            #     print('Load match cache size: {}'.format(
            #         ut.get_file_nBytes_str(fpath)))

            data = cacher.tryload()
            if data is None:
                data = extr._make_pairwise_features(edges)
                cacher.save(data)

                # if cacher.enabled and extr.verbose > 3:
                #     fpath = cacher.get_fpath()
                #     print('Save match cache size: {}'.format(
                #         ut.get_file_nBytes_str(fpath)))

            matches, feats = data
            feats = extr._postprocess_feats(feats)
        return feats
コード例 #3
0
ファイル: sseg_camvid.py プロジェクト: Erotemic/netharn
def _cached_class_frequency(dset):
    import ubelt as ub
    import copy
    # Copy the dataset so we can muck with it
    dset_copy = copy.copy(dset)

    dset_copy._build_sliders(input_overlap=0)
    dset_copy.augmenter = None

    cfgstr = '_'.join([dset_copy.sampler.dset.hashid, 'v1'])
    cacher = ub.Cacher('class_freq', cfgstr=cfgstr)
    total_freq = cacher.tryload()
    if total_freq is None:

        total_freq = np.zeros(len(dset_copy.classes), dtype=np.int64)
        if True:
            loader = torch_data.DataLoader(dset_copy, batch_size=16,
                                           num_workers=7, shuffle=False,
                                           pin_memory=True)

            prog = ub.ProgIter(loader, desc='computing (par) class freq')
            for batch in prog:
                class_idxs = batch['class_idxs'].data.numpy()
                item_freq = np.histogram(class_idxs, bins=len(dset_copy.classes))[0]
                total_freq += item_freq
        else:
            prog = ub.ProgIter(range(len(dset_copy)), desc='computing (ser) class freq')
            for index in prog:
                item = dset_copy[index]
                class_idxs = item['class_idxs'].data.numpy()
                item_freq = np.histogram(class_idxs, bins=len(dset_copy.classes))[0]
                total_freq += item_freq
        cacher.save(total_freq)
    return total_freq
コード例 #4
0
ファイル: test_cache.py プロジェクト: hack121/ubelt
def test_corrupt():
    """
    What no errors happen when an external processes removes meta

    python ubelt/tests/test_cache.py test_corrupt
    """
    def func():
        return ['expensive result']

    cacher = ub.Cacher('name', 'params', verbose=10)
    cacher.clear()

    data = cacher.ensure(func)

    data2 = cacher.tryload()

    assert data2 == data

    # Overwrite the data with junk
    with open(cacher.get_fpath(), 'wb') as file:
        file.write(''.encode('utf8'))

    assert cacher.tryload() is None
    with pytest.raises(IOError):
        cacher.load()

    assert cacher.tryload() is None
    with open(cacher.get_fpath(), 'wb') as file:
        file.write(':junkdata:'.encode('utf8'))
    with pytest.raises(Exception):
        cacher.load()
コード例 #5
0
    def after_initialize(harn):
        # Prepare structures we will use to measure and quantify quality
        for tag, voc_dset in harn.datasets.items():
            cacher = ub.Cacher('dmet2', cfgstr=tag, appname='netharn')
            dmet = cacher.tryload()
            if dmet is None:
                dmet = nh.metrics.detections.DetectionMetrics()
                dmet.true = voc_dset.to_coco()
                # Truth and predictions share the same images and categories
                dmet.pred.dataset['images'] = dmet.true.dataset['images']
                dmet.pred.dataset['categories'] = dmet.true.dataset[
                    'categories']
                dmet.pred.dataset['annotations'] = []  # start empty
                dmet.true.dataset['annotations'] = []
                dmet.pred._clear_index()
                dmet.true._clear_index()
                dmet.true._build_index()
                dmet.pred._build_index()
                dmet.true._ensure_imgsize()
                dmet.pred._ensure_imgsize()
                cacher.save(dmet)

            dmet.true._build_index()
            dmet.pred._build_index()
            harn.dmets[tag] = dmet
コード例 #6
0
ファイル: data.py プロジェクト: Sandy4321/baseline-viame-2018
def load_coco_datasets():
    import wrangle
    # annot_globstr = ub.truepath('~/data/viame-challenge-2018/phase0-annotations/*.json')
    # annot_globstr = ub.truepath('~/data/viame-challenge-2018/phase0-annotations/mbari_seq0.mscoco.json')
    # img_root = ub.truepath('~/data/viame-challenge-2018/phase0-imagery')

    # Contest training data on hermes
    annot_globstr = ub.truepath(
        '~/data/noaa/training_data/annotations/*/*-coarse-bbox-only*.json')
    img_root = ub.truepath('~/data/noaa/training_data/imagery/')

    fpaths = sorted(glob.glob(annot_globstr))
    # Remove keypoints annotation data (hack)
    fpaths = [p for p in fpaths if not ('nwfsc' in p or 'afsc' in p)]

    cacher = ub.Cacher('coco_dsets',
                       cfgstr=ub.hash_data(fpaths),
                       appname='viame')
    coco_dsets = cacher.tryload()
    if coco_dsets is None:
        print('Reading raw mscoco files')
        import os
        dsets = []
        for fpath in sorted(fpaths):
            print('reading fpath = {!r}'.format(fpath))
            dset = coco_api.CocoDataset(fpath, tag='', img_root=img_root)
            try:
                assert not dset.missing_images()
            except AssertionError:
                print('fixing image file names')
                hack = os.path.basename(fpath).split('-')[0].split('.')[0]
                dset = coco_api.CocoDataset(fpath,
                                            tag=hack,
                                            img_root=join(img_root, hack))
                assert not dset.missing_images(), ub.repr2(
                    dset.missing_images()) + 'MISSING'
            print(ub.repr2(dset.basic_stats()))
            dsets.append(dset)

        print('Merging')
        merged = coco_api.CocoDataset.union(*dsets)
        merged.img_root = img_root

        # HACK: wont need to do this for the released challenge data
        # probably wont hurt though
        # if not REAL_RUN:
        #     merged._remove_keypoint_annotations()
        #     merged._run_fixes()

        train_dset, vali_dset = wrangle.make_train_vali(merged)

        coco_dsets = {
            'train': train_dset,
            'vali': vali_dset,
        }

        cacher.save(coco_dsets)

    return coco_dsets
コード例 #7
0
ファイル: test_cache.py プロジェクト: hack121/ubelt
def test_cache_hit():
    cacher = ub.Cacher('name', 'params', verbose=2)
    cacher.clear()
    assert not cacher.exists()
    cacher.save(['some', 'data'])
    assert cacher.exists()
    data = cacher.load()
    assert data == ['some', 'data']
コード例 #8
0
ファイル: test_cache.py プロジェクト: Kulbear/ubelt
def test_cache_depends():
    """
    What no errors happen when an external processes removes meta
    """
    cacher = ub.Cacher('name',
                       depends=['a', 'b', 'c'],
                       verbose=10,
                       enabled=False)
    cfgstr = cacher._rectify_cfgstr()
    assert cfgstr.startswith('8a82eef87cb905220841f95')
コード例 #9
0
def test_clear_quiet():
    """
    What no errors happen when an external processes removes meta
    """
    def func():
        return 'expensive result'
    cacher = ub.Cacher('name', 'params', verbose=0)
    cacher.clear()
    cacher.clear()
    cacher.ensure(func)
    cacher.clear()
コード例 #10
0
def _setup_corrupt_cacher(verbose=0):
    def func():
        return ['expensive result']
    cacher = ub.Cacher('name', 'params', verbose=verbose)
    cacher.clear()
    cacher.ensure(func)
    # Write junk data that will cause a non-io error
    with open(cacher.get_fpath(), 'wb') as file:
        file.write(':junkdata:'.encode('utf8'))
    with pytest.raises(Exception):
        assert cacher.tryload(on_error='raise')
    assert exists(cacher.get_fpath())
    return cacher
コード例 #11
0
def stitched_predictions(dataset, arches, xpu, arch_to_train_dpath, workdir,
                         _epochs, tag):

    dataset.inputs.input_id
    print('dataset.inputs.input_id = {!r}'.format(dataset.inputs.input_id))

    # Predict probabilities for each model in the ensemble
    arch_to_paths = {}
    for arch in arches:
        train_dpath = arch_to_train_dpath[arch]
        epoch = _epochs[arch]
        load_path = fit_harn2.get_snapshot(train_dpath, epoch=epoch)

        pharn = UrbanPredictHarness(dataset, xpu)
        dataset.center_inputs = pharn.load_normalize_center(train_dpath)

        pharn.test_dump_dpath = ub.ensuredir(
            (workdir, tag, dataset.inputs.input_id, arch,
             'epoch{}'.format(epoch)))

        stitched_dpath = join(pharn.test_dump_dpath, 'stitched')

        cfgstr = util.hash_data([
            # depend on both the inputs and the exact model specs
            dataset.inputs.input_id,
            util.hash_file(load_path)
        ])

        # predict the whole scene
        cacher = ub.Cacher('prediction_stamp',
                           cfgstr=cfgstr,
                           dpath=stitched_dpath)
        if cacher.tryload() is None:
            # Only execute this if we haven't done so
            pharn.load_snapshot(load_path)
            pharn.run()
            cacher.save(True)

        paths = {
            'probs': glob.glob(join(stitched_dpath, 'probs', '*.h5')),
            'probs1': glob.glob(join(stitched_dpath, 'probs1', '*.h5')),
        }
        arch_to_paths[arch] = paths
    return arch_to_paths
コード例 #12
0
ファイル: util_tensorboard.py プロジェクト: Cookt2/netharn
def read_tensorboard_scalars(train_dpath, verbose=1, cache=1):
    """
    Reads all tensorboard scalar events in a directory.
    Caches them becuase reading events of interest from protobuf can be slow.
    """
    import glob
    from os.path import join
    try:
        from tensorboard.backend.event_processing import event_accumulator
    except ImportError:
        raise ImportError('tensorboard is not installed')
    event_paths = sorted(glob.glob(join(train_dpath, 'events.out.tfevents*')))
    # make a hash so we will re-read of we need to
    cfgstr = ub.hash_data(list(map(ub.hash_file, event_paths)))
    # cfgstr = ub.hash_data(list(map(basename, event_paths)))
    cacher = ub.Cacher('tb_scalars',
                       enabled=cache,
                       dpath=ub.ensuredir((train_dpath, '_cache')),
                       cfgstr=cfgstr)
    datas = cacher.tryload()
    if datas is None:
        datas = {}
        for p in ub.ProgIter(list(reversed(event_paths)),
                             desc='read tensorboard',
                             enabled=verbose):
            ea = event_accumulator.EventAccumulator(p)
            ea.Reload()
            for key in ea.scalars.Keys():
                if key not in datas:
                    datas[key] = {'xdata': [], 'ydata': [], 'wall': []}
                subdatas = datas[key]
                events = ea.scalars.Items(key)
                for e in events:
                    subdatas['xdata'].append(int(e.step))
                    subdatas['ydata'].append(float(e.value))
                    subdatas['wall'].append(float(e.wall_time))

        # Order all information by its wall time
        for key, subdatas in datas.items():
            sortx = ub.argsort(subdatas['wall'])
            for d, vals in subdatas.items():
                subdatas[d] = list(ub.take(vals, sortx))
        cacher.save(datas)
    return datas
コード例 #13
0
    def _cacher(self, fname, extra_deps=None, disable=False, verbose=None):
        """
        Create a cacher for a known lazy computation using a common hashid.

        If `self.workdir` or `self.hashid` is None, then caches are disabled by
        default. Caches can be explicitly disabled by setting the appropriate
        value in the `self._enabled_caches` dictionary.

        Args:
            fname (str): name of the property we are caching
            extra_deps (OrderedDict): extra data to contribute to the hashid
            disable (bool): explicitly disable cache if True, otherwise do
                normal checks to see if enabled.
            verbose (bool, default=None): if specified overrides `self.verbose`.

        Returns:
            ub.Cacher: cacher - if enabled this cacher will minimally depend
                on the `self.hashid`, but may also depend on extra info.
        """
        if verbose is None:
            verbose = self.verbose

        if not disable and self.hashid and self.workdir:
            enabled = self._enabled_caches.get(fname, True)
            dpath = ub.ensuredir((self.workdir, '_cache', fname))
        else:
            dpath = None
            enabled = False  # forced disable

        cfgstr = None

        if enabled:
            if extra_deps is None:
                extra_deps = ub.odict()
            elif not isinstance(extra_deps, ub.odict):
                raise TypeError('Extra dependencies must be an OrderedDict')
            # always include `self.hashid`
            extra_deps['self_hashid'] = self.hashid
            cfgstr = ub.hash_data(extra_deps)

        cacher = ub.Cacher(fname, cfgstr=cfgstr, dpath=dpath,
                           verbose=self.verbose, enabled=enabled)
        return cacher
コード例 #14
0
def svd_orthonormal(shape, rng=None, cache_key=None):
    """
    If cache_key is specified, then the result will be cached, and subsequent
    calls with the same key and shape will return the same result.

    References:
        Orthonorm init code is taked from Lasagne
        https://github.com/Lasagne/Lasagne/blob/master/lasagne/init.py
    """
    rng = util.ensure_rng(rng)

    if len(shape) < 2:
        raise RuntimeError("Only shapes of length 2 or more are supported.")
    flat_shape = (shape[0], np.prod(shape[1:]))

    enabled = False and cache_key is not None
    if enabled:
        rand_sequence = rng.randint(0, 2**16)
        depends = [shape, cache_key, rand_sequence]
        cfgstr = ub.hash_data(depends)
    else:
        cfgstr = ''

    # this process can be expensive, cache it

    # TODO: only cache very large matrices (4096x4096)
    # TODO: only cache very large matrices, not (256,256,3,3)
    cacher = ub.Cacher('svd_orthonormal',
                       appname='netharn',
                       enabled=enabled,
                       cfgstr=cfgstr)
    q = cacher.tryload()
    if q is None:
        # print('Compute orthonormal matrix with shape ' + str(shape))
        a = rng.normal(0.0, 1.0, flat_shape)
        u, _, v = np.linalg.svd(a, full_matrices=False)
        q = u if u.shape == flat_shape else v
        # print(shape, flat_shape)
        q = q.reshape(shape)
        q = q.astype(np.float32)
        cacher.save(q)
    return q
コード例 #15
0
    def _load_sized_image(self, index, inp_size):
        # load the raw data from VOC

        cacher = ub.Cacher('voc_img', cfgstr=ub.repr2([index, inp_size]),
                           appname='clab')
        data = cacher.tryload()
        if data is None:
            image = self._load_image(index)
            orig_size = np.array(image.shape[0:2][::-1])
            factor = inp_size / orig_size
            # squish the image into network input coordinates
            interpolation = (cv2.INTER_AREA if factor.sum() <= 2 else
                             cv2.INTER_CUBIC)
            hwc255 = cv2.resize(image, tuple(inp_size),
                                interpolation=interpolation)
            data = hwc255, orig_size, factor
            cacher.save(data)

        hwc255, orig_size, factor = data
        return hwc255, orig_size, factor
コード例 #16
0
ファイル: cacheUtil_test.py プロジェクト: duolabmeng6/pyefun
    def test_1(self):
        cacher = ub.Cacher(
            fname="mycache",
            depends=ub.hash_data('test'),
            dpath="./cache/",
            appname="my",
            ext=".pkl",
            verbose=3,
            enabled=True,
            hasher='sha1',
            protocol=3,
        )
        print(cacher.get_fpath())
        data = cacher.tryload()
        if data is None:
            myvar1 = 'result of expensive process'
            myvar2 = 'another result'
            data = myvar1, myvar2
            cacher.save(data)
        myvar1, myvar2 = data

        data = func()
        print(data)

        from ubelt.util_cache import Cacher
        from os.path import basename
        # Ensure that some data exists
        known_fpaths = set()
        # cacher = Cacher('versioned_data_v2', depends='1')
        # cacher.ensure(lambda: 'data1')
        # known_fpaths.add(cacher.get_fpath())
        # cacher = Cacher('versioned_data_v2', depends='2')
        # cacher.ensure(lambda: 'data2')
        # known_fpaths.add(cacher.get_fpath())
        # # List previously computed configs for this type
        # cacher = ub.Cacher('versioned_data_v2', depends='2')
        exist_fpaths = set(cacher.existing_versions())
        print('exist_fnames = {!r}'.format(exist_fpaths))

        exist_fnames = list(map(basename, exist_fpaths))
        print('exist_fnames = {!r}'.format(exist_fnames))
コード例 #17
0
def test_disable():
    """
    What no errors happen when an external processes removes meta
    """
    nonlocal_var = [0]

    def func():
        nonlocal_var[0] += 1
        return ['expensive result']
    cacher = ub.Cacher('name', 'params', verbose=10, enabled=False)

    assert nonlocal_var[0] == 0
    cacher.ensure(func)
    assert nonlocal_var[0] == 1
    cacher.ensure(func)
    assert nonlocal_var[0] == 2
    cacher.ensure(func)

    with pytest.raises(IOError):
        cacher.load()

    assert cacher.tryload(func) is None
コード例 #18
0
def test_noexist_meta_clear():
    """
    What no errors happen when an external processes removes meta
    """
    def func():
        return 'expensive result'
    cacher = ub.Cacher('name', 'params', verbose=10)
    cacher.clear()

    cacher.ensure(func)

    data_fpath = cacher.get_fpath()
    meta_fpath = data_fpath + '.meta'
    assert exists(data_fpath)
    assert exists(meta_fpath)

    ub.delete(meta_fpath)
    assert not exists(meta_fpath)
    cacher.clear()

    assert not exists(meta_fpath)
    assert not exists(data_fpath)
コード例 #19
0
    def __init__(self,
                 sampler,
                 coco_dset,
                 workdir=None,
                 augment=False,
                 dim=416):
        print('make MatchingCocoDataset')

        self.sampler = sampler
        cacher = ub.Cacher('pccs', cfgstr=coco_dset.tag, verbose=True)
        pccs = cacher.tryload()
        if pccs is None:
            import graphid
            graph = graphid.api.GraphID()
            graph.add_annots_from(coco_dset.annots().aids)
            infr = graph.infr
            infr.params['inference.enabled'] = False
            all_aids = list(coco_dset.annots().aids)
            aids_set = set(all_aids)
            for aid1 in ub.ProgIter(all_aids, desc='construct graph'):
                annot = coco_dset.anns[aid1]
                for review in annot['review_ids']:
                    aid2, decision = review
                    if aid2 not in aids_set:
                        # hack because data is setup wrong
                        continue
                    edge = (aid1, aid2)
                    if decision == 'positive':
                        infr.add_feedback(edge,
                                          evidence_decision=graphid.core.POSTV)
                    elif decision == 'negative':
                        infr.add_feedback(edge,
                                          evidence_decision=graphid.core.NEGTV)
                    elif decision == 'incomparable':
                        infr.add_feedback(edge,
                                          evidence_decision=graphid.core.INCMP)
                    else:
                        raise KeyError(decision)
            infr.params['inference.enabled'] = True
            infr.apply_nondynamic_update()
            print('status = {}' + ub.repr2(infr.status(True)))
            pccs = list(map(frozenset, infr.positive_components()))
            for pcc in pccs:
                for aid in pcc:
                    print('aid = {!r}'.format(aid))
                    assert aid in coco_dset.anns
            cacher.save(pccs)

        print('target index')
        self.aid_to_tx = {
            aid: tx
            for tx, aid in enumerate(sampler.regions.targets['aid'])
        }

        self.coco_dset = coco_dset
        print('Find Samples')
        self.samples = sample_labeled_pairs(pccs,
                                            max_num=1e5,
                                            pos_neg_ratio=1.0)

        self.samples = nh.util.shuffle(self.samples, rng=0)
        print('Finished sampling')
        self.dim = dim

        self.rng = nh.util.ensure_rng(0)
        if augment:
            import imgaug.augmenters as iaa
            # NOTE: we are only using `self.augmenter` to make a hyper hashid
            # in __getitem__ we invoke transform explicitly for fine control
            # self.hue = nh.data.transforms.HSVShift(hue=0.1, sat=1.5, val=1.5)
            self.crop = iaa.Crop(percent=(0, .2))
            self.flip = iaa.Fliplr(p=.5)
            self.augmenter = iaa.Sequential([
                # self.hue,
                self.crop,
                self.flip
            ])
        else:
            self.augmenter = None
        self.letterbox = nh.data.transforms.Resize(target_size=(dim, dim),
                                                   mode='letterbox')
コード例 #20
0
 def __init__(self, fname, dpath, cfgstr=None, product=None, robust=True):
     self.cacher = ub.Cacher(fname, cfgstr=cfgstr, dpath=dpath)
     self.product = product
     self.robust = robust
コード例 #21
0
def setup_harn(cmdline=True, **kwargs):
    """
    cmdline, kwargs = False, {}
    """
    import sys
    import ndsampler

    config = ImageClfConfig(default=kwargs)
    config.load(cmdline=cmdline)
    nh.configure_hacks(config)  # fix opencv bugs

    cacher = ub.Cacher('tiny-imagenet', cfgstr='v4', verbose=3)
    data = cacher.tryload()
    if data is None:
        data = grab_tiny_imagenet_as_coco()
        cacher.save(data)
    coco_datasets = data  # setup_coco_datasets()
    dset = coco_datasets['train']
    print('train dset = {!r}'.format(dset))

    workdir = ub.ensuredir(ub.expandpath(config['workdir']))
    samplers = {
        # tag: ndsampler.CocoSampler(dset, workdir=workdir, backend='cog')
        tag: ndsampler.CocoSampler(dset, workdir=workdir, backend='npy')
        for tag, dset in coco_datasets.items()
    }
    torch_datasets = {
        tag: ImagClfDataset(
            sampler, config['input_dims'],
            augmenter=((tag == 'train') and config['augmenter']),
        )
        for tag, sampler in samplers.items()
    }
    torch_loaders = {
        tag: torch_data.DataLoader(dset,
                                   batch_size=config['batch_size'],
                                   num_workers=config['workers'],
                                   shuffle=(tag == 'train'),
                                   pin_memory=True)
        for tag, dset in torch_datasets.items()
    }

    import torchvision
    # TODO: netharn should allow for this
    model_ = torchvision.models.resnet50(pretrained=False)

    # model_ = (, {
    #     'classes': torch_datasets['train'].classes,
    #     'in_channels': 3,
    # })
    initializer_ = nh.Initializer.coerce(config)

    hyper = nh.HyperParams(
        nice=config['nice'],
        workdir=config['workdir'],
        xpu=nh.XPU.coerce(config['xpu']),

        datasets=torch_datasets,
        loaders=torch_loaders,

        model=model_,
        initializer=initializer_,

        scheduler=nh.Scheduler.coerce(config),
        optimizer=nh.Optimizer.coerce(config),
        dynamics=nh.Dynamics.coerce(config),

        criterion=(nh.criterions.FocalLoss, {
            'focus': 0.0,
        }),

        monitor=(nh.Monitor, {
            'minimize': ['loss'],
            'patience': config['patience'],
            'max_epoch': config['max_epoch'],
            'smoothing': .6,
        }),

        other={
            'batch_size': config['batch_size'],
        },
        extra={
            'argv': sys.argv,
            'config': ub.repr2(config.asdict()),
        }
    )

    # Create harness
    harn = ImageClfHarn(hyper=hyper)
    harn.classes = torch_datasets['train'].classes
    harn.preferences.update({
        'num_keep': 5,
        'keyboard_debug': True,
        # 'export_modules': ['netharn'],
    })
    harn.intervals.update({
        'vali': 1,
        'test': 10,
    })
    harn.script_config = config
    return harn
コード例 #22
0
ファイル: classification.py プロジェクト: Kitware/netharn
def setup_harn(cmdline=True, **kw):
    """
    This creates the "The Classification Harness" (i.e. core ClfHarn object).
    This is where we programmatically connect our program arguments with the
    netharn HyperParameter standards. We are using :module:`scriptconfig` to
    capture these, but you could use click / argparse / etc.

    This function has the responsibility of creating our torch datasets,
    lazy computing input statistics, specifying our model architecture,
    schedule, initialization, optimizer, dynamics, XPU etc. These can usually
    be coerced using netharn API helpers and a "standardized" config dict. See
    the function code for details.

    Args:
        cmdline (bool, default=True):
            if True, behavior will be modified based on ``sys.argv``.
            Note this will activate the scriptconfig ``--help``, ``--dump`` and
            ``--config`` interactions.

    Kwargs:
        **kw: the overrides the default config for :class:`ClfConfig`.
            Note, command line flags have precedence if cmdline=True.

    Returns:
        ClfHarn: a fully-defined, but uninitialized custom :class:`FitHarn`
            object.

    Example:
        >>> # xdoctest: +SKIP
        >>> kw = {'datasets': 'special:shapes256'}
        >>> cmdline = False
        >>> harn = setup_harn(cmdline, **kw)
        >>> harn.initialize()
    """
    import ndsampler
    config = ClfConfig(default=kw)
    config.load(cmdline=cmdline)
    print('config = {}'.format(ub.repr2(config.asdict())))

    nh.configure_hacks(config)
    coco_datasets = nh.api.Datasets.coerce(config)

    print('coco_datasets = {}'.format(ub.repr2(coco_datasets, nl=1)))
    for tag, dset in coco_datasets.items():
        dset._build_hashid(hash_pixels=False)

    workdir = ub.ensuredir(ub.expandpath(config['workdir']))
    samplers = {
        tag: ndsampler.CocoSampler(dset,
                                   workdir=workdir,
                                   backend=config['sampler_backend'])
        for tag, dset in coco_datasets.items()
    }

    for tag, sampler in ub.ProgIter(list(samplers.items()),
                                    desc='prepare frames'):
        sampler.frames.prepare(workers=config['workers'])

    torch_datasets = {
        'train':
        ClfDataset(
            samplers['train'],
            input_dims=config['input_dims'],
            augmenter=config['augmenter'],
        ),
        'vali':
        ClfDataset(samplers['vali'],
                   input_dims=config['input_dims'],
                   augmenter=False),
    }

    if config['normalize_inputs']:
        # Get stats on the dataset (todo: turn off augmentation for this)
        _dset = torch_datasets['train']
        stats_idxs = kwarray.shuffle(np.arange(len(_dset)),
                                     rng=0)[0:min(1000, len(_dset))]
        stats_subset = torch.utils.data.Subset(_dset, stats_idxs)

        cacher = ub.Cacher('dset_mean', cfgstr=_dset.input_id + 'v3')
        input_stats = cacher.tryload()

        channels = ChannelSpec.coerce(config['channels'])

        if input_stats is None:
            # Use parallel workers to load data faster
            from netharn.data.data_containers import container_collate
            from functools import partial
            collate_fn = partial(container_collate, num_devices=1)

            loader = torch.utils.data.DataLoader(
                stats_subset,
                collate_fn=collate_fn,
                num_workers=config['workers'],
                shuffle=True,
                batch_size=config['batch_size'])

            # Track moving average of each fused channel stream
            channel_stats = {
                key: nh.util.RunningStats()
                for key in channels.keys()
            }
            assert len(channel_stats) == 1, (
                'only support one fused stream for now')
            for batch in ub.ProgIter(loader, desc='estimate mean/std'):
                for key, val in batch['inputs'].items():
                    try:
                        for part in val.numpy():
                            channel_stats[key].update(part)
                    except ValueError:  # final batch broadcast error
                        pass

            perchan_input_stats = {}
            for key, running in channel_stats.items():
                running = ub.peek(channel_stats.values())
                perchan_stats = running.simple(axis=(1, 2))
                perchan_input_stats[key] = {
                    'std': perchan_stats['mean'].round(3),
                    'mean': perchan_stats['std'].round(3),
                }

            input_stats = ub.peek(perchan_input_stats.values())
            cacher.save(input_stats)
    else:
        input_stats = {}

    torch_loaders = {
        tag: dset.make_loader(
            batch_size=config['batch_size'],
            num_batches=config['num_batches'],
            num_workers=config['workers'],
            shuffle=(tag == 'train'),
            balance=(config['balance'] if tag == 'train' else None),
            pin_memory=True)
        for tag, dset in torch_datasets.items()
    }

    initializer_ = None
    classes = torch_datasets['train'].classes

    modelkw = {
        'arch': config['arch'],
        'input_stats': input_stats,
        'classes': classes.__json__(),
        'channels': channels,
    }
    model = ClfModel(**modelkw)
    model._initkw = modelkw

    if initializer_ is None:
        initializer_ = nh.Initializer.coerce(config)

    hyper = nh.HyperParams(name=config['name'],
                           workdir=config['workdir'],
                           xpu=nh.XPU.coerce(config['xpu']),
                           datasets=torch_datasets,
                           loaders=torch_loaders,
                           model=model,
                           criterion=None,
                           optimizer=nh.Optimizer.coerce(config),
                           dynamics=nh.Dynamics.coerce(config),
                           scheduler=nh.Scheduler.coerce(config),
                           initializer=initializer_,
                           monitor=(nh.Monitor, {
                               'minimize': ['loss'],
                               'patience': config['patience'],
                               'max_epoch': config['max_epoch'],
                               'smoothing': 0.0,
                           }),
                           other={
                               'name': config['name'],
                               'batch_size': config['batch_size'],
                               'balance': config['balance'],
                           },
                           extra={
                               'argv': sys.argv,
                               'config': ub.repr2(config.asdict()),
                           })
    harn = ClfHarn(hyper=hyper)
    harn.preferences.update({
        'num_keep': 3,
        'keep_freq': 10,
        'tensorboard_groups': ['loss'],
        'eager_dump_tensorboard': True,
    })
    harn.intervals.update({})
    harn.script_config = config
    return harn
コード例 #23
0
ファイル: grab_camvid.py プロジェクト: Kitware/kwcoco
def grab_coco_camvid():
    """
    Example:
        >>> # xdoctest: +REQUIRES(--download)
        >>> dset = grab_coco_camvid()
        >>> print('dset = {!r}'.format(dset))
        >>> # xdoctest: +REQUIRES(--show)
        >>> import kwplot
        >>> plt = kwplot.autoplt()
        >>> plt.clf()
        >>> dset.show_image(gid=1)

    Ignore:
        import xdev
        gid_list = list(dset.imgs)
        for gid in xdev.InteractiveIter(gid_list):
            dset.show_image(gid)
            xdev.InteractiveIter.draw()
    """
    import kwcoco
    cache_dpath = ub.ensure_app_cache_dir('kwcoco', 'camvid')
    coco_fpath = join(cache_dpath, 'camvid.mscoco.json')

    # Need to manually bump this if you make a change to loading
    SCRIPT_VERSION = 'v4'

    # Ubelt's stamp-based caches are super cheap and let you take control of
    # the data format.
    stamp = ub.CacheStamp('camvid_coco',
                          cfgstr=SCRIPT_VERSION,
                          dpath=cache_dpath,
                          product=coco_fpath,
                          hasher='sha1',
                          verbose=3)
    if stamp.expired():
        camvid_raw_info = grab_raw_camvid()
        dset = convert_camvid_raw_to_coco(camvid_raw_info)
        with ub.Timer('dumping MS-COCO dset to: {}'.format(coco_fpath)):
            dset.dump(coco_fpath)
        # Mark this process as completed by saving a small file containing the
        # hash of the "product" you are stamping.
        stamp.renew()

    # We can also cache the index build step independently. This uses
    # ubelt.Cacher, which is pickle based, and writes the actual object to
    # disk. Each type of caching has its own uses and tradeoffs.
    cacher = ub.Cacher('prebuilt-coco',
                       cfgstr=SCRIPT_VERSION,
                       dpath=cache_dpath,
                       verbose=3)
    dset = cacher.tryload()
    if dset is None:
        print('Reading coco_fpath = {!r}'.format(coco_fpath))
        dset = kwcoco.CocoDataset(coco_fpath, tag='camvid')
        # Directly save the file to disk.
        dset._build_index()
        dset._build_hashid()
        cacher.save(dset)

    camvid_dset = dset
    print('Loaded camvid_dset = {!r}'.format(camvid_dset))
    return camvid_dset
コード例 #24
0
ファイル: pvpoke_driver.py プロジェクト: Erotemic/pypogo
def run_pvpoke_simulation(mons, league='auto'):
    """
    Args:
        mons (List[pypogo.Pokemon]): pokemon to simulate.
            Must have IVS, movesets, level, etc... fields populated.
    """
    from selenium import webdriver
    from selenium.webdriver.common.keys import Keys
    # from selenium.webdriver.support.ui import Select
    import pandas as pd
    # import pypogo

    if league == 'auto':
        for mon in mons:
            if mon.cp <= 1500:
                league = 'great'
            elif mon.cp <= 2500:
                league = 'ultra'
            elif mon.level <= 41:
                league = 'master-classic'
            elif mon.level <= 51:
                league = 'master'
            else:
                raise AssertionError
            break
    # for mon in mons:
    #     mon.populate_all
    mon_cachers = {}
    have_results = {}
    to_check_mons = []
    for mon in mons:
        mon._slug = mon.slug()
        mon_cachers[mon._slug] = cacher = ub.Cacher(
            'pvpoke_sim', depends=[mon._slug, league], appname='pypogo')
        mon_results = cacher.tryload()
        if mon_results is None:
            to_check_mons.append(mon)
        else:
            have_results[mon._slug] = mon_results

    if to_check_mons:
        # Requires the driver be in the PATH
        ensure_selenium_chromedriver()

        url = 'https://pvpoke.com/battle/matrix/'
        driver = webdriver.Chrome()
        driver.get(url)
        time.sleep(2.0)

        if league == 'great':
            league_box_target = 'Great League (CP 1500)'
            meta_text = 'Great League Meta'
        elif league == 'ultra':
            league_box_target = 'Ultra League (Level 50)'
            meta_text = 'Ultra League Meta'
            # meta_text = 'Premier Cup Meta'
            # meta_text = 'Remix Cup Meta'
            # meta_text = 'Premier Classic Cup Meta'
        elif league == 'master-classic':
            league_box_target = 'Master League (Level 40)'
            meta_text = 'Master League Meta'
        elif league == 'master':
            league_box_target = 'Master League (Level 50)'
            meta_text = 'Master League Meta'
        else:
            raise NotImplementedError

        leage_select = driver.find_elements_by_class_name('league-select')[0]
        leage_select.click()
        leage_select.send_keys(league_box_target)
        leage_select.click()
        leage_select.send_keys(Keys.ENTER)

        # leage_select.text.split('\n')
        # leage_select.send_keys('\n')
        # leage_select.send_keys('\n')

        def add_pokemon(mon):
            add_poke1_button = driver.find_elements_by_class_name(
                'add-poke-btn')[0]
            add_poke1_button.click()

            select_drop = driver.find_element_by_xpath(
                '/html/body/div[5]/div/div[3]/div[1]/select')

            if 1:
                import xdev
                all_names = select_drop.text.split('\n')
                distances = xdev.edit_distance(mon.display_name(), all_names)
                chosen_name = all_names[ub.argmin(distances)]
            else:
                chosen_name = mon.name

            search_box = driver.find_element_by_xpath(
                '/html/body/div[5]/div/div[3]/div[1]/input')
            search_box.send_keys(chosen_name)

            advanced_ivs_arrow = driver.find_element_by_xpath(
                '/html/body/div[5]/div/div[3]/div[1]/div[2]/div[9]/a/span[1]')
            advanced_ivs_arrow.click()

            level40_cap = driver.find_element_by_xpath(
                '/html/body/div[5]/div/div[3]/div[1]/div[2]/div[9]/div/div[2]/div[2]/div[2]'
            )
            level41_cap = driver.find_element_by_xpath(
                '/html/body/div[5]/div/div[3]/div[1]/div[2]/div[9]/div/div[2]/div[2]/div[3]'
            )
            level50_cap = driver.find_element_by_xpath(
                '/html/body/div[5]/div/div[3]/div[1]/div[2]/div[9]/div/div[2]/div[2]/div[4]'
            )
            level51_cap = driver.find_element_by_xpath(
                '/html/body/div[5]/div/div[3]/div[1]/div[2]/div[9]/div/div[2]/div[2]/div[5]'
            )

            if mon.level >= 51:
                level51_cap.click()
            elif mon.level >= 50:
                level50_cap.click()
            elif mon.level >= 41:
                level41_cap.click()
            elif mon.level >= 40:
                level40_cap.click()

            level_box = driver.find_element_by_xpath(
                '/html/body/div[5]/div/div[3]/div[1]/div[2]/div[9]/div/div[1]/input'
            )
            level_box.click()
            level_box.clear()
            level_box.clear()
            level_box.send_keys(str(mon.level))

            iv_a = driver.find_element_by_xpath(
                '/html/body/div[5]/div/div[3]/div[1]/div[2]/div[9]/div/div[1]/div/input[1]'
            )
            iv_d = driver.find_element_by_xpath(
                '/html/body/div[5]/div/div[3]/div[1]/div[2]/div[9]/div/div[1]/div/input[2]'
            )
            iv_s = driver.find_element_by_xpath(
                '/html/body/div[5]/div/div[3]/div[1]/div[2]/div[9]/div/div[1]/div/input[3]'
            )

            # TODO
            # driver.find_elements_by_class_name('move-select')

            iv_a.clear()
            iv_a.send_keys(str(mon.ivs[0]))

            iv_d.clear()
            iv_d.send_keys(str(mon.ivs[1]))

            iv_s.clear()
            iv_s.send_keys(str(mon.ivs[2]))

            # USE_MOVES = 1
            if mon.moves is not None:
                # mon.populate_all()

                fast_select = driver.find_element_by_xpath(
                    '/html/body/div[5]/div/div[3]/div[1]/div[2]/div[10]/select[1]'
                )
                fast_select.click()
                fast_select.send_keys(mon.pvp_fast_move['name'])
                fast_select.send_keys(Keys.ENTER)

                charge1_select = driver.find_element_by_xpath(
                    '/html/body/div[5]/div/div[3]/div[1]/div[2]/div[10]/select[2]'
                )
                charge1_select.click()
                charge1_select.send_keys(mon.pvp_charge_moves[0]['name'])
                charge1_select.send_keys(Keys.ENTER)

                charge2_select = driver.find_element_by_xpath(
                    '/html/body/div[5]/div/div[3]/div[1]/div[2]/div[10]/select[3]'
                )
                charge2_select.click()
                charge2_select.send_keys(mon.pvp_charge_moves[1]['name'])
                charge2_select.send_keys(Keys.ENTER)

            save_button = driver.find_elements_by_class_name('save-poke')[0]
            save_button.click()

        quickfills = driver.find_elements_by_class_name('quick-fill-select')
        quickfill = quickfills[1]
        quickfill.text.split('\n')
        quickfill.click()
        quickfill.send_keys(meta_text)
        quickfill.click()

        for mon in to_check_mons:
            add_pokemon(mon)

        shield_num_to_text = {
            0: 'No shields',
            1: '1 shield',
            2: '2 shields',
        }

        shield_case_to_data = {}

        for atk_num_shields, def_num_sheids in it.product(
                shield_num_to_text, shield_num_to_text):
            shield_selectors = driver.find_elements_by_class_name(
                'shield-select')
            shield_selectors[2].click()
            shield_selectors[2].send_keys(shield_num_to_text[atk_num_shields])
            shield_selectors[2].send_keys(Keys.ENTER)

            shield_selectors[3].click()
            shield_selectors[3].send_keys(shield_num_to_text[def_num_sheids])
            shield_selectors[3].send_keys(Keys.ENTER)

            #shield_selectors[0].click()

            battle_btn = driver.find_elements_by_class_name('battle-btn')[0]
            battle_btn.click()

            # Clear previous downloaded files
            dlfolder = pathlib.Path(ub.expandpath('$HOME/Downloads'))
            for old_fpath in list(dlfolder.glob('_vs*.csv')):
                old_fpath.unlink()

            time.sleep(2.0)

            # Download new data
            dl_btn = driver.find_element_by_xpath(
                '//*[@id="main"]/div[4]/div[9]/div/a')
            dl_btn.click()

            while len(list(dlfolder.glob('_vs*.csv'))) < 1:
                pass

            new_fpaths = list(dlfolder.glob('_vs*.csv'))
            assert len(new_fpaths) == 1
            fpath = new_fpaths[0]

            data = pd.read_csv(fpath, header=0, index_col=0)
            shield_case_to_data[(atk_num_shields, def_num_sheids)] = data

        for idx, mon in enumerate(to_check_mons):
            mon_results = {
                ss: scores.iloc[idx]
                for ss, scores in shield_case_to_data.items()
            }
            cacher = mon_cachers[mon._slug]
            cacher.save(mon_results)
            have_results[mon._slug] = mon_results

    _tojoin = ub.ddict(list)
    _joined = ub.ddict(list)
    for mon_results in have_results.values():
        for ss, scores in mon_results.items():
            _tojoin[ss].append(scores)

    for ss, vals in _tojoin.items():
        _joined[ss] = pd.concat([v.to_frame().T for v in vals])
    _joined.default_factory = None
    results = _joined
    return results
コード例 #25
0
def main(bib_fpath=None):
    r"""
    intro point to fixbib script

    CommmandLine:
        fixbib
        python -m fixtex bib
        python -m fixtex bib --dryrun
        python -m fixtex bib --dryrun --debug
    """

    if bib_fpath is None:
        bib_fpath = 'My Library.bib'

    # DEBUG = ub.argflag('--debug')
    # Read in text and ensure ascii format
    dirty_text = ut.readfrom(bib_fpath)

    from fixtex.fix_tex import find_used_citations, testdata_fpaths

    if exists('custom_extra.bib'):
        extra_parser = bparser.BibTexParser(ignore_nonstandard_types=False)
        parser = bparser.BibTexParser()
        ut.delete_keys(parser.alt_dict, ['url', 'urls'])
        print('Parsing extra bibtex file')
        extra_text = ut.readfrom('custom_extra.bib')
        extra_database = extra_parser.parse(extra_text, partial=False)
        print('Finished parsing extra')
        extra_dict = extra_database.get_entry_dict()
    else:
        extra_dict = None

    #udata = dirty_text.decode("utf-8")
    #dirty_text = udata.encode("ascii", "ignore")
    #dirty_text = udata

    # parser = bparser.BibTexParser()
    # bib_database = parser.parse(dirty_text)
    # d = bib_database.get_entry_dict()

    print('BIBTEXPARSER LOAD')
    parser = bparser.BibTexParser(ignore_nonstandard_types=False,
                                  common_strings=True)
    ut.delete_keys(parser.alt_dict, ['url', 'urls'])
    print('Parsing bibtex file')
    bib_database = parser.parse(dirty_text, partial=False)
    print('Finished parsing')

    bibtex_dict = bib_database.get_entry_dict()
    old_keys = list(bibtex_dict.keys())
    new_keys = []
    for key in ub.ProgIter(old_keys, label='fixing keys'):
        new_key = key
        new_key = new_key.replace(':', '')
        new_key = new_key.replace('-', '_')
        new_key = re.sub('__*', '_', new_key)
        new_keys.append(new_key)

    # assert len(ut.find_duplicate_items(new_keys)) == 0, 'new keys created conflict'
    assert len(ub.find_duplicates(new_keys)) == 0, 'new keys created conflict'

    for key, new_key in zip(old_keys, new_keys):
        if key != new_key:
            entry = bibtex_dict[key]
            entry['ID'] = new_key
            bibtex_dict[new_key] = entry
            del bibtex_dict[key]

    # The bibtext is now clean. Print it to stdout
    #print(clean_text)
    verbose = None
    if verbose is None:
        verbose = 1

    # Find citations from the tex documents
    key_list = None
    if key_list is None:
        cacher = ub.Cacher('texcite1', enabled=0)
        data = cacher.tryload()
        if data is None:
            fpaths = testdata_fpaths()
            key_list, inverse = find_used_citations(fpaths,
                                                    return_inverse=True)
            # ignore = ['JP', '?', 'hendrick']
            # for item in ignore:
            #     try:
            #         key_list.remove(item)
            #     except ValueError:
            #         pass
            if verbose:
                print('Found %d citations used in the document' %
                      (len(key_list), ))
            data = key_list, inverse
            cacher.save(data)
        key_list, inverse = data

    # else:
    #     key_list = None

    unknown_pubkeys = []
    debug_author = ub.argval('--debug-author', default=None)
    # ./fix_bib.py --debug_author=Kappes

    if verbose:
        print('Fixing %d/%d bibtex entries' %
              (len(key_list), len(bibtex_dict)))

    # debug = True
    debug = False
    if debug_author is not None:
        debug = False

    known_keys = list(bibtex_dict.keys())
    missing_keys = set(key_list) - set(known_keys)
    if extra_dict is not None:
        missing_keys.difference_update(set(extra_dict.keys()))

    if missing_keys:
        print('The library is missing keys found in tex files %s' %
              (ub.repr2(missing_keys), ))

    # Search for possible typos:
    candidate_typos = {}
    sedlines = []
    for key in missing_keys:
        candidates = ut.closet_words(key, known_keys, num=3, subset=True)
        if len(candidates) > 1:
            top = candidates[0]
            if ut.edit_distance(key, top) == 1:
                # "sed -i -e 's/{}/{}/g' *.tex".format(key, top)
                import os
                replpaths = ' '.join(
                    [relpath(p, os.getcwd()) for p in inverse[key]])
                sedlines.append("sed -i -e 's/{}/{}/g' {}".format(
                    key, top, replpaths))
        candidate_typos[key] = candidates
        print('Cannot find key = %r' % (key, ))
        print('Did you mean? %r' % (candidates, ))

    print('Quick fixes')
    print('\n'.join(sedlines))

    # group by file
    just = max([0] + list(map(len, missing_keys)))
    missing_fpaths = [inverse[key] for key in missing_keys]
    for fpath in sorted(set(ub.flatten(missing_fpaths))):
        # ut.fix_embed_globals()
        subkeys = [k for k in missing_keys if fpath in inverse[k]]
        print('')
        ut.cprint('--- Missing Keys ---', 'blue')
        ut.cprint('fpath = %r' % (fpath, ), 'blue')
        ut.cprint('{} | {}'.format('Missing'.ljust(just), 'Did you mean?'),
                  'blue')
        for key in subkeys:
            print('{} | {}'.format(ut.highlight_text(key.ljust(just), 'red'),
                                   ' '.join(candidate_typos[key])))

    # for key in list(bibtex_dict.keys()):

    if extra_dict is not None:
        # Extra database takes precidence over regular
        key_list = list(ut.unique(key_list + list(extra_dict.keys())))
        for k, v in extra_dict.items():
            bibtex_dict[k] = v

    full = ub.argflag('--full')

    for key in key_list:
        try:
            entry = bibtex_dict[key]
        except KeyError:
            continue
        self = BibTexCleaner(key, entry, full=full)

        if debug_author is not None:
            debug = debug_author in entry.get('author', '')

        if debug:
            ut.cprint(' --- ENTRY ---', 'yellow')
            print(ub.repr2(entry, nl=1))

        entry = self.fix()
        # self.clip_abstract()
        # self.shorten_keys()
        # self.fix_authors()
        # self.fix_year()
        # old_pubval = self.fix_pubkey()
        # if old_pubval:
        #     unknown_pubkeys.append(old_pubval)
        # self.fix_arxiv()
        # self.fix_general()
        # self.fix_paper_types()

        if debug:
            print(ub.repr2(entry, nl=1))
            ut.cprint(' --- END ENTRY ---', 'yellow')
        bibtex_dict[key] = entry

    unwanted_keys = set(bibtex_dict.keys()) - set(key_list)
    if verbose:
        print('Removing unwanted %d entries' % (len(unwanted_keys)))
    ut.delete_dict_keys(bibtex_dict, unwanted_keys)

    if 0:
        d1 = bibtex_dict.copy()
        full = True
        for key, entry in d1.items():
            self = BibTexCleaner(key, entry, full=full)
            pub = self.publication()
            if pub is None:
                print(self.entry['ENTRYTYPE'])

            old = self.fix_pubkey()
            x1 = self._pubval()
            x2 = self.standard_pubval(full=full)
            # if x2 is not None and len(x2) > 5:
            #     print(ub.repr2(self.entry))

            if x1 != x2:
                print('x2 = %r' % (x2, ))
                print('x1 = %r' % (x1, ))
                print(ub.repr2(self.entry))

            # if 'CVPR' in self.entry.get('booktitle', ''):
            #     if 'CVPR' != self.entry.get('booktitle', ''):
            #         break
            if old:
                print('old = %r' % (old, ))
            d1[key] = self.entry

    if full:
        d1 = bibtex_dict.copy()

        import numpy as np
        import pandas as pd
        df = pd.DataFrame.from_dict(d1, orient='index')

        paged_items = df[~pd.isnull(df['pub_accro'])]
        has_pages = ~pd.isnull(paged_items['pages'])
        print('have pages {} / {}'.format(has_pages.sum(), len(has_pages)))
        print(ub.repr2(paged_items[~has_pages]['title'].values.tolist()))

        entrytypes = dict(list(df.groupby('pub_type')))
        if False:
            # entrytypes['misc']
            g = entrytypes['online']
            g = g[g.columns[~np.all(pd.isnull(g), axis=0)]]

            entrytypes['book']
            entrytypes['thesis']
            g = entrytypes['article']
            g = entrytypes['incollection']
            g = entrytypes['conference']

        def lookup_pub(e):
            if e == 'article':
                return 'journal', 'journal'
            elif e == 'incollection':
                return 'booksection', 'booktitle'
            elif e == 'conference':
                return 'conference', 'booktitle'
            return None, None

        for e, g in entrytypes.items():
            print('e = %r' % (e, ))
            g = g[g.columns[~np.all(pd.isnull(g), axis=0)]]
            if 'pub_full' in g.columns:
                place_title = g['pub_full'].tolist()
                print(ub.repr2(ub.dict_hist(place_title)))
            else:
                print('Unknown publications')

        if 'report' in entrytypes:
            g = entrytypes['report']
            missing = g[pd.isnull(g['title'])]
            if len(missing):
                print('Missing Title')
                print(ub.repr2(missing[['title', 'author']].values.tolist()))

        if 'journal' in entrytypes:
            g = entrytypes['journal']
            g = g[g.columns[~np.all(pd.isnull(g), axis=0)]]

            missing = g[pd.isnull(g['journal'])]
            if len(missing):
                print('Missing Journal')
                print(ub.repr2(missing[['title', 'author']].values.tolist()))

        if 'conference' in entrytypes:
            g = entrytypes['conference']
            g = g[g.columns[~np.all(pd.isnull(g), axis=0)]]

            missing = g[pd.isnull(g['booktitle'])]
            if len(missing):
                print('Missing Booktitle')
                print(ub.repr2(missing[['title', 'author']].values.tolist()))

        if 'incollection' in entrytypes:
            g = entrytypes['incollection']
            g = g[g.columns[~np.all(pd.isnull(g), axis=0)]]

            missing = g[pd.isnull(g['booktitle'])]
            if len(missing):
                print('Missing Booktitle')
                print(ub.repr2(missing[['title', 'author']].values.tolist()))

        if 'thesis' in entrytypes:
            g = entrytypes['thesis']
            g = g[g.columns[~np.all(pd.isnull(g), axis=0)]]
            missing = g[pd.isnull(g['institution'])]
            if len(missing):
                print('Missing Institution')
                print(ub.repr2(missing[['title', 'author']].values.tolist()))

        # import utool
        # utool.embed()

    # Overwrite BibDatabase structure
    bib_database._entries_dict = bibtex_dict
    bib_database.entries = list(bibtex_dict.values())

    #conftitle_to_types_set_hist = {key: set(val) for key, val in conftitle_to_types_hist.items()}
    #print(ub.repr2(conftitle_to_types_set_hist))

    print('Unknown conference keys:')
    print(ub.repr2(sorted(unknown_pubkeys)))
    print('len(unknown_pubkeys) = %r' % (len(unknown_pubkeys), ))

    writer = BibTexWriter()
    writer.contents = ['comments', 'entries']
    writer.indent = '  '
    writer.order_entries_by = ('type', 'author', 'year')

    new_bibtex_str = bibtexparser.dumps(bib_database, writer)

    # Need to check
    #jegou_aggregating_2012

    # Fix the Journal Abreviations
    # References:
    # https://www.ieee.org/documents/trans_journal_names.pdf

    # Write out clean bibfile in ascii format
    clean_bib_fpath = ub.augpath(bib_fpath.replace(' ', '_'), suffix='_clean')

    if not ub.argflag('--dryrun'):
        ut.writeto(clean_bib_fpath, new_bibtex_str)
コード例 #26
0
ファイル: segmentation.py プロジェクト: Erotemic/netharn
def setup_harn(cmdline=True, **kw):
    """
    CommandLine:
        xdoctest -m netharn.examples.segmentation setup_harn

    Example:
        >>> # xdoctest: +REQUIRES(--slow)
        >>> kw = {'workers': 0, 'xpu': 'cpu', 'batch_size': 2}
        >>> cmdline = False
        >>> # Just sets up the harness, does not do any heavy lifting
        >>> harn = setup_harn(cmdline=cmdline, **kw)
        >>> #
        >>> harn.initialize()
        >>> #
        >>> batch = harn._demo_batch(tag='train')
        >>> epoch_metrics = harn._demo_epoch(tag='vali', max_iter=2)
    """
    import sys
    import ndsampler
    import kwarray
    # kwarray.seed_global(2108744082)

    config = SegmentationConfig(default=kw)
    config.load(cmdline=cmdline)
    nh.configure_hacks(config)  # fix opencv bugs

    coco_datasets = nh.api.Datasets.coerce(config)
    print('coco_datasets = {}'.format(ub.repr2(coco_datasets)))
    for tag, dset in coco_datasets.items():
        dset._build_hashid(hash_pixels=False)

    workdir = ub.ensuredir(ub.expandpath(config['workdir']))
    samplers = {
        tag: ndsampler.CocoSampler(dset,
                                   workdir=workdir,
                                   backend=config['backend'])
        for tag, dset in coco_datasets.items()
    }

    for tag, sampler in ub.ProgIter(list(samplers.items()),
                                    desc='prepare frames'):
        try:
            sampler.frames.prepare(workers=config['workers'])
        except AttributeError:
            pass

    torch_datasets = {
        tag: SegmentationDataset(
            sampler,
            config['input_dims'],
            input_overlap=((tag == 'train') and config['input_overlap']),
            augmenter=((tag == 'train') and config['augmenter']),
        )
        for tag, sampler in samplers.items()
    }
    torch_loaders = {
        tag: torch_data.DataLoader(dset,
                                   batch_size=config['batch_size'],
                                   num_workers=config['workers'],
                                   shuffle=(tag == 'train'),
                                   drop_last=True,
                                   pin_memory=True)
        for tag, dset in torch_datasets.items()
    }

    if config['class_weights']:
        mode = config['class_weights']
        dset = torch_datasets['train']
        class_weights = _precompute_class_weights(dset,
                                                  mode=mode,
                                                  workers=config['workers'])
        class_weights = torch.FloatTensor(class_weights)
        class_weights[dset.classes.index('background')] = 0
    else:
        class_weights = None

    if config['normalize_inputs']:
        stats_dset = torch_datasets['train']
        stats_idxs = kwarray.shuffle(np.arange(len(stats_dset)),
                                     rng=0)[0:min(1000, len(stats_dset))]
        stats_subset = torch.utils.data.Subset(stats_dset, stats_idxs)
        cacher = ub.Cacher('dset_mean', cfgstr=stats_dset.input_id + 'v3')
        input_stats = cacher.tryload()
        if input_stats is None:
            loader = torch.utils.data.DataLoader(
                stats_subset,
                num_workers=config['workers'],
                shuffle=True,
                batch_size=config['batch_size'])
            running = nh.util.RunningStats()
            for batch in ub.ProgIter(loader, desc='estimate mean/std'):
                try:
                    running.update(batch['im'].numpy())
                except ValueError:  # final batch broadcast error
                    pass
            input_stats = {
                'std': running.simple(axis=None)['mean'].round(3),
                'mean': running.simple(axis=None)['std'].round(3),
            }
            cacher.save(input_stats)
    else:
        input_stats = {}

    print('input_stats = {!r}'.format(input_stats))

    # TODO: infer numbr of channels
    model_ = (SegmentationModel, {
        'arch': config['arch'],
        'input_stats': input_stats,
        'classes': torch_datasets['train'].classes.__json__(),
        'in_channels': 3,
    })

    initializer_ = nh.Initializer.coerce(config)
    # if config['init'] == 'cls':
    #     initializer_ = model_[0]._initializer_cls()

    # Create hyperparameters
    hyper = nh.HyperParams(
        nice=config['nice'],
        workdir=config['workdir'],
        xpu=nh.XPU.coerce(config['xpu']),
        datasets=torch_datasets,
        loaders=torch_loaders,
        model=model_,
        initializer=initializer_,
        scheduler=nh.Scheduler.coerce(config),
        optimizer=nh.Optimizer.coerce(config),
        dynamics=nh.Dynamics.coerce(config),
        criterion=(
            nh.criterions.FocalLoss,
            {
                'focus': config['focus'],
                'weight': class_weights,
                # 'reduction': 'none',
            }),
        monitor=(nh.Monitor, {
            'minimize': ['loss'],
            'patience': config['patience'],
            'max_epoch': config['max_epoch'],
            'smoothing': .6,
        }),
        other={
            'batch_size': config['batch_size'],
        },
        extra={
            'argv': sys.argv,
            'config': ub.repr2(config.asdict()),
        })

    # Create harness
    harn = SegmentationHarn(hyper=hyper)
    harn.classes = torch_datasets['train'].classes
    harn.preferences.update({
        'num_keep': 2,
        'keyboard_debug': True,
        # 'export_modules': ['netharn'],
    })
    harn.intervals.update({
        'vali': 1,
        'test': 10,
    })
    harn.script_config = config
    return harn
コード例 #27
0
def voc_eval(detpath,
             annopath,
             imagesetfile,
             classname,
             cachedir,
             ovthresh=0.5,
             use_07_metric=False,
             bias=1):
    """rec, prec, ap = voc_eval(detpath,
                                annopath,
                                imagesetfile,
                                classname,
                                [ovthresh],
                                [use_07_metric])
    Top level function that does the PASCAL VOC evaluation.
    detpath: Path to detections
        detpath.format(classname) should produce the detection results file.
    annopath: Path to annotations
        annopath.format(imagename) should be the xml annotations file.
    imagesetfile: Text file containing the list of images, one image per line.
    classname: Category name (duh)
    cachedir: Directory for caching the annotations
    [ovthresh]: Overlap threshold (default = 0.5)
    [use_07_metric]: Whether to use VOC07's 11 point AP computation
        (default False)
    """
    # assumes detections are in detpath.format(classname)
    # assumes annotations are in annopath.format(imagename)
    # assumes imagesetfile is a text file with each line an image name
    # cachedir caches the annotations in a pickle file

    # first load gt
    # if not os.path.isdir(cachedir):
    #     os.mkdir(cachedir)
    # cachefile = os.path.join(cachedir, 'annots.pkl')
    # read list of images
    with open(imagesetfile, 'r') as f:
        lines = f.readlines()
    imagenames = [x.strip().split(' ')[0] for x in lines]

    # not os.path.isfile(cachefile):
    # load annots
    import ubelt as ub
    cacher = ub.Cacher('voc_cachefile', cfgstr=ub.hash_data(imagenames))
    recs = cacher.tryload()
    if recs is None:
        recs = {}
        for i, imagename in enumerate(ub.ProgIter(imagenames, desc='reading')):
            recs[imagename] = parse_rec(annopath.format(imagename))
        cacher.save(recs)
        # save
        # print('Saving cached annotations to {:s}'.format(cachefile))
        # with open(cachefile, 'w') as f:
        #     cPickle.dump(recs, f)
    # else:
    #     # load
    #     with open(cachefile, 'r') as f:
    #         recs = cPickle.load(f)

    # extract gt objects for this class
    class_recs = {}
    npos = 0
    for imagename in imagenames:
        R = [obj for obj in recs[imagename] if obj['name'] == classname]
        bbox = np.array([x['bbox'] for x in R])
        difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
        det = [False] * len(R)
        npos = npos + sum(~difficult)
        class_recs[imagename] = {
            'bbox': bbox,
            'difficult': difficult,
            'det': det
        }

    # read dets
    detfile = detpath.format(classname)
    with open(detfile, 'r') as f:
        lines = f.readlines()

    splitlines = [x.strip().split(' ') for x in lines]
    image_ids = [x[0] for x in splitlines]
    confidence = np.array([float(x[1]) for x in splitlines])
    BB = np.array([[float(z) for z in x[2:]] for x in splitlines])

    # sort by confidence
    sorted_ind = np.argsort(-confidence)
    # sorted_scores = np.sort(-confidence)  #
    BB = BB[sorted_ind, :]
    image_ids = [image_ids[x] for x in sorted_ind]

    # go down dets and mark TPs and FPs
    nd = len(image_ids)
    tp = np.zeros(nd)
    fp = np.zeros(nd)
    for d in range(nd):
        R = class_recs[image_ids[d]]
        bb = BB[d, :].astype(float)
        ovmax = -np.inf
        BBGT = R['bbox'].astype(float)

        if BBGT.size > 0:
            # compute overlaps
            # intersection
            ixmin = np.maximum(BBGT[:, 0], bb[0])
            iymin = np.maximum(BBGT[:, 1], bb[1])
            ixmax = np.minimum(BBGT[:, 2], bb[2])
            iymax = np.minimum(BBGT[:, 3], bb[3])
            iw = np.maximum(ixmax - ixmin + bias, 0.)
            ih = np.maximum(iymax - iymin + bias, 0.)
            inters = iw * ih

            # union
            uni = ((bb[2] - bb[0] + bias) * (bb[3] - bb[1] + bias) +
                   (BBGT[:, 2] - BBGT[:, 0] + bias) *
                   (BBGT[:, 3] - BBGT[:, 1] + bias) - inters)

            overlaps = inters / uni
            ovmax = np.max(overlaps)
            jmax = np.argmax(overlaps)

        if ovmax > ovthresh:
            if not R['difficult'][jmax]:
                if not R['det'][jmax]:
                    tp[d] = 1.
                    R['det'][jmax] = 1
                else:
                    fp[d] = 1.
        else:
            fp[d] = 1.

    # compute precision recall
    fp = np.cumsum(fp)
    tp = np.cumsum(tp)
    rec = tp / float(npos)
    # avoid divide by zero in case the first detection matches a difficult
    # ground truth
    prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
    ap = voc_ap(rec, prec, use_07_metric)

    return rec, prec, ap
コード例 #28
0
ファイル: object_detection.py プロジェクト: Erotemic/netharn
def setup_harn(cmdline=True, **kw):
    """
    Ignore:
        >>> from object_detection import *  # NOQA
        >>> cmdline = False
        >>> kw = {
        >>>     'train_dataset': '~/data/VOC/voc-trainval.mscoco.json',
        >>>     'vali_dataset': '~/data/VOC/voc-test-2007.mscoco.json',
        >>> }
        >>> harn = setup_harn(**kw)
    """
    import ndsampler
    from ndsampler import coerce_data
    # Seed other global rngs just in case something uses them under the hood
    kwarray.seed_global(1129989262, offset=1797315558)

    config = DetectFitConfig(default=kw, cmdline=cmdline)

    nh.configure_hacks(config)  # fix opencv bugs
    ub.ensuredir(config['workdir'])

    # Load ndsampler.CocoDataset objects from info in the config
    subsets = coerce_data.coerce_datasets(config)

    samplers = {}
    for tag, subset in subsets.items():
        print('subset = {!r}'.format(subset))
        sampler = ndsampler.CocoSampler(subset, workdir=config['workdir'])
        samplers[tag] = sampler

    torch_datasets = {
        tag: DetectDataset(
            sampler,
            input_dims=config['input_dims'],
            augment=config['augment'] if (tag == 'train') else False,
        )
        for tag, sampler in samplers.items()
    }

    print('make loaders')
    loaders_ = {
        tag:
        torch.utils.data.DataLoader(dset,
                                    batch_size=config['batch_size'],
                                    num_workers=config['workers'],
                                    shuffle=(tag == 'train'),
                                    collate_fn=nh.data.collate.padded_collate,
                                    pin_memory=True)
        for tag, dset in torch_datasets.items()
    }
    # for x in ub.ProgIter(loaders_['train']):
    #     pass

    if config['normalize_inputs']:
        # Get stats on the dataset (todo: turn off augmentation for this)
        _dset = torch_datasets['train']
        stats_idxs = kwarray.shuffle(np.arange(len(_dset)),
                                     rng=0)[0:min(1000, len(_dset))]
        stats_subset = torch.utils.data.Subset(_dset, stats_idxs)
        cacher = ub.Cacher('dset_mean', cfgstr=_dset.input_id + 'v2')
        input_stats = cacher.tryload()
        if input_stats is None:
            # Use parallel workers to load data faster
            loader = torch.utils.data.DataLoader(
                stats_subset,
                collate_fn=nh.data.collate.padded_collate,
                num_workers=config['workers'],
                shuffle=True,
                batch_size=config['batch_size'])
            # Track moving average
            running = nh.util.RunningStats()
            for batch in ub.ProgIter(loader, desc='estimate mean/std'):
                try:
                    running.update(batch['im'].numpy())
                except ValueError:  # final batch broadcast error
                    pass
            input_stats = {
                'std': running.simple(axis=None)['mean'].round(3),
                'mean': running.simple(axis=None)['std'].round(3),
            }
            cacher.save(input_stats)
    else:
        input_stats = None
    print('input_stats = {!r}'.format(input_stats))

    initializer_ = nh.Initializer.coerce(config, leftover='kaiming_normal')
    print('initializer_ = {!r}'.format(initializer_))

    arch = config['arch']
    if arch == 'yolo2':

        if False:
            dset = samplers['train'].dset
            print('dset = {!r}'.format(dset))
            # anchors = yolo2.find_anchors(dset)

        anchors = np.array([(1.3221, 1.73145), (3.19275, 4.00944),
                            (5.05587, 8.09892), (9.47112, 4.84053),
                            (11.2364, 10.0071)])

        classes = samplers['train'].classes
        model_ = (yolo2.Yolo2, {
            'classes': classes,
            'anchors': anchors,
            'conf_thresh': 0.001,
            'nms_thresh': 0.5 if not ub.argflag('--eav') else 0.4
        })
        model = model_[0](**model_[1])
        model._initkw = model_[1]

        criterion_ = (
            yolo2.YoloLoss,
            {
                'coder': model.coder,
                'seen': 0,
                'coord_scale': 1.0,
                'noobject_scale': 1.0,
                'object_scale': 5.0,
                'class_scale': 1.0,
                'thresh': 0.6,  # iou_thresh
                # 'seen_thresh': 12800,
            })
    else:
        raise KeyError(arch)

    scheduler_ = nh.Scheduler.coerce(config)
    print('scheduler_ = {!r}'.format(scheduler_))

    optimizer_ = nh.Optimizer.coerce(config)
    print('optimizer_ = {!r}'.format(optimizer_))

    dynamics_ = nh.Dynamics.coerce(config)
    print('dynamics_ = {!r}'.format(dynamics_))

    xpu = nh.XPU.coerce(config['xpu'])
    print('xpu = {!r}'.format(xpu))

    import sys

    hyper = nh.HyperParams(
        **{
            'nice':
            config['nice'],
            'workdir':
            config['workdir'],
            'datasets':
            torch_datasets,
            'loaders':
            loaders_,
            'xpu':
            xpu,
            'model':
            model,
            'criterion':
            criterion_,
            'initializer':
            initializer_,
            'optimizer':
            optimizer_,
            'dynamics':
            dynamics_,

            # 'optimizer': (torch.optim.SGD, {
            #     'lr': lr_step_points[0],
            #     'momentum': 0.9,
            #     'dampening': 0,
            #     # multiplying by batch size was one of those unpublished details
            #     'weight_decay': decay * simulated_bsize,
            # }),
            'scheduler':
            scheduler_,
            'monitor': (
                nh.Monitor,
                {
                    'minimize': ['loss'],
                    # 'maximize': ['mAP'],
                    'patience': config['patience'],
                    'max_epoch': config['max_epoch'],
                    'smoothing': .6,
                }),
            'other': {
                # Other params are not used internally, so you are free to set any
                # extra params specific to your algorithm, and still have them
                # logged in the hyperparam structure. For YOLO this is `ovthresh`.
                'batch_size': config['batch_size'],
                'nice': config['nice'],
                'ovthresh': config['ovthresh'],  # used in mAP computation
            },
            'extra': {
                'config': ub.repr2(config.asdict()),
                'argv': sys.argv,
            }
        })
    print('hyper = {!r}'.format(hyper))
    print('make harn')
    harn = DetectHarn(hyper=hyper)
    harn.preferences.update({
        'num_keep': 2,
        'keep_freq': 30,
        'export_modules': ['netharn'],  # TODO
        'prog_backend': 'progiter',  # alternative: 'tqdm'
        'keyboard_debug': True,
    })
    harn.intervals.update({
        'log_iter_train': 50,
    })
    harn.fit_config = config
    print('harn = {!r}'.format(harn))
    print('samplers = {!r}'.format(samplers))
    return harn
コード例 #29
0
ファイル: api.py プロジェクト: Kitware/netharn
def _coerce_datasets(config):
    import netharn as nh
    import ndsampler
    import numpy as np
    from torchvision import transforms
    coco_datasets = nh.api.Datasets.coerce(config)
    print('coco_datasets = {}'.format(ub.repr2(coco_datasets, nl=1)))
    for tag, dset in coco_datasets.items():
        dset._build_hashid(hash_pixels=False)

    workdir = ub.ensuredir(ub.expandpath(config['workdir']))
    samplers = {
        tag: ndsampler.CocoSampler(dset, workdir=workdir, backend=config['sampler_backend'])
        for tag, dset in coco_datasets.items()
    }

    for tag, sampler in ub.ProgIter(list(samplers.items()), desc='prepare frames'):
        sampler.frames.prepare(workers=config['workers'])

    # TODO: basic ndsampler torch dataset, likely has to support the transforms
    # API, bleh.

    transform = transforms.Compose([
        transforms.Resize(config['input_dims']),
        transforms.CenterCrop(config['input_dims']),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    torch_datasets = {
        key: SamplerDataset(
            sapmler, transform=transform,
            # input_dims=config['input_dims'],
            # augmenter=config['augmenter'] if key == 'train' else None,
        )
        for key, sapmler in samplers.items()
    }
    # self = torch_dset = torch_datasets['train']

    if config['normalize_inputs']:
        # Get stats on the dataset (todo: turn off augmentation for this)
        import kwarray
        _dset = torch_datasets['train']
        stats_idxs = kwarray.shuffle(np.arange(len(_dset)), rng=0)[0:min(1000, len(_dset))]
        stats_subset = torch.utils.data.Subset(_dset, stats_idxs)

        cacher = ub.Cacher('dset_mean', cfgstr=_dset.input_id + 'v3')
        input_stats = cacher.tryload()

        from netharn.data.channel_spec import ChannelSpec
        channels = ChannelSpec.coerce(config['channels'])

        if input_stats is None:
            # Use parallel workers to load data faster
            from netharn.data.data_containers import container_collate
            from functools import partial
            collate_fn = partial(container_collate, num_devices=1)

            loader = torch.utils.data.DataLoader(
                stats_subset,
                collate_fn=collate_fn,
                num_workers=config['workers'],
                shuffle=True,
                batch_size=config['batch_size'])

            # Track moving average of each fused channel stream
            channel_stats = {key: nh.util.RunningStats()
                             for key in channels.keys()}
            assert len(channel_stats) == 1, (
                'only support one fused stream for now')
            for batch in ub.ProgIter(loader, desc='estimate mean/std'):
                if isinstance(batch, (tuple, list)):
                    inputs = {'rgb': batch[0]}  # make assumption
                else:
                    inputs = batch['inputs']

                for key, val in inputs.items():
                    try:
                        for part in val.numpy():
                            channel_stats[key].update(part)
                    except ValueError:  # final batch broadcast error
                        pass

            perchan_input_stats = {}
            for key, running in channel_stats.items():
                running = ub.peek(channel_stats.values())
                perchan_stats = running.simple(axis=(1, 2))
                perchan_input_stats[key] = {
                    'std': perchan_stats['mean'].round(3),
                    'mean': perchan_stats['std'].round(3),
                }

            input_stats = ub.peek(perchan_input_stats.values())
            cacher.save(input_stats)
    else:
        input_stats = {}

    torch_loaders = {
        tag: dset.make_loader(
            batch_size=config['batch_size'],
            num_batches=config['num_batches'],
            num_workers=config['workers'],
            shuffle=(tag == 'train'),
            balance=(config['balance'] if tag == 'train' else None),
            pin_memory=True)
        for tag, dset in torch_datasets.items()
    }

    dataset_info = {
        'torch_datasets': torch_datasets,
        'torch_loaders': torch_loaders,
        'input_stats': input_stats
    }
    return dataset_info
コード例 #30
0
ファイル: ggr_matching.py プロジェクト: Kitware/netharn
    def __init__(self, sampler, workdir=None, augment=False, dim=416):
        print('make AnnotCocoDataset')

        cacher = ub.Cacher('aid_pccs_v2',
                           cfgstr=sampler.dset.tag,
                           verbose=True)
        aid_pccs = cacher.tryload()
        if aid_pccs is None:
            aid_pccs = extract_ggr_pccs(sampler.dset)
            cacher.save(aid_pccs)
        self.aid_pccs = aid_pccs
        self.sampler = sampler

        self.aids = sorted(ub.flatten(self.aid_pccs))
        self.aid_to_index = aid_to_index = {
            aid: index
            for index, aid in enumerate(self.aids)
        }

        # index pccs
        self.index_pccs = [
            frozenset(aid_to_index[aid] for aid in pcc)
            for pcc in self.aid_pccs
        ]

        self.nx_to_aidpcc = {nx: pcc for nx, pcc in enumerate(self.aid_pccs)}
        self.nx_to_indexpcc = {
            nx: pcc
            for nx, pcc in enumerate(self.index_pccs)
        }

        self.aid_to_nx = {
            aid: nx
            for nx, pcc in self.nx_to_aidpcc.items() for aid in pcc
        }
        self.index_to_nx = {
            index: nx
            for nx, pcc in self.nx_to_indexpcc.items() for index in pcc
        }

        self.aid_to_tx = {
            aid: tx
            for tx, aid in enumerate(sampler.regions.targets['aid'])
        }

        window_dim = dim
        self.dim = window_dim
        self.window_dim = window_dim
        self.dims = (window_dim, window_dim)

        self.rng = kwarray.ensure_rng(0)
        if augment:
            import imgaug.augmenters as iaa
            self.independent = iaa.Sequential([
                iaa.Sometimes(
                    0.2, nh.data.transforms.HSVShift(hue=0.1, sat=1.5,
                                                     val=1.5)),
                iaa.Crop(percent=(0, .2)),
            ])
            self.dependent = iaa.Sequential([iaa.Fliplr(p=.5)])
            # NOTE: we are only using `self.augmenter` to make a hyper hashid
            # in __getitem__ we invoke transform explicitly for fine control
            self.augmenter = iaa.Sequential([
                self.independent,
                self.dependent,
            ])
        else:
            self.augmenter = None
        self.letterbox = nh.data.transforms.Resize(target_size=self.dims,
                                                   fill_color=0,
                                                   mode='letterbox')