Esempio n. 1
0
def get_data(dataset,
             dataroot,
             augment,
             resize=608,
             split=0.15,
             split_idx=0,
             multinode=False,
             target_lb=-1):

    transform_train = transforms.Compose(
        [Normalizer(), Augmenter(),
         Resizer(min_side=resize)])
    transform_test = transforms.Compose(
        [Normalizer(), Resizer(min_side=resize)])

    if isinstance(C.get().aug, list):
        logger.debug('augmentation provided.')
        policies = policy_decoder(augment, augment['num_policy'],
                                  augment['num_op'])
        transform_train.transforms.insert(
            0, Augmentation(policies, detection=True))

    if dataset == 'coco':
        total_trainset = CocoDataset(dataroot,
                                     set_name='train',
                                     transform=transform_train)
        testset = CocoDataset(dataroot,
                              set_name='val',
                              transform=transform_test)

    return total_trainset, testset
    def preprocess(self, dataset='csv', csv_train=None, csv_val=None, csv_classes=None, coco_path=False,
                   train_set_name='train2017', val_set_name='val2017', resize=608):
        self.dataset = dataset
        if self.dataset == 'coco':
            if coco_path is None:
                raise ValueError('Must provide --home_path when training on COCO,')
            self.dataset_train = CocoDataset(self.home_path, set_name=train_set_name,
                                             transform=transforms.Compose(
                                                 [Normalizer(), Augmenter(), Resizer(min_side=resize)]))
            self.dataset_val = CocoDataset(self.home_path, set_name=val_set_name,
                                           transform=transforms.Compose([Normalizer(), Resizer(min_side=resize)]))

        elif self.dataset == 'csv':
            if csv_train is None:
                raise ValueError('Must provide --csv_train when training on COCO,')
            if csv_classes is None:
                raise ValueError('Must provide --csv_classes when training on COCO,')
            self.dataset_train = CSVDataset(train_file=csv_train, class_list=csv_classes,
                                            transform=transforms.Compose(
                                                [Normalizer(), Augmenter(), Resizer(min_side=resize)])
                                            )

            if csv_val is None:
                self.dataset_val = None
                print('No validation annotations provided.')
            else:
                self.dataset_val = CSVDataset(train_file=csv_val, class_list=csv_classes,
                                              transform=transforms.Compose([Normalizer(), Resizer(min_side=resize)]))
        else:
            raise ValueError('Dataset type not understood (must be csv or coco), exiting.')

        sampler = AspectRatioBasedSampler(self.dataset_train, batch_size=2, drop_last=False)
        self.dataloader_train = DataLoader(self.dataset_train, num_workers=0, collate_fn=collater,
                                           batch_sampler=sampler)
        if self.dataset_val is not None:
            sampler_val = AspectRatioBasedSampler(self.dataset_val, batch_size=1, drop_last=False)
            self.dataloader_val = DataLoader(self.dataset_val, num_workers=3, collate_fn=collater,
                                             batch_sampler=sampler_val)

        print('Num training images: {}'.format(len(self.dataset_train)))
        if len(self.dataset_val) == 0:
            raise Exception('num val images is 0!')
        print('Num val images: {}'.format(len(self.dataset_val)))
