예제 #1
0
def get_train_loader(args):
    """get the train loader"""
    if args.view == 'Lab' or args.view == 'YCbCr':
        if args.view == 'Lab':
            mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2,
                    (-107.857 + 94.478) / 2]
            std = [(100 - 0) / 2, (86.183 + 98.233) / 2,
                   (107.857 + 94.478) / 2]
            color_transfer = RGB2Lab()
        else:
            mean = [116.151, 121.080, 132.342]
            std = [109.500, 111.855, 111.964]
            color_transfer = RGB2YCbCr()
        normalize = transforms.Normalize(mean=mean, std=std)
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),  # maybe not necessary
            transforms.RandomHorizontalFlip(),
            color_transfer,
            transforms.ToTensor(),
            normalize,
        ])
    else:
        print('Use RGB images with %s level %s!' %
              (args.view, str(args.level)))
        NORM = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        normalize = transforms.Normalize(*NORM)
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),  # maybe not necessary
            # transforms.Resize(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # normalize
        ])

    train_dataset = datasets.CIFAR10(root=args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=train_transform)
    train_sampler = None

    if args.oracle != 'original':
        train_raw = np.load(args.data_folder +
                            '/CIFAR-10-C-trainval/train/%s_4_images.npy' %
                            (args.oracle))
        train_dataset.data = train_raw

    # train loader
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    # num of samples
    n_data = len(train_dataset)
    print('number of samples: {}'.format(n_data))

    return train_loader, n_data
예제 #2
0
def get_train_loader(args):
    """get the train loader"""
    if 'imagenet' in args.dataset:
        data_folder = os.path.join(args.data_folder, 'train')

        if args.view == 'Lab':
            mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2]
            std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]
            color_transfer = RGB2Lab()
        elif args.view == 'YCbCr':
            mean = [116.151, 121.080, 132.342]
            std = [109.500, 111.855, 111.964]
            color_transfer = RGB2YCbCr()
        else:
            raise NotImplemented('view not implemented {}'.format(args.view))
        normalize = transforms.Normalize(mean=mean, std=std)

        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.)),
            transforms.RandomHorizontalFlip(),
            color_transfer,
            transforms.ToTensor(),
            normalize,
        ])
        train_dataset = ImageFolderInstance(data_folder, transform=train_transform)
    else:
        assert args.dataset == 'stl10'
        assert args.view == 'Lab'

        mean = [(0 + 100) / 2,
                (-86.183 + 98.233) / 2,
                (-107.857 + 94.478) / 2]
        std = [(100 - 0) / 2,
               (86.183 + 98.233) / 2,
               (107.857 + 94.478) / 2]
        train_transform = transforms.Compose([
            # transforms.RandomCrop(64),
            transforms.RandomResizedCrop(64, scale=(args.crop_low, 1)),
            transforms.RandomHorizontalFlip(),
            RGB2Lab(),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
        ])
        train_dataset = datasets.STL10(
            args.data_folder, 'train+unlabeled',
            transform=train_transform, download=True)
        train_dataset = DatasetInstance(train_dataset)

    train_sampler = None
    # train loader
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)

    # num of samples
    n_data = len(train_dataset)
    print('number of samples: {}'.format(n_data))

    return train_loader, n_data
