def get_colorize_data( sz: int, bs: int, crappy_path: Path, good_path: Path, random_seed: int = None, keep_pct: float = 1.0, num_workers: int = 8, stats: tuple = imagenet_stats, xtra_tfms=[], ) -> ImageDataBunch: src = (ImageImageList.from_folder(crappy_path, convert_mode='RGB').use_partial_data( sample_pct=keep_pct, seed=random_seed).split_by_rand_pct( 0.1, seed=random_seed)) data = (src.label_from_func( lambda x: good_path / x.relative_to(crappy_path)).transform( get_transforms(max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms), size=sz, tfm_y=True, ).databunch(bs=bs, num_workers=num_workers, no_check=True).normalize(stats, do_y=True)) data.c = 3 return data
def get_DIV2k_data(pLow, pFull, bs: int, sz: int): """Given the path of low resolution images with a proper suffix returns a databunch """ suffixes = { "dataset/DIV2K_train_LR_x8": "x8", "dataset/DIV2K_train_LR_difficult": "x4d", "dataset/DIV2K_train_LR_mild": "x4m" } lowResSuffix = suffixes[str(pLow)] src = ImageImageList.from_folder(pLow, presort=True).split_by_idxs( train_idx=list(range(0, 800)), valid_idx=list(range(800, 900))) data = (src.label_from_func( lambda x: pFull / (x.name).replace(lowResSuffix, '')).transform( get_transforms(max_rotate=30, max_zoom=3., max_lighting=.4, max_warp=.4, p_affine=.85), size=sz, tfm_y=True, ).databunch(bs=bs, num_workers=8, no_check=True).normalize(imagenet_stats, do_y=True)) data.c = 3 return data
def get_dummy_databunch(bs: int, sz: int): """Returns sz databunch """ path = Path('./dataset/dummy/') src = ImageImageList.from_folder(path).split_none() data = (src.label_from_func( lambda x: path / (x.name.replace(".jpg", ".png"))).transform( size=sz, tfm_y=True).databunch(bs=bs, num_workers=1, no_check=True).normalize(imagenet_stats, do_y=True)) data.c = 3 return data
def get_DIV2k_data_QF(pLow, pFull, bs: int, sz: int): """Given the path of low resolution images returns a databunch """ src = ImageImageList.from_folder(pLow, presort=True).split_by_idxs( train_idx=list(range(0, 800)), valid_idx=list(range(800, 900))) data = (src.label_from_func( lambda x: pFull / (x.name.replace(".jpg", ".png"))).transform( get_transforms(max_zoom=2.), size=sz, tfm_y=True).databunch(bs=bs, num_workers=8, no_check=True).normalize(imagenet_stats, do_y=True)) data.c = 3 return data