def trainloader_dct_subset(args): traindir = os.path.join(args.data, 'train') train_dataset = ImageFolderDCT(traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.TransformDCT(), transforms.ToTensorDCT(), transforms.SubsetDCT(args.subset_channels), transforms.NormalizeDCT( train_y_mean, train_y_std, train_cb_mean, train_cb_std, train_cr_mean, train_cr_std), ])) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) train_loader_len = len(train_loader) return train_loader, train_sampler, train_loader_len
def trainloader_dct_resized(args): traindir = os.path.join(args.data, 'train') train_dataset = ImageFolderDCT(traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.TransformDCT(), # 28x28x192 transforms.DCTFlatten2D(), transforms.UpsampleDCT(upscale_ratio_h=4, upscale_ratio_w=4, debug=False), transforms.ToTensorDCT(), transforms.SubsetDCT(channels=args.subset), transforms.Aggregate(), transforms.NormalizeDCT( train_dct_subset_mean, train_dct_subset_std, channels=args.subset ) ])) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) train_loader_len = len(train_loader) return train_loader, train_sampler, train_loader_len
def trainloader_upscaled_static(args, model='mobilenet'): traindir = os.path.join(args.data, 'train') if model == 'mobilenet': input_size = 896 elif model == 'resnet': input_size = 448 else: raise NotImplementedError transform = transforms.Compose([ transforms.RandomResizedCrop(input_size), transforms.RandomHorizontalFlip(), transforms.Upscale(upscale_factor=2), transforms.TransformUpscaledDCT(), transforms.ToTensorDCT(), transforms.SubsetDCT(channels=args.subset, pattern=args.pattern), transforms.Aggregate(), transforms.NormalizeDCT( train_upscaled_static_mean, train_upscaled_static_std, channels=args.subset, pattern=args.pattern ) ]) train_dataset = ImageFolderDCT(traindir, transform) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) train_loader_len = len(train_loader) return train_loader, train_sampler, train_loader_len
def get_composed_transform_dct(self, aug=False, filter_size=8): # print("aug: ", aug) # print("filter size,", filter_size) if aug == False: transform = transforms_dct.Compose([ #transform_funcs, transforms_dct.Resize(int(filter_size * 56 * 1.15)), transforms_dct.CenterCrop(filter_size * 56), transforms_dct.GetDCT(filter_size), transforms_dct.UpScaleDCT(size=56), transforms_dct.ToTensorDCT(), transforms_dct.SubsetDCT(channels=24), transforms_dct.Aggregate(), transforms_dct.NormalizeDCT( # train_y_mean_resized, train_y_std_resized, # train_cb_mean_resized, train_cb_std_resized, # train_cr_mean_resized, train_cr_std_resized), train_upscaled_static_mean, train_upscaled_static_std, channels=24) #transforms_dct.Aggregate() ]) else: transform = transforms_dct.Compose([ #transform_funcs, transforms_dct.RandomResizedCrop(filter_size * 56), transforms_dct.ImageJitter(self.jitter_param), transforms_dct.RandomHorizontalFlip(), transforms_dct.GetDCT(filter_size), transforms_dct.UpScaleDCT(size=56), transforms_dct.ToTensorDCT(), transforms_dct.SubsetDCT(channels=24), transforms_dct.Aggregate(), transforms_dct.NormalizeDCT( # train_y_mean_resized, train_y_std_resized, # train_cb_mean_resized, train_cb_std_resized, # train_cr_mean_resized, train_cr_std_resized), train_upscaled_static_mean, train_upscaled_static_std, channels=24) ]) return transform
def trainloader_upscaled_dct_direct(args, model='mobilenet'): if model == 'mobilenet': input_size = 112 elif model == 'resnet': input_size = 56 else: raise NotImplementedError traindir = os.path.join(args.data, 'train') transform = transforms.Compose([ transforms.UpsampleCbCr(), transforms.SubsetDCT2(channels=args.subset, pattern=args.pattern), transforms.RandomResizedCropDCT(size=input_size), transforms.Aggregate2(), transforms.RandomHorizontalFlip(), transforms.ToTensorDCT2(), transforms.NormalizeDCT( train_upscaled_static_dct_direct_mean_interp, train_upscaled_static_dct_direct_std_interp, channels=args.subset, pattern=args.pattern ) ]) train_dataset = ImageFolderDCT(traindir, transform, backend='dct') if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) train_loader_len = len(train_loader) return train_loader, train_sampler, train_loader_len
def get_mean_and_std_yuv(dataset): '''Compute the mean and std value of dataset.''' dataloader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder(dataset, transforms.Compose([ transforms.RandomResizedCropDCT(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.AverageYUV() ]), loader=yuv_loader), batch_size=128, shuffle=False, num_workers=16) mean = torch.zeros(3) std = torch.zeros(3) print('==> Computing mean and std..') for idx, (inputs, targets) in enumerate(dataloader): mean += inputs.mean(dim=0) std += inputs.std(dim=0) # for i in range(3): # mean[i] += inputs[:,i].mean() # std[i] += inputs[:,i].std() mean.div_(idx+1) std.div_(idx+1) return mean, std
train_cb_mean_resized, train_cb_std_resized, train_cr_mean_resized, train_cr_std_resized), ])), batch_size=1, shuffle=False, num_workers=1, pin_memory=False) # train_dataset = ImageFolderDCT('/mnt/ssd/kai.x/dataset/ILSVRC2012/train', transforms.Compose([ train_dataset = ImageFolderDCT( '/storage-t1/user/kaixu/datasets/ILSVRC2012/train', transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToYCrCb(), transforms.ChromaSubsample(), transforms.UpsampleDCT(size=224, T=896, debug=False), transforms.ToTensorDCT(), transforms.NormalizeDCT(train_y_mean_resized, train_y_std_resized, train_cb_mean_resized, train_cb_std_resized, train_cr_mean_resized, train_cr_std_resized), ])) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False,