예제 #3
0
def get_val_loader(args):
	common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur',
							'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
							'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']
	print('Use %s!' %(args.view))
	if args.view == 'Lab' or args.view == 'YCbCr':
		if args.view == 'Lab':
			mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2]
			std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]
			color_transfer = RGB2Lab()
		else:
			mean = [116.151, 121.080, 132.342]
			std = [109.500, 111.855, 111.964]
			color_transfer = RGB2YCbCr()
		normalize = transforms.Normalize(mean=mean, std=std)

		te_transform = transforms.Compose([
			# transforms.RandomCrop(32, padding=4), # maybe not necessary
			color_transfer,
			transforms.ToTensor(),
			normalize,
		])
		val_dataset = datasets.CIFAR10(
			root=args.data_folder,
			train=False,
			transform=te_transform
		)
		if args.corruption in common_corruptions:
			print('Test on %s!' %(args.corruption))
			teset_raw = np.load(args.data_folder + '/CIFAR-10-C-trainval/val/%s_%s_images.npy' %(args.corruption, str(args.level - 1)))
			teset_raw = color_transfer( Image.fromarray(teset_raw.astype(np.uint8)) )
			val_dataset.data = teset_raw
	else:
		NORM = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
		normalize = transforms.Normalize(*NORM)
		te_transform = transforms.Compose([
			# transforms.RandomCrop(32, padding=4), # maybe not necessary
			transforms.ToTensor(),
			normalize,
		])
		val_dataset = datasets.CIFAR10(
			root=args.data_folder,
			train=False,
			transform=te_transform
		)
		if args.corruption in common_corruptions:
			print('Test on %s!' %(args.corruption))
			teset_raw = np.load(args.data_folder + '/CIFAR-10-C-trainval/val/%s_%s_images.npy' %(args.corruption, str(args.level - 1)))
			val_dataset.data = teset_raw

	print('number of val: {}'.format(len(val_dataset)))

	val_loader = torch.utils.data.DataLoader(
		val_dataset, batch_size=args.batch_size, shuffle=False,
		num_workers=args.num_workers, pin_memory=True)

	return val_loader
예제 #4
0
def get_train_val_loader(args):
    train_folder = os.path.join(args.data_folder, 'train')
    val_folder = os.path.join(args.data_folder, 'val_in_folders')

    if args.view == 'Lab':
        mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2]
        std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]
        color_transfer = RGB2Lab()
    elif args.view == 'YCbCr':
        mean = [116.151, 121.080, 132.342]
        std = [109.500, 111.855, 111.964]
        color_transfer = RGB2YCbCr()
    elif args.view == 'temporal':  #Use Lab for comparison
        mean = [0.4493, 0.4348, 0.3970]
        std = [0.3030, 0.3001, 0.3016]
        if args.distort == True:
            color_transfer = get_color_distortion()
        else:
            color_transfer = RGB2Lab()
    else:
        raise NotImplemented('view not implemented {}'.format(args.view))

    normalize = transforms.Normalize(mean=mean, std=std)
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.)),
        transforms.RandomHorizontalFlip(),
        color_transfer,
        transforms.ToTensor(),
        normalize,
    ])

    train_dataset = datasets.ImageFolder(train_folder,
                                         transform=train_transform)

    val_dataset = datasets.ImageFolder(val_folder, transform=train_transform)

    print('number of train: {}'.format(len(train_dataset)))
    print('number of val: {}'.format(len(val_dataset)))

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    return train_loader, val_loader, train_sampler
예제 #5
0
def get_train_loader(args):
	"""get the train loader"""
	data_folder = os.path.join(args.data_folder, 'train')

	if args.view == 'Lab' or args.view == 'YCbCr':
		if args.view == 'Lab':
			mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2]
			std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]
			color_transfer = RGB2Lab()
		else:
			mean = [116.151, 121.080, 132.342]
			std = [109.500, 111.855, 111.964]
			color_transfer = RGB2YCbCr()
		normalize = transforms.Normalize(mean=mean, std=std)
		train_transform = transforms.Compose([
			# transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.)), # 224 -> 32
			transforms.RandomCrop(32, padding=4), # maybe not necessary
			transforms.RandomHorizontalFlip(),
			color_transfer,
			transforms.ToTensor(),
			normalize,
		])
	else:
		print('Use RGB images with %s!' %(args.view))
		NORM = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
		normalize = transforms.Normalize(*NORM)
		color_transfer = RGB()

		train_transform = transforms.Compose([
			transforms.ToTensor()
		])
	
	train_dataset = datasets.CIFAR10(root=args.data_folder,
		train=True, download=True, transform=train_transform)
	train_sampler = None

	# train loader
	train_loader = torch.utils.data.DataLoader(
		train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
		num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)

	# num of samples
	n_data = len(train_dataset)
	print('number of samples: {}'.format(n_data))

	return train_loader, n_data
