Exemplo n.º 1
0
def trainloader_dct_subset(args):
    traindir = os.path.join(args.data, 'train')
    train_dataset = ImageFolderDCT(traindir, transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.TransformDCT(),
        transforms.ToTensorDCT(),
        transforms.SubsetDCT(args.subset_channels),
        transforms.NormalizeDCT(
            train_y_mean, train_y_std,
            train_cb_mean, train_cb_std,
            train_cr_mean, train_cr_std),
    ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    train_loader_len = len(train_loader)

    return train_loader, train_sampler, train_loader_len
Exemplo n.º 2
0
def trainloader_dct_resized(args):
    traindir = os.path.join(args.data, 'train')
    train_dataset = ImageFolderDCT(traindir, transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.TransformDCT(),  # 28x28x192
        transforms.DCTFlatten2D(),
        transforms.UpsampleDCT(upscale_ratio_h=4, upscale_ratio_w=4, debug=False),
        transforms.ToTensorDCT(),
        transforms.SubsetDCT(channels=args.subset),
        transforms.Aggregate(),
        transforms.NormalizeDCT(
            train_dct_subset_mean,
            train_dct_subset_std,
            channels=args.subset
        )
    ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    train_loader_len = len(train_loader)

    return train_loader, train_sampler, train_loader_len
Exemplo n.º 3
0
 def cvt_transform(self, img):
     return cvtransforms.Compose([
         cvtransforms.RandomResizedCrop(self.img_size),
         # cvtransforms.RandomHorizontalFlip(),
         cvtransforms.Upscale(upscale_factor=2),
         cvtransforms.TransformUpscaledDCT(),
         cvtransforms.ToTensorDCT(),
         cvtransforms.SubsetDCT(channels=192),
         cvtransforms.Aggregate(),
         cvtransforms.NormalizeDCT(train_upscaled_static_mean,
                                   train_upscaled_static_std,
                                   channels=192)
     ])(img)
Exemplo n.º 4
0
def trainloader_upscaled_static(args, model='mobilenet'):
    traindir = os.path.join(args.data, 'train')

    if model == 'mobilenet':
        input_size = 896
    elif model == 'resnet':
        input_size = 448
    else:
        raise NotImplementedError

    transform = transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.Upscale(upscale_factor=2),
        transforms.TransformUpscaledDCT(),
        transforms.ToTensorDCT(),
        transforms.SubsetDCT(channels=args.subset, pattern=args.pattern),
        transforms.Aggregate(),
        transforms.NormalizeDCT(
            train_upscaled_static_mean,
            train_upscaled_static_std,
            channels=args.subset,
            pattern=args.pattern
        )
    ])

    train_dataset = ImageFolderDCT(traindir, transform)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    train_loader_len = len(train_loader)

    return train_loader, train_sampler, train_loader_len
Exemplo n.º 5
0
 def get_composed_transform_dct(self, aug=False, filter_size=8):
     # print("aug: ", aug)
     # print("filter size,", filter_size)
     if aug == False:
         transform = transforms_dct.Compose([  #transform_funcs,
             transforms_dct.Resize(int(filter_size * 56 * 1.15)),
             transforms_dct.CenterCrop(filter_size * 56),
             transforms_dct.GetDCT(filter_size),
             transforms_dct.UpScaleDCT(size=56),
             transforms_dct.ToTensorDCT(),
             transforms_dct.SubsetDCT(channels=24),
             transforms_dct.Aggregate(),
             transforms_dct.NormalizeDCT(
                 #  train_y_mean_resized,  train_y_std_resized,
                 #  train_cb_mean_resized, train_cb_std_resized,
                 #  train_cr_mean_resized, train_cr_std_resized),
                 train_upscaled_static_mean,
                 train_upscaled_static_std,
                 channels=24)
             #transforms_dct.Aggregate()
         ])
     else:
         transform = transforms_dct.Compose([  #transform_funcs,
             transforms_dct.RandomResizedCrop(filter_size * 56),
             transforms_dct.ImageJitter(self.jitter_param),
             transforms_dct.RandomHorizontalFlip(),
             transforms_dct.GetDCT(filter_size),
             transforms_dct.UpScaleDCT(size=56),
             transforms_dct.ToTensorDCT(),
             transforms_dct.SubsetDCT(channels=24),
             transforms_dct.Aggregate(),
             transforms_dct.NormalizeDCT(
                 #  train_y_mean_resized,  train_y_std_resized,
                 #  train_cb_mean_resized, train_cb_std_resized,
                 #  train_cr_mean_resized, train_cr_std_resized),
                 train_upscaled_static_mean,
                 train_upscaled_static_std,
                 channels=24)
         ])
     return transform
                                            train_y_std_resized,
                                            train_cb_mean_resized,
                                            train_cb_std_resized,
                                            train_cr_mean_resized,
                                            train_cr_std_resized),
                ])),
            batch_size=1,
            shuffle=False,
            num_workers=1,
            pin_memory=False)

        # train_dataset = ImageFolderDCT('/mnt/ssd/kai.x/dataset/ILSVRC2012/train', transforms.Compose([
        train_dataset = ImageFolderDCT(
            '/storage-t1/user/kaixu/datasets/ILSVRC2012/train',
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToYCrCb(),
                transforms.ChromaSubsample(),
                transforms.UpsampleDCT(size=224, T=896, debug=False),
                transforms.ToTensorDCT(),
                transforms.NormalizeDCT(train_y_mean_resized,
                                        train_y_std_resized,
                                        train_cb_mean_resized,
                                        train_cb_std_resized,
                                        train_cr_mean_resized,
                                        train_cr_std_resized),
            ]))

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=1,
Exemplo n.º 7
0
    #     transforms.NormalizeDCT(
    #         train_y_mean_resized, train_y_std_resized,
    #         train_cb_mean_resized, train_cb_std_resized,
    #         train_cr_mean_resized, train_cr_std_resized),
    # ])

    # transform3 =transforms.Compose([
    #     transforms.RandomResizedCrop(224),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.ResizedTransformDCT(),
    #     transforms.ToTensorDCT(),
    #     transforms.SubsetDCT(32),
    # ])

    transform4 = transforms.Compose([
        transforms.RandomResizedCrop(896),
        transforms.RandomHorizontalFlip(),
        transforms.Upscale(upscale_factor=2),
        transforms.TransformUpscaledDCT(),
        transforms.ToTensorDCT(),
        transforms.SubsetDCT(channels='24'),
        transforms.Aggregate(),
        transforms.NormalizeDCT(
            train_upscaled_static_mean,
            train_upscaled_static_std,
            channels='24'
        )
        ])

    transform5 = transforms.Compose([
        transforms.DCTFlatten2D(),