def initialize_dataset(config, dataset_name, dataset_id, split, input_size, mean, std):

    if dataset_name == 'omniglot':
        if dataset_id == '00':  # default
            # todo add transforms here?
            return OmniglotDataset(root_dir=config.dataset.root_dir,
                                   split=split)
    elif dataset_name == 'flowers':
        if dataset_id == '00':  # default
            # Setup Transforms instead of doing in the specific dataset class
            transforms = trns.Compose([trns.Resize((input_size, input_size)),
                                       trns.RandomHorizontalFlip(),
                                       trns.ToTensor(),
                                       trns.Normalize(mean=mean, std=std)  # normalise with model zoo
                                       ])

            return OxfordFlowersDataset(root_dir=config.dataset.root_dir,
                                        split=split,
                                        transform=transforms,
                                        categories_subset=config.dataset.classes)
    elif dataset_name == 'pets':
        if dataset_id == '00':  # default
            # Setup Transforms instead of doing in the specific dataset class
            transforms = trns.Compose([trns.Resize((input_size, input_size)),
                                       trns.RandomHorizontalFlip(),
                                       trns.ToTensor(),
                                       trns.Normalize(mean=mean, std=std)  # normalise with model zoo
                                       ])
            if split == 'train':
                split = 'trainval'
            return OxfordPetsDataset(root_dir=config.dataset.root_dir,
                                     split=split,
                                     transform=transforms,
                                     categories_subset=config.dataset.classes)
    elif dataset_name == 'dogs':
        if dataset_id == '00':  # default
            # Setup Transforms instead of doing in the specific dataset class
            transforms = trns.Compose([trns.Resize((input_size, input_size)),
                                       trns.ToTensor(),
                                       trns.Normalize(mean=mean, std=std)  # normalise with model zoo
                                       ])

            return StanfordDogsDataset(root_dir=config.dataset.root_dir,
                                       split=split,
                                       transform=transforms,
                                       categories_subset=config.dataset.classes)

    elif dataset_name == 'mnist':
        if dataset_id == '00':

            transforms = trns.Compose([trns.ToTensor()])

            if split == 'train':
                dset = MNIST(root=config.dataset.root_dir,
                             train=True,
                             transform=transforms,
                             download=True)
                dset.labels = dset.train_labels

            else:
                dset = MNIST(root=config.dataset.root_dir,
                             train=False,
                             transform=transforms,
                             download=True)
                dset.labels = dset.test_labels
            return dset

    elif dataset_name == 'voc':
        if dataset_id == '00':
            transforms = None
            if split == 'train':

                pv07 = PascalVOCDataset(root_dir=config.dataset.root_dir,
                                        split='trainval',
                                        year='2007',
                                        transform=transforms,
                                        categories_subset=config.dataset.classes,
                                        use_flipped=config.dataset.use_flipped,
                                        use_difficult=config.dataset.use_difficult)

                detection_set = DetectionWrapper(dataset=pv07,
                                                 batch_size=config.train.batch_size,
                                                 max_num_box=config.model.max_n_gt_boxes,
                                                 scales=config.train.scales,
                                                 max_size=config.train.max_size,
                                                 use_all_gt=config.train.use_all_gt,
                                                 training=True)

                return detection_set

            elif split == 'val':
                pv07 = PascalVOCDataset(root_dir=config.dataset.root_dir,
                                        split='test',
                                        year='2007',
                                        transform=transforms,
                                        categories_subset=config.dataset.classes,
                                        use_flipped=False,#config.dataset.use_flipped,
                                        use_difficult=config.dataset.use_difficult)


                detection_set = DetectionWrapper(dataset=pv07,
                                                 batch_size=config.train.batch_size,
                                                 max_num_box=config.model.max_n_gt_boxes,
                                                 scales=config.train.scales,
                                                 max_size=config.train.max_size,
                                                 use_all_gt=config.train.use_all_gt,
                                                 training=True)#False)

                return detection_set

            else:
                raise ValueError("Split '%s' not recognised for the %s dataset (id: %s)." % (split, dataset_name, dataset_id))
        elif dataset_id == '01':
            transforms = None
            if split == 'train':

                pv07 = PascalVOCDataset(root_dir=config.dataset.root_dir,
                                        split='trainval',
                                        year='2007',
                                        transform=transforms,
                                        categories_subset=config.dataset.classes,
                                        use_flipped=config.dataset.use_flipped,
                                        use_difficult=config.dataset.use_difficult)
                pv12 = PascalVOCDataset(root_dir=config.dataset.root_dir,
                                        split='trainval',
                                        year='2012',
                                        transform=transforms,
                                        categories_subset=config.dataset.classes,
                                        use_flipped=config.dataset.use_flipped,
                                        use_difficult=config.dataset.use_difficult)

                combined_set = CombinedDataset([pv07, pv12])

                detection_set = DetectionWrapper(dataset=combined_set,
                                                 batch_size=config.train.batch_size,
                                                 max_num_box=config.model.max_n_gt_boxes,
                                                 scales=config.train.scales,
                                                 max_size=config.train.max_size,
                                                 use_all_gt=config.train.use_all_gt,
                                                 training=True)

                return detection_set

            elif split == 'val':
                pv07 = PascalVOCDataset(root_dir=config.dataset.root_dir,
                                        split='test',
                                        year='2007',
                                        transform=transforms,
                                        categories_subset=config.dataset.classes,
                                        use_flipped=False,  # config.dataset.use_flipped,
                                        use_difficult=config.dataset.use_difficult)

                detection_set = DetectionWrapper(dataset=pv07,
                                                 batch_size=config.train.batch_size,
                                                 max_num_box=config.model.max_n_gt_boxes,
                                                 scales=config.train.scales,
                                                 max_size=config.train.max_size,
                                                 use_all_gt=config.train.use_all_gt,
                                                 training=True)  # False)

                return detection_set

            else:
                raise ValueError(
                    "Split '%s' not recognised for the %s dataset (id: %s)." % (split, dataset_name, dataset_id))
    else:
        raise ValueError("Dataset '%s' not recognised." % dataset_name)
