コード例 #1
0
def trainloader_dct_subset(args):
    traindir = os.path.join(args.data, 'train')
    train_dataset = ImageFolderDCT(traindir, transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.TransformDCT(),
        transforms.ToTensorDCT(),
        transforms.SubsetDCT(args.subset_channels),
        transforms.NormalizeDCT(
            train_y_mean, train_y_std,
            train_cb_mean, train_cb_std,
            train_cr_mean, train_cr_std),
    ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    train_loader_len = len(train_loader)

    return train_loader, train_sampler, train_loader_len
コード例 #2
0
def trainloader_dct_resized(args):
    traindir = os.path.join(args.data, 'train')
    train_dataset = ImageFolderDCT(traindir, transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.TransformDCT(),  # 28x28x192
        transforms.DCTFlatten2D(),
        transforms.UpsampleDCT(upscale_ratio_h=4, upscale_ratio_w=4, debug=False),
        transforms.ToTensorDCT(),
        transforms.SubsetDCT(channels=args.subset),
        transforms.Aggregate(),
        transforms.NormalizeDCT(
            train_dct_subset_mean,
            train_dct_subset_std,
            channels=args.subset
        )
    ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    train_loader_len = len(train_loader)

    return train_loader, train_sampler, train_loader_len
コード例 #3
0
def valloader_dct(args):
    valdir = os.path.join(args.data, 'val')

    val_loader = torch.utils.data.DataLoader(
        ImageFolderDCT(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.TransformDCT(),
            transforms.ToTensorDCT(),
            transforms.NormalizeDCT(
                train_y_mean, train_y_std,
                train_cb_mean, train_cb_std,
                train_cr_mean, train_cr_std),
        ])),
        batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    return val_loader
コード例 #4
0
def valloader_dct_resized(args):
    valdir = os.path.join(args.data, 'val')

    val_loader = torch.utils.data.DataLoader(
        ImageFolderDCT(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.TransformDCT(),  # 28x28x192
            transforms.DCTFlatten2D(),
            transforms.UpsampleDCT(upscale_ratio_h=4, upscale_ratio_w=4, debug=False),
            transforms.ToTensorDCT(),
            transforms.SubsetDCT(channels=args.subset),
            transforms.Aggregate(),
            transforms.NormalizeDCT(
                train_dct_subset_mean,
                train_dct_subset_std,
                channels=args.subset
            )
        ])),
        batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    return val_loader
コード例 #5
0
        input_normalize_y = transforms.Normalize(mean=train_y_mean_resized,
                                                 std=train_y_std_resized)
        input_normalize_cb = transforms.Normalize(mean=train_cb_mean_resized,
                                                  std=train_cb_std_resized)
        input_normalize_cr = transforms.Normalize(mean=train_cr_mean_resized,
                                                  std=train_cr_std_resized)
        input_normalize.append(input_normalize_y)
        input_normalize.append(input_normalize_cb)
        input_normalize.append(input_normalize_cr)
        val_loader = torch.utils.data.DataLoader(
            # ImageFolderDCT('/mnt/ssd/kai.x/dataset/ILSVRC2012/val', transforms.Compose([
            ImageFolderDCT(
                '/storage-t1/user/kaixu/datasets/ILSVRC2012/val',
                transforms.Compose([
                    transforms.ToYCrCb(),
                    transforms.TransformDCT(),
                    transforms.UpsampleDCT(T=896, debug=False),
                    transforms.CenterCropDCT(112),
                    transforms.ToTensorDCT(),
                    transforms.NormalizeDCT(train_y_mean_resized,
                                            train_y_std_resized,
                                            train_cb_mean_resized,
                                            train_cb_std_resized,
                                            train_cr_mean_resized,
                                            train_cr_std_resized),
                ])),
            batch_size=1,
            shuffle=False,
            num_workers=1,
            pin_memory=False)