예제 #6
0
파일: train_CMC.py 프로젝트: mhw32/CMC
def get_train_loader(args):
    """get the train loader"""
    data_folder = os.path.join(args.data_folder, 'train')

    if args.view == 'Lab':
        mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2]
        std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]
        color_transfer = RGB2Lab()
    elif args.view == 'YCbCr':
        mean = [116.151, 121.080, 132.342]
        std = [109.500, 111.855, 111.964]
        color_transfer = RGB2YCbCr()
    else:
        raise NotImplemented('view not implemented {}'.format(args.view))
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.)),
        transforms.RandomHorizontalFlip(),
        color_transfer,
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = ImageFolderInstance(data_folder, transform=train_transform)
    train_sampler = None

    train_samples = train_dataset.dataset.samples
    train_labels = [train_samples[i][1] for i in range(len(train_samples))]
    train_ordered_labels = np.array(train_labels)

    # train loader
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    # num of samples
    n_data = len(train_dataset)
    print('number of samples: {}'.format(n_data))

    return train_loader, train_ordered_labels, n_data
예제 #7
0
파일: train_CMC.py 프로젝트: amsword/CMC
def get_train_loader(args):
    if args.view == 'Lab':
        mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2]
        std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]
        color_transfer = RGB2Lab()
    elif args.view == 'YCbCr':
        mean = [116.151, 121.080, 132.342]
        std = [109.500, 111.855, 111.964]
        color_transfer = RGB2YCbCr()
    else:
        raise NotImplementedError('view not implemented {}'.format(args.view))
    normalize = transforms.Normalize(mean=mean, std=std)

    from qd.qd_pytorch import BGR2RGB
    train_transform = transforms.Compose([
        BGR2RGB(),
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.)),
        transforms.RandomHorizontalFlip(),
        color_transfer,
        transforms.ToTensor(),
        normalize,
    ])
    #train_dataset = ImageFolderInstance(data_folder, transform=train_transform)
    from dataset import TSVMultiviewDataset
    train_dataset = TSVMultiviewDataset(args.data_folder,
            split='train',
            version=0,
            transform=train_transform)
    train_sampler = None

    # train loader
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)

    # num of samples
    n_data = len(train_dataset)
    print('number of samples: {}'.format(n_data))

    return train_loader, n_data
예제 #8
0
파일: train_CMC.py 프로젝트: mhw32/CMC
def get_test_loader(args):
    """get the train loader"""
    data_folder = os.path.join(args.data_folder, 'validation')

    if args.view == 'Lab':
        mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2]
        std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]
        color_transfer = RGB2Lab()
    elif args.view == 'YCbCr':
        mean = [116.151, 121.080, 132.342]
        std = [109.500, 111.855, 111.964]
        color_transfer = RGB2YCbCr()
    else:
        raise NotImplemented('view not implemented {}'.format(args.view))
    normalize = transforms.Normalize(mean=mean, std=std)

    test_transform = transforms.Compose([
        transforms.Resize(256),  # FIXME: hardcoded for 224 image size
        transforms.CenterCrop(image_size),
        color_transfer,
        transforms.ToTensor(),
        normalize,
    ])
    test_dataset = ImageFolderInstance(data_folder, transform=test_transform)
    test_sampler = None

    # train loader
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=(test_sampler is None),
                                              num_workers=args.num_workers,
                                              pin_memory=True,
                                              sampler=test_sampler)

    # num of samples
    n_data = len(test_dataset)
    print('number of samples: {}'.format(n_data))

    return test_loader, n_data