Esempio n. 3
0
def get_dataloaders(dataset,
                    batch,
                    dataroot,
                    resize=608,
                    split=0.15,
                    split_idx=0,
                    multinode=False,
                    target_lb=-1):
    multilabel = False
    detection = False
    if 'coco' in dataset:
        transform_train = transforms.Compose(
            [Normalizer(), Augmenter(),
             Resizer(min_side=resize)])
        transform_test = transforms.Compose(
            [Normalizer(), Resizer(min_side=resize)])

        multilabel = True
        detection = True
    elif 'cifar' in dataset or 'svhn' in dataset:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
        ])
    elif 'imagenet' in dataset:
        input_size = 224
        sized_size = 256

        if 'efficientnet' in C.get()['model']:
            input_size = EfficientNet.get_image_size(C.get()['model'])
            sized_size = input_size + 32  # TODO
            # sized_size = int(round(input_size / 224. * 256))
            # sized_size = input_size
            logger.info('size changed to %d/%d.' % (input_size, sized_size))

        transform_train = transforms.Compose([
            EfficientNetRandomCrop(input_size),
            transforms.Resize((input_size, input_size),
                              interpolation=Image.BICUBIC),
            # transforms.RandomResizedCrop(input_size, scale=(0.1, 1.0), interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(
                brightness=0.4,
                contrast=0.4,
                saturation=0.4,
            ),
            transforms.ToTensor(),
            Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        transform_test = transforms.Compose([
            EfficientNetCenterCrop(input_size),
            transforms.Resize((input_size, input_size),
                              interpolation=Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    else:
        raise ValueError('dataset=%s' % dataset)

    total_aug = augs = None
    if isinstance(C.get().aug, list):
        logger.debug('augmentation provided.')
        transform_train.transforms.insert(
            0, Augmentation(C.get().aug, detection=detection))
    else:
        logger.debug('augmentation: %s' % C.get().aug)
        if C.get().aug == 'fa_reduced_cifar10':
            transform_train.transforms.insert(
                0, Augmentation(fa_reduced_cifar10()))

        elif C.get().aug == 'fa_reduced_imagenet':
            transform_train.transforms.insert(
                0, Augmentation(fa_resnet50_rimagenet()))

        elif C.get().aug == 'fa_reduced_svhn':
            transform_train.transforms.insert(0,
                                              Augmentation(fa_reduced_svhn()))

        elif C.get().aug == 'arsaug':
            transform_train.transforms.insert(0, Augmentation(arsaug_policy()))
        elif C.get().aug == 'autoaug_cifar10':
            transform_train.transforms.insert(
                0, Augmentation(autoaug_paper_cifar10()))
        elif C.get().aug == 'autoaug_extend':
            transform_train.transforms.insert(0,
                                              Augmentation(autoaug_policy()))
        elif C.get().aug in ['default']:
            pass
        else:
            raise ValueError('not found augmentations. %s' % C.get().aug)

    if C.get()['cutout'] > 0:
        transform_train.transforms.append(CutoutDefault(C.get()['cutout']))

    if dataset == 'coco':
        total_trainset = CocoDataset(dataroot,
                                     set_name='train2017',
                                     transform=transform_train)
        testset = CocoDataset(dataroot,
                              set_name='val2017',
                              transform=transform_test)

    elif dataset == 'cifar10':
        total_trainset = torchvision.datasets.CIFAR10(
            root=dataroot,
            train=True,
            download=True,
            transform=transform_train)
        testset = torchvision.datasets.CIFAR10(root=dataroot,
                                               train=False,
                                               download=True,
                                               transform=transform_test)
    elif dataset == 'reduced_cifar10':
        total_trainset = torchvision.datasets.CIFAR10(
            root=dataroot,
            train=True,
            download=True,
            transform=transform_train)
        sss = StratifiedShuffleSplit(n_splits=1,
                                     test_size=46000,
                                     random_state=0)  # 4000 trainset
        sss = sss.split(list(range(len(total_trainset))),
                        total_trainset.targets)
        train_idx, valid_idx = next(sss)
        targets = [total_trainset.targets[idx] for idx in train_idx]
        total_trainset = Subset(total_trainset, train_idx)
        total_trainset.targets = targets

        testset = torchvision.datasets.CIFAR10(root=dataroot,
                                               train=False,
                                               download=True,
                                               transform=transform_test)
    elif dataset == 'cifar100':
        total_trainset = torchvision.datasets.CIFAR100(
            root=dataroot,
            train=True,
            download=True,
            transform=transform_train)
        testset = torchvision.datasets.CIFAR100(root=dataroot,
                                                train=False,
                                                download=True,
                                                transform=transform_test)
    elif dataset == 'svhn':
        trainset = torchvision.datasets.SVHN(root=dataroot,
                                             split='train',
                                             download=True,
                                             transform=transform_train)
        extraset = torchvision.datasets.SVHN(root=dataroot,
                                             split='extra',
                                             download=True,
                                             transform=transform_train)
        total_trainset = ConcatDataset([trainset, extraset])
        testset = torchvision.datasets.SVHN(root=dataroot,
                                            split='test',
                                            download=True,
                                            transform=transform_test)
    elif dataset == 'reduced_svhn':
        total_trainset = torchvision.datasets.SVHN(root=dataroot,
                                                   split='train',
                                                   download=True,
                                                   transform=transform_train)
        sss = StratifiedShuffleSplit(n_splits=1,
                                     test_size=73257 - 1000,
                                     random_state=0)  # 1000 trainset
        sss = sss.split(list(range(len(total_trainset))),
                        total_trainset.targets)
        train_idx, valid_idx = next(sss)
        targets = [total_trainset.targets[idx] for idx in train_idx]
        total_trainset = Subset(total_trainset, train_idx)
        total_trainset.targets = targets

        testset = torchvision.datasets.SVHN(root=dataroot,
                                            split='test',
                                            download=True,
                                            transform=transform_test)
    elif dataset == 'imagenet':
        total_trainset = ImageNet(root=os.path.join(dataroot,
                                                    'imagenet-pytorch'),
                                  transform=transform_train,
                                  download=True)
        testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'),
                           split='val',
                           transform=transform_test)

        # compatibility
        total_trainset.targets = [lb for _, lb in total_trainset.samples]
    elif dataset == 'reduced_imagenet':
        # randomly chosen indices
        #         idx120 = sorted(random.sample(list(range(1000)), k=120))
        idx120 = [
            16, 23, 52, 57, 76, 93, 95, 96, 99, 121, 122, 128, 148, 172, 181,
            189, 202, 210, 232, 238, 257, 258, 259, 277, 283, 289, 295, 304,
            307, 318, 322, 331, 337, 338, 345, 350, 361, 375, 376, 381, 388,
            399, 401, 408, 424, 431, 432, 440, 447, 462, 464, 472, 483, 497,
            506, 512, 530, 541, 553, 554, 557, 564, 570, 584, 612, 614, 619,
            626, 631, 632, 650, 657, 658, 660, 674, 675, 680, 682, 691, 695,
            699, 711, 734, 736, 741, 754, 757, 764, 769, 770, 780, 781, 787,
            797, 799, 811, 822, 829, 830, 835, 837, 842, 843, 845, 873, 883,
            897, 900, 902, 905, 913, 920, 925, 937, 938, 940, 941, 944, 949,
            959
        ]
        total_trainset = ImageNet(root=os.path.join(dataroot,
                                                    'imagenet-pytorch'),
                                  transform=transform_train)
        testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'),
                           split='val',
                           transform=transform_test)

        # compatibility
        total_trainset.targets = [lb for _, lb in total_trainset.samples]

        sss = StratifiedShuffleSplit(n_splits=1,
                                     test_size=len(total_trainset) - 50000,
                                     random_state=0)  # 4000 trainset
        sss = sss.split(list(range(len(total_trainset))),
                        total_trainset.targets)
        train_idx, valid_idx = next(sss)

        # filter out
        train_idx = list(
            filter(lambda x: total_trainset.labels[x] in idx120, train_idx))
        valid_idx = list(
            filter(lambda x: total_trainset.labels[x] in idx120, valid_idx))
        test_idx = list(
            filter(lambda x: testset.samples[x][1] in idx120,
                   range(len(testset))))

        targets = [
            idx120.index(total_trainset.targets[idx]) for idx in train_idx
        ]
        for idx in range(len(total_trainset.samples)):
            if total_trainset.samples[idx][1] not in idx120:
                continue
            total_trainset.samples[idx] = (total_trainset.samples[idx][0],
                                           idx120.index(
                                               total_trainset.samples[idx][1]))
        total_trainset = Subset(total_trainset, train_idx)
        total_trainset.targets = targets

        for idx in range(len(testset.samples)):
            if testset.samples[idx][1] not in idx120:
                continue
            testset.samples[idx] = (testset.samples[idx][0],
                                    idx120.index(testset.samples[idx][1]))
        testset = Subset(testset, test_idx)
        print('reduced_imagenet train=', len(total_trainset))
    else:
        raise ValueError('invalid dataset name=%s' % dataset)

    if total_aug is not None and augs is not None:
        total_trainset.set_preaug(augs, total_aug)
        print('set_preaug-')

    train_sampler = None
    if split > 0.0:
        if multilabel:
            # not sure how important stratified is, especially for val,
            # might want to test with and without and add it for multilabel in the future
            sss = ShuffleSplit(n_splits=5, test_size=split, random_state=0)
            sss = sss.split(list(range(len(total_trainset))))
            for _ in range(split_idx + 1):
                train_idx, valid_idx = next(sss)
        else:
            sss = StratifiedShuffleSplit(n_splits=5,
                                         test_size=split,
                                         random_state=0)
            sss = sss.split(list(range(len(total_trainset))),
                            total_trainset.targets)
            for _ in range(split_idx + 1):
                train_idx, valid_idx = next(sss)

        if target_lb >= 0:
            train_idx = [
                i for i in train_idx if total_trainset.targets[i] == target_lb
            ]
            valid_idx = [
                i for i in valid_idx if total_trainset.targets[i] == target_lb
            ]

        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetSampler(valid_idx)

        if multinode:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                Subset(total_trainset, train_idx),
                num_replicas=dist.get_world_size(),
                rank=dist.get_rank())
    else:
        valid_sampler = SubsetSampler([])

        if multinode:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                total_trainset,
                num_replicas=dist.get_world_size(),
                rank=dist.get_rank())
            logger.info(
                f'----- dataset with DistributedSampler  {dist.get_rank()}/{dist.get_world_size()}'
            )

    trainloader = DataLoader(total_trainset,
                             batch_size=batch,
                             shuffle=True if train_sampler is None else False,
                             num_workers=8,
                             pin_memory=True,
                             sampler=train_sampler,
                             drop_last=True,
                             collate_fn=collater)
    validloader = DataLoader(total_trainset,
                             batch_size=batch,
                             shuffle=False,
                             num_workers=4,
                             pin_memory=True,
                             sampler=valid_sampler,
                             drop_last=False,
                             collate_fn=collater)

    testloader = DataLoader(testset,
                            batch_size=batch,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=False,
                            collate_fn=collater)

    return train_sampler, trainloader, validloader, testloader