Ejemplo n.º 1
0
def task_datasets(task, vali_frac=0, colorspace='RGB'):
    learn, test = next(task.xval_splits())
    learn.tag = 'learn'

    # Split everything in the learning set into training / validation
    n_learn = len(learn)
    n_vali = int(n_learn * vali_frac)
    train = learn[n_vali:]
    vali = learn[:n_vali]
    vali.tag = 'vali'
    train.tag = 'train'

    if ub.argflag('--all'):
        # HACK EVERYTHING TOGETHER
        train = learn + test
        from clab import inputs
        vali = inputs.Inputs()
        test = inputs.Inputs()

    train_dataset = SSegInputsWrapper(train, task, colorspace=colorspace)
    vali_dataset = SSegInputsWrapper(vali, task, colorspace=colorspace)
    test_dataset = SSegInputsWrapper(test, task, colorspace=colorspace)
    # train_dataset.augment = True

    print('* len(train_dataset) = {}'.format(len(train_dataset)))
    print('* len(vali_dataset) = {}'.format(len(vali_dataset)))
    print('* len(test_dataset) = {}'.format(len(test_dataset)))
    datasets = {
        'train': train_dataset,
        'vali': vali_dataset,
        'test': test_dataset,
    }
    return datasets
Ejemplo n.º 2
0
def load_task_dataset(taskname,
                      vali_frac=0,
                      colorspace='RGB',
                      combine=None,
                      boundary=True,
                      arch=None,
                      halfcombo=None):
    task = get_task(taskname, boundary=boundary, arch=arch)
    learn, test = next(task.xval_splits())
    learn.tag = 'learn'

    # Split everything in the learning set into training / validation
    n_learn = len(learn)
    n_vali = int(n_learn * vali_frac)
    train = learn[n_vali:]
    vali = learn[:n_vali]
    vali.tag = 'vali'
    train.tag = 'train'

    if combine is None:
        combine = ub.argflag('--combine')

    if halfcombo is None:
        halfcombo = ub.argflag('--halfcombo')

    if halfcombo:
        # decrease testing for training
        n = len(test) // 2
        new_test = test[n:]
        train = learn + test[:n]
        test = new_test
        train.tag = 'train_h'
        test.tag = 'test_h'

    if combine:
        # HACK EVERYTHING TOGETHER
        train = learn + test
        from clab import inputs
        vali = inputs.Inputs()
        test = inputs.Inputs()

    train_dataset = SSegInputsWrapper(train, task, colorspace=colorspace)
    vali_dataset = SSegInputsWrapper(vali, task, colorspace=colorspace)
    test_dataset = SSegInputsWrapper(test, task, colorspace=colorspace)
    # train_dataset.augment = True

    print('* len(train_dataset) = {}'.format(len(train_dataset)))
    print('* len(vali_dataset) = {}'.format(len(vali_dataset)))
    print('* len(test_dataset) = {}'.format(len(test_dataset)))
    datasets = {
        'train': train_dataset,
        'vali': vali_dataset,
        'test': test_dataset,
    }
    return datasets
Ejemplo n.º 3
0
    def _mode_new_input(prep, mode, input, clear=False, mult=1):
        out_dpaths = prep._mode_paths(mode, input, clear=clear)
        new_input = inputs.Inputs()
        new_input.tag = mode
        if 'im' in out_dpaths:
            new_input.imdir = out_dpaths['im']
        if 'gt' in out_dpaths:
            new_input.gtdir = out_dpaths['gt']
        if 'aux' in out_dpaths:
            new_input.auxdir = out_dpaths['aux']

        if not clear:
            try:
                new_input.prepare_image_paths()
            except AssertionError:  # hack
                return prep._mode_new_input(mode, input, clear=True)

            if len(new_input.paths) > 0:
                n_loaded = min(map(len, new_input.paths.values()))
                min_n_expected = len(input) * mult
                print(' * n_loaded = {!r}'.format(n_loaded))
                print(' * min_n_expected = {!r}'.format(min_n_expected))
                # Short curcuit augmentation if we found stuff
                if n_loaded >= min_n_expected:
                    print('short circuit {}'.format(mode))
                    return new_input, True

        if 'im' in out_dpaths:
            new_input.im_paths = []
        if 'gt' in out_dpaths:
            new_input.gt_paths = []
        if 'aux' in out_dpaths:
            new_input.aux_paths = {k: [] for k in input.aux_paths.keys()}
        return new_input, False