def task_datasets(task, vali_frac=0, colorspace='RGB'): learn, test = next(task.xval_splits()) learn.tag = 'learn' # Split everything in the learning set into training / validation n_learn = len(learn) n_vali = int(n_learn * vali_frac) train = learn[n_vali:] vali = learn[:n_vali] vali.tag = 'vali' train.tag = 'train' if ub.argflag('--all'): # HACK EVERYTHING TOGETHER train = learn + test from clab import inputs vali = inputs.Inputs() test = inputs.Inputs() train_dataset = SSegInputsWrapper(train, task, colorspace=colorspace) vali_dataset = SSegInputsWrapper(vali, task, colorspace=colorspace) test_dataset = SSegInputsWrapper(test, task, colorspace=colorspace) # train_dataset.augment = True print('* len(train_dataset) = {}'.format(len(train_dataset))) print('* len(vali_dataset) = {}'.format(len(vali_dataset))) print('* len(test_dataset) = {}'.format(len(test_dataset))) datasets = { 'train': train_dataset, 'vali': vali_dataset, 'test': test_dataset, } return datasets
def load_task_dataset(taskname, vali_frac=0, colorspace='RGB', combine=None, boundary=True, arch=None, halfcombo=None): task = get_task(taskname, boundary=boundary, arch=arch) learn, test = next(task.xval_splits()) learn.tag = 'learn' # Split everything in the learning set into training / validation n_learn = len(learn) n_vali = int(n_learn * vali_frac) train = learn[n_vali:] vali = learn[:n_vali] vali.tag = 'vali' train.tag = 'train' if combine is None: combine = ub.argflag('--combine') if halfcombo is None: halfcombo = ub.argflag('--halfcombo') if halfcombo: # decrease testing for training n = len(test) // 2 new_test = test[n:] train = learn + test[:n] test = new_test train.tag = 'train_h' test.tag = 'test_h' if combine: # HACK EVERYTHING TOGETHER train = learn + test from clab import inputs vali = inputs.Inputs() test = inputs.Inputs() train_dataset = SSegInputsWrapper(train, task, colorspace=colorspace) vali_dataset = SSegInputsWrapper(vali, task, colorspace=colorspace) test_dataset = SSegInputsWrapper(test, task, colorspace=colorspace) # train_dataset.augment = True print('* len(train_dataset) = {}'.format(len(train_dataset))) print('* len(vali_dataset) = {}'.format(len(vali_dataset))) print('* len(test_dataset) = {}'.format(len(test_dataset))) datasets = { 'train': train_dataset, 'vali': vali_dataset, 'test': test_dataset, } return datasets
def _mode_new_input(prep, mode, input, clear=False, mult=1): out_dpaths = prep._mode_paths(mode, input, clear=clear) new_input = inputs.Inputs() new_input.tag = mode if 'im' in out_dpaths: new_input.imdir = out_dpaths['im'] if 'gt' in out_dpaths: new_input.gtdir = out_dpaths['gt'] if 'aux' in out_dpaths: new_input.auxdir = out_dpaths['aux'] if not clear: try: new_input.prepare_image_paths() except AssertionError: # hack return prep._mode_new_input(mode, input, clear=True) if len(new_input.paths) > 0: n_loaded = min(map(len, new_input.paths.values())) min_n_expected = len(input) * mult print(' * n_loaded = {!r}'.format(n_loaded)) print(' * min_n_expected = {!r}'.format(min_n_expected)) # Short curcuit augmentation if we found stuff if n_loaded >= min_n_expected: print('short circuit {}'.format(mode)) return new_input, True if 'im' in out_dpaths: new_input.im_paths = [] if 'gt' in out_dpaths: new_input.gt_paths = [] if 'aux' in out_dpaths: new_input.aux_paths = {k: [] for k in input.aux_paths.keys()} return new_input, False