def trainloader_dct_subset(args): traindir = os.path.join(args.data, 'train') train_dataset = ImageFolderDCT(traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.TransformDCT(), transforms.ToTensorDCT(), transforms.SubsetDCT(args.subset_channels), transforms.NormalizeDCT( train_y_mean, train_y_std, train_cb_mean, train_cb_std, train_cr_mean, train_cr_std), ])) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) train_loader_len = len(train_loader) return train_loader, train_sampler, train_loader_len
def trainloader_dct_resized(args): traindir = os.path.join(args.data, 'train') train_dataset = ImageFolderDCT(traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.TransformDCT(), # 28x28x192 transforms.DCTFlatten2D(), transforms.UpsampleDCT(upscale_ratio_h=4, upscale_ratio_w=4, debug=False), transforms.ToTensorDCT(), transforms.SubsetDCT(channels=args.subset), transforms.Aggregate(), transforms.NormalizeDCT( train_dct_subset_mean, train_dct_subset_std, channels=args.subset ) ])) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) train_loader_len = len(train_loader) return train_loader, train_sampler, train_loader_len
def cvt_transform(self, img): return cvtransforms.Compose([ cvtransforms.RandomResizedCrop(self.img_size), # cvtransforms.RandomHorizontalFlip(), cvtransforms.Upscale(upscale_factor=2), cvtransforms.TransformUpscaledDCT(), cvtransforms.ToTensorDCT(), cvtransforms.SubsetDCT(channels=192), cvtransforms.Aggregate(), cvtransforms.NormalizeDCT(train_upscaled_static_mean, train_upscaled_static_std, channels=192) ])(img)
def trainloader_upscaled_static(args, model='mobilenet'): traindir = os.path.join(args.data, 'train') if model == 'mobilenet': input_size = 896 elif model == 'resnet': input_size = 448 else: raise NotImplementedError transform = transforms.Compose([ transforms.RandomResizedCrop(input_size), transforms.RandomHorizontalFlip(), transforms.Upscale(upscale_factor=2), transforms.TransformUpscaledDCT(), transforms.ToTensorDCT(), transforms.SubsetDCT(channels=args.subset, pattern=args.pattern), transforms.Aggregate(), transforms.NormalizeDCT( train_upscaled_static_mean, train_upscaled_static_std, channels=args.subset, pattern=args.pattern ) ]) train_dataset = ImageFolderDCT(traindir, transform) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) train_loader_len = len(train_loader) return train_loader, train_sampler, train_loader_len
def get_composed_transform_dct(self, aug=False, filter_size=8): # print("aug: ", aug) # print("filter size,", filter_size) if aug == False: transform = transforms_dct.Compose([ #transform_funcs, transforms_dct.Resize(int(filter_size * 56 * 1.15)), transforms_dct.CenterCrop(filter_size * 56), transforms_dct.GetDCT(filter_size), transforms_dct.UpScaleDCT(size=56), transforms_dct.ToTensorDCT(), transforms_dct.SubsetDCT(channels=24), transforms_dct.Aggregate(), transforms_dct.NormalizeDCT( # train_y_mean_resized, train_y_std_resized, # train_cb_mean_resized, train_cb_std_resized, # train_cr_mean_resized, train_cr_std_resized), train_upscaled_static_mean, train_upscaled_static_std, channels=24) #transforms_dct.Aggregate() ]) else: transform = transforms_dct.Compose([ #transform_funcs, transforms_dct.RandomResizedCrop(filter_size * 56), transforms_dct.ImageJitter(self.jitter_param), transforms_dct.RandomHorizontalFlip(), transforms_dct.GetDCT(filter_size), transforms_dct.UpScaleDCT(size=56), transforms_dct.ToTensorDCT(), transforms_dct.SubsetDCT(channels=24), transforms_dct.Aggregate(), transforms_dct.NormalizeDCT( # train_y_mean_resized, train_y_std_resized, # train_cb_mean_resized, train_cb_std_resized, # train_cr_mean_resized, train_cr_std_resized), train_upscaled_static_mean, train_upscaled_static_std, channels=24) ]) return transform
train_y_std_resized, train_cb_mean_resized, train_cb_std_resized, train_cr_mean_resized, train_cr_std_resized), ])), batch_size=1, shuffle=False, num_workers=1, pin_memory=False) # train_dataset = ImageFolderDCT('/mnt/ssd/kai.x/dataset/ILSVRC2012/train', transforms.Compose([ train_dataset = ImageFolderDCT( '/storage-t1/user/kaixu/datasets/ILSVRC2012/train', transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToYCrCb(), transforms.ChromaSubsample(), transforms.UpsampleDCT(size=224, T=896, debug=False), transforms.ToTensorDCT(), transforms.NormalizeDCT(train_y_mean_resized, train_y_std_resized, train_cb_mean_resized, train_cb_std_resized, train_cr_mean_resized, train_cr_std_resized), ])) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1,
# transforms.NormalizeDCT( # train_y_mean_resized, train_y_std_resized, # train_cb_mean_resized, train_cb_std_resized, # train_cr_mean_resized, train_cr_std_resized), # ]) # transform3 =transforms.Compose([ # transforms.RandomResizedCrop(224), # transforms.RandomHorizontalFlip(), # transforms.ResizedTransformDCT(), # transforms.ToTensorDCT(), # transforms.SubsetDCT(32), # ]) transform4 = transforms.Compose([ transforms.RandomResizedCrop(896), transforms.RandomHorizontalFlip(), transforms.Upscale(upscale_factor=2), transforms.TransformUpscaledDCT(), transforms.ToTensorDCT(), transforms.SubsetDCT(channels='24'), transforms.Aggregate(), transforms.NormalizeDCT( train_upscaled_static_mean, train_upscaled_static_std, channels='24' ) ]) transform5 = transforms.Compose([ transforms.DCTFlatten2D(),