예제 #9
0
def get_train_val_loader(args):
    if 'imagenet' in args.dataset:
        train_folder = os.path.join(args.data_folder, 'train')
        val_folder = os.path.join(args.data_folder, 'val')

        if args.view == 'Lab':
            mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2,
                    (-107.857 + 94.478) / 2]
            std = [(100 - 0) / 2, (86.183 + 98.233) / 2,
                   (107.857 + 94.478) / 2]
            color_transfer = RGB2Lab()
        elif args.view == 'YCbCr':
            mean = [116.151, 121.080, 132.342]
            std = [109.500, 111.855, 111.964]
            color_transfer = RGB2YCbCr()
        else:
            raise NotImplemented('view not implemented {}'.format(args.view))

        normalize = transforms.Normalize(mean=mean, std=std)
        train_dataset = datasets.ImageFolder(
            train_folder,
            transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.0)),
                transforms.RandomHorizontalFlip(),
                color_transfer,
                transforms.ToTensor(),
                normalize,
            ]))
        val_dataset = datasets.ImageFolder(
            val_folder,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                color_transfer,
                transforms.ToTensor(),
                normalize,
            ]))
    else:
        assert args.dataset == 'stl10'
        assert args.view == 'Lab'

        mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2]
        std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]
        train_transform = transforms.Compose([
            # transforms.RandomCrop(64),
            # transforms.RandomResizedCrop(64, scale=(args.crop_low, 1)),
            transforms.RandomResizedCrop(64,
                                         scale=(0.3, 1.0),
                                         ratio=(0.7, 1.4)),
            transforms.RandomHorizontalFlip(),
            RGB2Lab(),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
        ])
        test_transform = transforms.Compose([
            transforms.Resize(70),
            transforms.CenterCrop(64),
            RGB2Lab(),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
        ])

        train_dataset = datasets.STL10(args.data_folder,
                                       'train',
                                       transform=train_transform,
                                       download=True)
        val_dataset = datasets.STL10(args.data_folder,
                                     'test',
                                     transform=test_transform,
                                     download=True)

    print('number of train: {}'.format(len(train_dataset)))
    print('number of val: {}'.format(len(val_dataset)))

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    return train_loader, val_loader, train_sampler
예제 #10
0
def get_train_val_loader(args):
    if args.view == 'Lab' or args.view == 'YCbCr':
        if args.view == 'Lab':
            mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2,
                    (-107.857 + 94.478) / 2]
            std = [(100 - 0) / 2, (86.183 + 98.233) / 2,
                   (107.857 + 94.478) / 2]
            color_transfer = RGB2Lab()
        else:
            mean = [116.151, 121.080, 132.342]
            std = [109.500, 111.855, 111.964]
            color_transfer = RGB2YCbCr()
        normalize = transforms.Normalize(mean=mean, std=std)
        train_dataset = datasets.CIFAR10(
            root=args.data_folder,
            train=True,
            transform=transforms.Compose([
                # transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.0)),
                # transforms.RandomCrop(32, padding=4), # maybe not necessary
                transforms.RandomHorizontalFlip(),
                color_transfer,
                transforms.ToTensor(),
                normalize,
            ]))
        val_dataset = datasets.CIFAR10(
            root=args.data_folder,
            train=False,
            transform=transforms.Compose([
                # transforms.RandomCrop(32, padding=4), # maybe not necessary
                color_transfer,
                transforms.ToTensor(),
                normalize,
            ]))
    else:
        print('Use RGB images with %s!' % (args.view))
        NORM = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        normalize = transforms.Normalize(*NORM)
        train_dataset = datasets.CIFAR10(
            root=args.data_folder,
            train=True,
            transform=transforms.Compose([
                # transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.0)),
                # transforms.RandomCrop(32, padding=4), # maybe not necessary
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        val_dataset = datasets.CIFAR10(
            root=args.data_folder,
            train=False,
            transform=transforms.Compose([
                # transforms.RandomCrop(32, padding=4), # maybe not necessary
                transforms.ToTensor(),
                normalize,
            ]))
    print('number of train: {}'.format(len(train_dataset)))
    print('number of val: {}'.format(len(val_dataset)))

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    return train_loader, val_loader, train_sampler
예제 #11
0
def get_train_val_loader(args):
	if args.view == 'Lab' or args.view == 'YCbCr':
		if args.view == 'Lab':
			mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2]
			std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]
			color_transfer = RGB2Lab()
		else:
			mean = [116.151, 121.080, 132.342]
			std = [109.500, 111.855, 111.964]
			color_transfer = RGB2YCbCr()
		normalize = transforms.Normalize(mean=mean, std=std)
		train_dataset = datasets.CIFAR10(
			root=args.data_folder,
			train=True,
			transform=transforms.Compose([
				# transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.0)),
				transforms.RandomCrop(32, padding=4), # maybe not necessary
				transforms.RandomHorizontalFlip(),
				color_transfer,
				transforms.ToTensor(),
				normalize,
			])
		)
		val_dataset = datasets.CIFAR10(
			root=args.data_folder,
			train=False,
			transform=transforms.Compose([
				color_transfer,
				transforms.ToTensor(),
				normalize,
			])
		)
		if args.oracle != 'original':
			if args.oracle != 'scale':
				train_raw = np.load(args.data_folder + '/CIFAR-10-C-trainval/train/%s_4_images.npy' %(args.oracle))
				train_dataset.data = train_raw
				val_raw = np.load(args.data_folder + '/CIFAR-10-C-trainval/val/%s_4_images.npy' %(args.oracle))
				val_dataset.data = val_raw
			else:
				train_raw = np.load(args.data_folder + '/CIFAR-10-C-trainval/train/upsample_4_images.npy')
				train_dataset.data = train_raw
				val_raw = np.load(args.data_folder + '/CIFAR-10-C-trainval/val/upsample_4_images.npy')
				val_dataset.data = val_raw
	else:
		print('Use RGB images with %s level %s!' %(args.view, str(args.level)))
		NORM = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
		normalize_lst = lambda x: list(map(transforms.Normalize(*NORM), x))
		data_augmentation = create_augmentation(args.view, args.level)
		train_transform = transforms.Compose([
			transforms.RandomCrop(32, padding=4), # maybe not necessary
			# transforms.Resize(224),
			transforms.RandomHorizontalFlip(),
			transforms.ToTensor(),
			data_augmentation,
			normalize_lst
		])
		train_dataset = datasets.CIFAR10(root=args.data_folder,
			train=True, download=True, transform=train_transform)

		val_transform = transforms.Compose([
			transforms.ToTensor(),
			data_augmentation,
			normalize_lst
		])
		val_dataset = datasets.CIFAR10(root=args.data_folder,
			train=False, download=True, transform=val_transform)

	print('number of train: {}'.format(len(train_dataset)))
	print('number of val: {}'.format(len(val_dataset)))

	train_sampler = None

	train_loader = torch.utils.data.DataLoader(
		train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
		num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)

	val_loader = torch.utils.data.DataLoader(
		val_dataset, batch_size=args.batch_size, shuffle=False,
		num_workers=args.num_workers, pin_memory=True)

	return train_loader, val_loader, train_sampler
예제 #12
0
def get_val_loader(args):
    common_corruptions = [
        'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
        'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
        'brightness', 'contrast', 'elastic_transform', 'pixelate',
        'jpeg_compression', 'scale'
    ]
    print('Use %s %s!' % (args.view, str(args.level)))
    if args.view == 'Lab' or args.view == 'YCbCr':
        if args.view == 'Lab':
            mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2,
                    (-107.857 + 94.478) / 2]
            std = [(100 - 0) / 2, (86.183 + 98.233) / 2,
                   (107.857 + 94.478) / 2]
            color_transfer = RGB2Lab()
        else:
            mean = [116.151, 121.080, 132.342]
            std = [109.500, 111.855, 111.964]
            color_transfer = RGB2YCbCr()
        normalize = transforms.Normalize(mean=mean, std=std)

        te_transform = transforms.Compose([
            color_transfer,
            transforms.ToTensor(),
            normalize,
        ])
        val_dataset = datasets.CIFAR10(root=args.data_folder,
                                       train=False,
                                       transform=te_transform)
        if args.corruption in common_corruptions:
            print('Test on %s!' % (args.corruption))
            if args.corruption == 'scale':
                teset_raw = np.load(
                    args.data_folder +
                    '/CIFAR-10-C-trainval/val/%s_%s_images.npy' %
                    ('upsample', str(args.test_level)))
            else:
                teset_raw = np.load(
                    args.data_folder +
                    '/CIFAR-10-C-trainval/val/%s_%s_images.npy' %
                    (args.corruption, str(args.test_level - 1)))
            val_dataset.data = teset_raw
    else:
        NORM = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        normalize_lst = lambda x: list(map(transforms.Normalize(*NORM), x))
        data_augmentation = create_augmentation(args.view, args.level)
        val_transform = transforms.Compose(
            [transforms.ToTensor(), data_augmentation, normalize_lst])
        val_dataset = datasets.CIFAR10(root=args.data_folder,
                                       train=False,
                                       download=True,
                                       transform=val_transform)

        if args.corruption in common_corruptions:
            print('Test on %s!' % (args.corruption))
            if args.corruption == 'scale':
                teset_raw = np.load(
                    args.data_folder +
                    '/CIFAR-10-C-trainval/val/%s_%s_images.npy' %
                    ('upsample', str(args.test_level)))
            else:
                teset_raw = np.load(
                    args.data_folder +
                    '/CIFAR-10-C-trainval/val/%s_%s_images.npy' %
                    (args.corruption, str(args.test_level - 1)))
            val_dataset.data = teset_raw

    print('number of val: {}'.format(len(val_dataset)))

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    return val_loader
예제 #13
0
def get_train_loader(args):
    """get the train loader"""
    data_folder = os.path.join(args.data_folder, 'train')

    if args.view == 'Lab':
        mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2]
        std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]
        color_transfer = RGB2Lab()
    elif args.view == 'YCbCr':
        mean = [116.151, 121.080, 132.342]
        std = [109.500, 111.855, 111.964]
        color_transfer = RGB2YCbCr()
    else:
        raise NotImplemented('view not implemented {}'.format(args.view))
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.)),
        transforms.RandomHorizontalFlip(),
        color_transfer,
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = ImageFolderInstance(data_folder, transform=train_transform)
    train_sampler = None
    if args.IM:
        # print("using IM space.................")
        if args.IM_type == 'IM':
            print("using IM space.................")
            train_dataset = IM(train_dataset,
                               g_alpha=args.g_alpha,
                               g_num_mix=args.g_num,
                               g_prob=args.g_prob,
                               r_beta=args.r_beta,
                               r_prob=args.r_prob,
                               r_num_mix=args.r_num,
                               r_decay=args.r_pixel_decay)
        if args.IM_type == 'global':
            print("using global space.................")
            train_dataset = global_(train_dataset,
                                    g_alpha=args.g_alpha,
                                    g_num_mix=args.g_num,
                                    g_prob=args.g_prob)
        if args.IM_type == 'region':
            print("using region space.................")
            train_dataset = region(train_dataset,
                                   r_beta=args.r_beta,
                                   r_prob=args.r_prob,
                                   r_num_mix=args.r_num,
                                   r_decay=args.r_pixel_decay)
        if args.IM_type == 'Cutout':
            print("using Cutout aug.................")
            train_dataset = Cutout(train_dataset,
                                   mask_size=args.mask_size,
                                   p=args.cutout_p,
                                   cutout_inside=args.cutout_inside,
                                   mask_color=args.mask_color)
        if args.IM_type == 'RandomErasing':
            print("using RandomErasing aug.................")
            train_dataset = RandomErasing(
                train_dataset,
                p=args.random_erasing_prob,
                area_ratio_range=args.area_ratio_range,
                min_aspect_ratio=args.min_aspect_ratio,
                max_attempt=args.max_attempt)

    # train loader
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    # num of samples
    n_data = len(train_dataset)
    print('number of samples: {}'.format(n_data))

    return train_loader, n_data