Esempio n. 2
0
def load_dataset(args):
    '''
		Loads the dataset specified
	'''

    # MNIST dataset
    if args.dataset == 'mnist':
        trans_img = transforms.Compose([transforms.ToTensor()])

        print("Downloading MNIST data...")
        trainset = MNIST('./data',
                         train=True,
                         transform=trans_img,
                         download=True)
        testset = MNIST('./data',
                        train=False,
                        transform=trans_img,
                        download=True)

    # CIFAR-10 dataset
    if args.dataset == 'cifar10':
        # Data
        print('==> Preparing data..')
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = CIFAR10(root='./data',
                           train=True,
                           transform=transform_train,
                           download=True)
        testset = CIFAR10(root='./data',
                          train=False,
                          transform=transform_test,
                          download=True)

    if args.dataset == 'cifar100':
        # Data
        print('==> Preparing data..')
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = CIFAR100(root='./data',
                            train=True,
                            transform=transform_train,
                            download=True)
        testset = CIFAR100(root='./data',
                           train=False,
                           transform=transform_test,
                           download=True)

    if args.dataset == 'fashionmnist':

        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        trainset = FASHION(root='./data',
                           train=True,
                           transform=transform,
                           download=True)
        testset = FASHION(root='./data',
                          train=False,
                          transform=transform,
                          download=True)

    if args.dataset == 'svhn':
        train_transform = transforms.Compose([])

        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [109.9, 109.7, 113.8]],
            std=[x / 255.0 for x in [50.1, 50.6, 50.8]])

        train_transform.transforms.append(transforms.ToTensor())
        train_transform.transforms.append(normalize)

        trainset = SVHN(root='./data',
                        split='train',
                        transform=train_transform,
                        download=True)

        extra_dataset = SVHN(root='./data',
                             split='extra',
                             transform=train_transform,
                             download=True)

        # Combine both training splits, as is common practice for SVHN

        data = np.concatenate([trainset.data, extra_dataset.data], axis=0)
        labels = np.concatenate([trainset.labels, extra_dataset.labels],
                                axis=0)

        trainset.data = data
        trainset.labels = labels

        test_transform = transforms.Compose([transforms.ToTensor(), normalize])
        testset = SVHN(root='./data',
                       split='test',
                       transform=test_transform,
                       download=True)

    # Self-Paced Learning Enabled
    if args.spld:
        train_idx = np.arange(len(trainset))
        #numpy.random.shuffle(train_idx)
        n_train = len(train_idx)
        train_sampler = SubsetSequentialSamplerSPLDML(range(len(trainset)),
                                                      range(len(trainset)))
        trainloader = DataLoader(trainset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=4,
                                 sampler=train_sampler)

        testloader = DataLoader(testset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=4)
    elif args.spldml:
        n_train = len(trainset)
        train_sampler = SubsetSequentialSamplerSPLDML(range(len(trainset)),
                                                      range(args.batch_size))
        trainloader = DataLoader(trainset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=1,
                                 sampler=train_sampler)

        testloader = DataLoader(testset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=1)
    # Deep Metric Learning
    elif args.dml:
        n_train = len(trainset)
        train_sampler = SubsetSequentialSampler(range(len(trainset)),
                                                range(args.batch_size))
        trainloader = DataLoader(trainset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=1,
                                 sampler=train_sampler)

        testloader = DataLoader(testset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=1)
    elif args.stratified:
        n_train = len(trainset)
        labels = getattr(trainset, 'train_labels')

        if isinstance(labels, list):
            labels = torch.FloatTensor(np.array(labels))

        train_sampler = StratifiedSampler(labels, args.batch_size)
        trainloader = DataLoader(trainset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=4,
                                 sampler=train_sampler)

        testloader = DataLoader(testset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=4)
    # Random sampling
    else:
        n_train = len(trainset)
        trainloader = DataLoader(trainset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=4)

        testloader = DataLoader(testset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=4)

    return trainloader, testloader, trainset, testset, n_train