예제 #1
0
    def __init__(self,
                 root='',
                 sources=None,
                 targets=None,
                 height=256,
                 width=128,
                 transforms='random_flip',
                 norm_mean=None,
                 norm_std=None,
                 use_gpu=True,
                 split_id=0,
                 combineall=False,
                 batch_size_train=3,
                 batch_size_test=3,
                 workers=4,
                 num_instances=4,
                 num_cams=1,
                 num_datasets=1,
                 train_sampler='RandomSampler',
                 seq_len=15,
                 sample_method='evenly'):

        super(VideoDataManager, self).__init__(sources=sources,
                                               targets=targets,
                                               height=height,
                                               width=width,
                                               transforms=transforms,
                                               norm_mean=norm_mean,
                                               norm_std=norm_std,
                                               use_gpu=use_gpu)

        print('=> Loading train (source) dataset')
        trainset = []
        for name in self.sources:
            trainset_ = init_video_dataset(name,
                                           transform=self.transform_tr,
                                           mode='train',
                                           combineall=combineall,
                                           root=root,
                                           split_id=split_id,
                                           seq_len=seq_len,
                                           sample_method=sample_method)
            trainset.append(trainset_)
        trainset = sum(trainset)

        self._num_train_pids = trainset.num_train_pids
        self._num_train_cams = trainset.num_train_cams

        train_sampler = build_train_sampler(trainset.train,
                                            train_sampler,
                                            batch_size=batch_size_train,
                                            num_instances=num_instances,
                                            num_cams=num_cams,
                                            num_datasets=num_datasets)

        self.train_loader = torch.utils.data.DataLoader(
            trainset,
            sampler=train_sampler,
            batch_size=batch_size_train,
            shuffle=False,
            num_workers=workers,
            pin_memory=self.use_gpu,
            drop_last=True)

        print('=> Loading test (target) dataset')
        self.test_loader = {
            name: {
                'query': None,
                'gallery': None
            }
            for name in self.targets
        }
        self.test_dataset = {
            name: {
                'query': None,
                'gallery': None
            }
            for name in self.targets
        }

        for name in self.targets:
            # build query loader
            queryset = init_video_dataset(name,
                                          transform=self.transform_te,
                                          mode='query',
                                          combineall=combineall,
                                          root=root,
                                          split_id=split_id,
                                          seq_len=seq_len,
                                          sample_method=sample_method)
            self.test_loader[name]['query'] = torch.utils.data.DataLoader(
                queryset,
                batch_size=batch_size_test,
                shuffle=False,
                num_workers=workers,
                pin_memory=self.use_gpu,
                drop_last=False)

            # build gallery loader
            galleryset = init_video_dataset(name,
                                            transform=self.transform_te,
                                            mode='gallery',
                                            combineall=combineall,
                                            verbose=False,
                                            root=root,
                                            split_id=split_id,
                                            seq_len=seq_len,
                                            sample_method=sample_method)
            self.test_loader[name]['gallery'] = torch.utils.data.DataLoader(
                galleryset,
                batch_size=batch_size_test,
                shuffle=False,
                num_workers=workers,
                pin_memory=self.use_gpu,
                drop_last=False)

            self.test_dataset[name]['query'] = queryset.query
            self.test_dataset[name]['gallery'] = galleryset.gallery

        print('\n')
        print('  **************** Summary ****************')
        print('  source             : {}'.format(self.sources))
        print('  # source datasets  : {}'.format(len(self.sources)))
        print('  # source ids       : {}'.format(self.num_train_pids))
        print('  # source tracklets : {}'.format(len(trainset)))
        print('  # source cameras   : {}'.format(self.num_train_cams))
        print('  target             : {}'.format(self.targets))
        print('  *****************************************')
        print('\n')
예제 #2
0
    def __init__(self,
                 root='',
                 sources=None,
                 targets=None,
                 height=256,
                 width=128,
                 transforms='random_flip',
                 k_tfm=1,
                 norm_mean=None,
                 norm_std=None,
                 use_gpu=True,
                 split_id=0,
                 combineall=False,
                 load_train_targets=False,
                 batch_size_train=32,
                 batch_size_test=32,
                 workers=4,
                 num_instances=4,
                 num_cams=1,
                 num_datasets=1,
                 train_sampler='RandomSampler',
                 train_sampler_t='RandomSampler',
                 cuhk03_labeled=False,
                 cuhk03_classic_split=False,
                 market1501_500k=False):

        super(ImageDataManager, self).__init__(sources=sources,
                                               targets=targets,
                                               height=height,
                                               width=width,
                                               transforms=transforms,
                                               norm_mean=norm_mean,
                                               norm_std=norm_std,
                                               use_gpu=use_gpu)

        print('=> Loading train (source) dataset')
        trainset = []
        for name in self.sources:
            trainset_ = init_image_dataset(
                name,
                transform=self.transform_tr,
                k_tfm=k_tfm,
                mode='train',
                combineall=combineall,
                root=root,
                split_id=split_id,
                cuhk03_labeled=cuhk03_labeled,
                cuhk03_classic_split=cuhk03_classic_split,
                market1501_500k=market1501_500k)
            trainset.append(trainset_)
        trainset = sum(trainset)

        self._num_train_pids = trainset.num_train_pids
        self._num_train_cams = trainset.num_train_cams

        self.train_loader = torch.utils.data.DataLoader(
            trainset,
            sampler=build_train_sampler(trainset.train,
                                        train_sampler,
                                        batch_size=batch_size_train,
                                        num_instances=num_instances,
                                        num_cams=num_cams,
                                        num_datasets=num_datasets),
            batch_size=batch_size_train,
            shuffle=False,
            num_workers=workers,
            pin_memory=self.use_gpu,
            drop_last=True,
            collate_fn=train_collate_fn)

        self.train_loader_t = None
        if load_train_targets:
            # check if sources and targets are identical
            assert len(set(self.sources) & set(self.targets)) == 0, \
                'sources={} and targets={} must not have overlap'.format(self.sources, self.targets)

            print('=> Loading train (target) dataset')
            trainset_t = []
            for name in self.targets:
                trainset_t_ = init_image_dataset(
                    name,
                    transform=self.transform_tr,
                    k_tfm=k_tfm,
                    mode='train',
                    combineall=False,  # only use the training data
                    root=root,
                    split_id=split_id,
                    cuhk03_labeled=cuhk03_labeled,
                    cuhk03_classic_split=cuhk03_classic_split,
                    market1501_500k=market1501_500k)
                trainset_t.append(trainset_t_)
            trainset_t = sum(trainset_t)

            self.train_loader_t = torch.utils.data.DataLoader(
                trainset_t,
                sampler=build_train_sampler(trainset_t.train,
                                            train_sampler_t,
                                            batch_size=batch_size_train,
                                            num_instances=num_instances,
                                            num_cams=num_cams,
                                            num_datasets=num_datasets),
                batch_size=batch_size_train,
                shuffle=False,
                num_workers=workers,
                pin_memory=self.use_gpu,
                drop_last=True,
                collate_fn=train_collate_fn)

        print('=> Loading test (target) dataset')
        self.test_loader = {
            name: {
                'query': None,
                'gallery': None
            }
            for name in self.targets
        }
        self.test_dataset = {
            name: {
                'query': None,
                'gallery': None
            }
            for name in self.targets
        }

        for name in self.targets:
            # build query loader
            queryset = init_image_dataset(
                name,
                transform=self.transform_te,
                mode='query',
                combineall=combineall,
                root=root,
                split_id=split_id,
                cuhk03_labeled=cuhk03_labeled,
                cuhk03_classic_split=cuhk03_classic_split,
                market1501_500k=market1501_500k)
            self.test_loader[name]['query'] = torch.utils.data.DataLoader(
                queryset,
                batch_size=batch_size_test,
                shuffle=False,
                num_workers=workers,
                pin_memory=self.use_gpu,
                drop_last=False,
                collate_fn=val_collate_fn)

            # build gallery loader
            galleryset = init_image_dataset(
                name,
                transform=self.transform_te,
                mode='gallery',
                combineall=combineall,
                verbose=False,
                root=root,
                split_id=split_id,
                cuhk03_labeled=cuhk03_labeled,
                cuhk03_classic_split=cuhk03_classic_split,
                market1501_500k=market1501_500k)
            self.test_loader[name]['gallery'] = torch.utils.data.DataLoader(
                galleryset,
                batch_size=batch_size_test,
                shuffle=False,
                num_workers=workers,
                pin_memory=self.use_gpu,
                drop_last=False,
                collate_fn=val_collate_fn)

            self.test_dataset[name]['query'] = queryset.query
            self.test_dataset[name]['gallery'] = galleryset.gallery

        print('\n')
        print('  **************** Summary ****************')
        print('  source            : {}'.format(self.sources))
        print('  # source datasets : {}'.format(len(self.sources)))
        print('  # source ids      : {}'.format(self.num_train_pids))
        print('  # source images   : {}'.format(len(trainset)))
        print('  # source cameras  : {}'.format(self.num_train_cams))
        if load_train_targets:
            print('  # target images   : {} (unlabeled)'.format(
                len(trainset_t)))
        print('  target            : {}'.format(self.targets))
        print('  *****************************************')
        print('\n')
예제 #3
0
    def __init__(self, root='', sources=None, targets=None, height=256, width=128, transforms='random_flip', 
                 norm_mean=None, norm_std=None, use_gpu=True, split_id=0, combineall=False,
                 batch_size_train=32, batch_size_test=32, workers=4, num_instances=4, train_sampler='',
                 cuhk03_labeled=False, cuhk03_classic_split=False, market1501_500k=False):
        super(ImageDataManager, self).__init__(sources=sources, targets=targets, height=height, width=width,
                                               transforms=transforms, norm_mean=norm_mean, norm_std=norm_std,
                                               use_gpu=use_gpu)
        
        print('=> Loading train (source) dataset')
        trainset = []  
        for name in self.sources:
            trainset_ = init_image_dataset(
                name,
                transform=self.transform_tr,
                mode='train',
                combineall=combineall,
                root=root,
                split_id=split_id,
                cuhk03_labeled=cuhk03_labeled,
                cuhk03_classic_split=cuhk03_classic_split,
                market1501_500k=market1501_500k
            )
            trainset.append(trainset_)
        trainset = sum(trainset)

        self._num_train_pids = trainset.num_train_pids
        self._num_train_cams = trainset.num_train_cams

        train_sampler = build_train_sampler(
            trainset.train, train_sampler,
            batch_size=batch_size_train,
            num_instances=num_instances
        )

        self.trainloader = torch.utils.data.DataLoader(
            trainset,
            sampler=train_sampler,
            batch_size=batch_size_train,
            shuffle=False,
            num_workers=workers,
            pin_memory=self.use_gpu,
            drop_last=True
        )

        print('=> Loading test (target) dataset')
        self.testloader = {name: {'query': None, 'gallery': None} for name in self.targets}
        self.testdataset = {name: {'query': None, 'gallery': None} for name in self.targets}

        for name in self.targets:
            # build query loader
            queryset = init_image_dataset(
                name,
                transform=self.transform_te,
                mode='query',
                combineall=combineall,
                root=root,
                split_id=split_id,
                cuhk03_labeled=cuhk03_labeled,
                cuhk03_classic_split=cuhk03_classic_split,
                market1501_500k=market1501_500k
            )
            self.testloader[name]['query'] = torch.utils.data.DataLoader(
                queryset,
                batch_size=batch_size_test,
                shuffle=False,
                num_workers=workers,
                pin_memory=self.use_gpu,
                drop_last=False
            )

            # build gallery loader
            galleryset = init_image_dataset(
                name,
                transform=self.transform_te,
                mode='gallery',
                combineall=combineall,
                verbose=False,
                root=root,
                split_id=split_id,
                cuhk03_labeled=cuhk03_labeled,
                cuhk03_classic_split=cuhk03_classic_split,
                market1501_500k=market1501_500k
            )
            self.testloader[name]['gallery'] = torch.utils.data.DataLoader(
                galleryset,
                batch_size=batch_size_test,
                shuffle=False,
                num_workers=workers,
                pin_memory=self.use_gpu,
                drop_last=False
            )

            self.testdataset[name]['query'] = queryset.query
            self.testdataset[name]['gallery'] = galleryset.gallery

        print('\n')
        print('  **************** Summary ****************')
        print('  train            : {}'.format(self.sources))
        print('  # train datasets : {}'.format(len(self.sources)))
        print('  # train ids      : {}'.format(self.num_train_pids))
        print('  # train images   : {}'.format(len(trainset)))
        print('  # train cameras  : {}'.format(self.num_train_cams))
        print('  test             : {}'.format(self.targets))
        print('  *****************************************')
        print('\n')
예제 #4
0
    def __init__(self, root='', sources=None, targets=None, height=256, width=128, random_erase=False,
                 color_jitter=False, color_aug=False, use_cpu=False, split_id=0, combineall=False,
                 batch_size=32, workers=4, num_instances=4, train_sampler='', cuhk03_labeled=False,
                 cuhk03_classic_split=False, market1501_500k=False, val_split=0.15):

        super(ImageDataManager, self).__init__(sources=sources, targets=targets, height=height, width=width,
                                               random_erase=random_erase, color_jitter=color_jitter,
                                               color_aug=color_aug, use_cpu=use_cpu)

        ### Training data
        print('=> Loading train (source) dataset')
        trainset = []
        for name in self.sources:
            trainset_ = init_image_dataset(
                name,
                transform=self.transform_tr,
                mode='train',
                combineall=combineall,
                val_split=val_split,
                root=root,
                split_id=split_id,
                cuhk03_labeled=cuhk03_labeled,
                cuhk03_classic_split=cuhk03_classic_split,
                market1501_500k=market1501_500k
            )
            trainset.append(trainset_)
        trainset = sum(trainset)

        self._num_train_pids = trainset.num_train_pids
        self._num_train_cams = trainset.num_train_cams

        training_sampler = build_train_sampler(
            trainset.train, train_sampler,
            batch_size=batch_size,
            num_instances=num_instances
        )

        self.trainloader = torch.utils.data.DataLoader(
            trainset,
            sampler=training_sampler,
            batch_size=batch_size,
            shuffle=False,
            num_workers=workers,
            pin_memory=self.use_gpu,
            drop_last=True
        )

        ### Validation data
        print('=> Loading validation dataset')

        # Copy trainset to use for the validationset (to keep the same train/val split)
        valset = copy.deepcopy(trainset)

        # Change properties of the copied object
        valset.transform = self.transform_te
        valset.mode = 'validation'
        valset.data = valset.validation

        # Create data loader
        self.validationloader = torch.utils.data.DataLoader(
            valset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=workers,
            pin_memory=self.use_gpu,
            drop_last=False
        )

        ### Test data
        print('=> Loading test (target) dataset')
        self.testloader = {name: {'query': None, 'gallery': None} for name in self.targets}
        self.testdataset = {name: {'query': None, 'gallery': None} for name in self.targets}

        # Note that the train/val split for the datasets below will be irrelevant
        # since we are only using the test set. So, I am setting verbose=False because
        # otherwise the printed information is confusing
        for name in self.targets:
            # build query loader
            queryset = init_image_dataset(
                name,
                transform=self.transform_te,
                mode='query',
                combineall=combineall,
                val_split=val_split,
                verbose=False,
                root=root,
                split_id=split_id,
                cuhk03_labeled=cuhk03_labeled,
                cuhk03_classic_split=cuhk03_classic_split,
                market1501_500k=market1501_500k
            )
            self.testloader[name]['query'] = torch.utils.data.DataLoader(
                queryset,
                batch_size=batch_size,
                shuffle=False,
                num_workers=workers,
                pin_memory=self.use_gpu,
                drop_last=False
            )

            # build gallery loader
            galleryset = init_image_dataset(
                name,
                transform=self.transform_te,
                mode='gallery',
                combineall=combineall,
                val_split=val_split,
                verbose=False,
                root=root,
                split_id=split_id,
                cuhk03_labeled=cuhk03_labeled,
                cuhk03_classic_split=cuhk03_classic_split,
                market1501_500k=market1501_500k
            )
            self.testloader[name]['gallery'] = torch.utils.data.DataLoader(
                galleryset,
                batch_size=batch_size,
                shuffle=False,
                num_workers=workers,
                pin_memory=self.use_gpu,
                drop_last=False
            )

            self.testdataset[name]['query'] = queryset.query
            self.testdataset[name]['gallery'] = galleryset.gallery

        print('\n')
        print('  **************** Summary ****************')
        print('  train            : {}'.format(self.sources))
        print('  # train datasets : {}'.format(len(self.sources)))
        print('  # train ids      : {}'.format(self.num_train_pids))
        print('  # train images   : {}'.format(len(trainset)))
        print('  # train cameras  : {}'.format(self.num_train_cams))
        print('  test             : {}'.format(self.targets))
        print('  *****************************************')
        print('\n')
    def __init__(
        self,
        root='',
        sources=None,
        targets=None,
        height=256,
        width=128,
        enable_masks=False,
        transforms='random_flip',
        norm_mean=None,
        norm_std=None,
        use_gpu=True,
        split_id=0,
        combineall=False,
        batch_size_train=32,
        batch_size_test=32,
        correct_batch_size = False,
        workers=4,
        train_sampler='RandomSampler',
        batch_num_instances=4,
        epoch_num_instances=-1,
        fill_instances=False,
        cuhk03_labeled=False,
        cuhk03_classic_split=False,
        market1501_500k=False,
        custom_dataset_names=[''],
        custom_dataset_roots=[''],
        custom_dataset_types=[''],
        apply_masks_to_test=False,
        min_samples_per_id=0,
        num_sampled_packages=1,
        filter_classes=None,
    ):

        super(ImageDataManager, self).__init__(
            sources=sources,
            targets=targets,
            height=height,
            width=width,
            transforms=transforms,
            norm_mean=norm_mean,
            norm_std=norm_std,
            use_gpu=use_gpu,
            apply_masks_to_test=apply_masks_to_test
        )

        print('=> Loading train (source) dataset')
        train_dataset_ids_map = self.build_dataset_map(self.source_groups)

        train_dataset = []
        for name in self.sources:
            train_dataset.append(init_image_dataset(
                name,
                transform=self.transform_tr,
                mode='train',
                combineall=combineall,
                root=root,
                split_id=split_id,
                load_masks=enable_masks,
                dataset_id=train_dataset_ids_map[name],
                cuhk03_labeled=cuhk03_labeled,
                cuhk03_classic_split=cuhk03_classic_split,
                market1501_500k=market1501_500k,
                custom_dataset_names=custom_dataset_names,
                custom_dataset_roots=custom_dataset_roots,
                custom_dataset_types=custom_dataset_types,
                min_id_samples=min_samples_per_id,
                num_sampled_packages=num_sampled_packages,
                filter_classes=filter_classes,
            ))
        train_dataset = sum(train_dataset)

        self._data_counts = self.to_ordered_list(train_dataset.data_counts)
        self._num_train_pids = self.to_ordered_list(train_dataset.num_train_pids)
        self._num_train_cams = self.to_ordered_list(train_dataset.num_train_cams)
        assert isinstance(self._num_train_pids, list)
        assert isinstance(self._num_train_cams, list)
        assert len(self._num_train_pids) == len(self._num_train_cams)
        if correct_batch_size:
            batch_size_train = self.calculate_batch(batch_size_train, len(train_dataset))
        batch_size_train = max(1, min(batch_size_train, len(train_dataset)))
        self.train_loader = torch.utils.data.DataLoader(
            train_dataset,
            sampler=build_train_sampler(
                train_dataset.train,
                train_sampler,
                batch_size=batch_size_train,
                batch_num_instances=batch_num_instances,
                epoch_num_instances=epoch_num_instances,
                fill_instances=fill_instances,
            ),
            batch_size=batch_size_train,
            shuffle=False,
            worker_init_fn=worker_init_fn,
            num_workers=workers,
            pin_memory=self.use_gpu,
            drop_last=True
        )
        self.num_iter = len(self.train_loader)
        print('=> Loading test (target) dataset')
        self.test_loader = {name: {'query': None, 'gallery': None} for name in self.targets}
        self.test_dataset = {name: {'query': None, 'gallery': None} for name in self.targets}

        for name in self.targets:
            if name == 'lfw':
                lfw_data = init_image_dataset(
                    name,
                    transform=self.transform_te,
                    root=root,
                )
                self.test_loader[name]['pairs'] = torch.utils.data.DataLoader(
                    lfw_data,
                    batch_size=max(min(batch_size_test, len(lfw_data)), 1),
                    shuffle=False,
                    num_workers=workers,
                    pin_memory=self.use_gpu,
                    worker_init_fn=worker_init_fn,
                    drop_last=False
                )
            else:
                # build query loader
                query_dataset = init_image_dataset(
                    name,
                    transform=self.transform_te,
                    mode='query',
                    combineall=combineall,
                    root=root,
                    split_id=split_id,
                    cuhk03_labeled=cuhk03_labeled,
                    cuhk03_classic_split=cuhk03_classic_split,
                    market1501_500k=market1501_500k,
                    custom_dataset_names=custom_dataset_names,
                    custom_dataset_roots=custom_dataset_roots,
                    custom_dataset_types=custom_dataset_types,
                    filter_classes=filter_classes
                )
                self.test_loader[name]['query'] = torch.utils.data.DataLoader(
                    query_dataset,
                    batch_size=max(min(batch_size_test, len(query_dataset)), 1),
                    shuffle=False,
                    num_workers=workers,
                    worker_init_fn=worker_init_fn,
                    pin_memory=self.use_gpu,
                    drop_last=False
                )

                # build gallery loader
                gallery_dataset = init_image_dataset(
                    name,
                    transform=self.transform_te,
                    mode='gallery',
                    combineall=combineall,
                    verbose=False,
                    root=root,
                    split_id=split_id,
                    cuhk03_labeled=cuhk03_labeled,
                    cuhk03_classic_split=cuhk03_classic_split,
                    market1501_500k=market1501_500k,
                    custom_dataset_names=custom_dataset_names,
                    custom_dataset_roots=custom_dataset_roots,
                    custom_dataset_types=custom_dataset_types,
                    filter_classes=filter_classes
                )
                self.test_loader[name]['gallery'] = torch.utils.data.DataLoader(
                    gallery_dataset,
                    batch_size=max(min(batch_size_test, len(gallery_dataset)), 1),
                    worker_init_fn=worker_init_fn,
                    shuffle=False,
                    num_workers=workers,
                    pin_memory=self.use_gpu,
                    drop_last=False
                )

                self.test_dataset[name]['query'] = query_dataset.query
                self.test_dataset[name]['gallery'] = gallery_dataset.gallery

        print('\n')
        print('  **************** Summary ****************')
        print('  source            : {}'.format(self.sources))
        print('  # source datasets : {}'.format(len(self.sources)))
        print('  # source ids      : {}'.format(sum(self.num_train_pids)))
        print('  # source images   : {}'.format(len(train_dataset)))
        print('  # source cameras  : {}'.format(sum(self.num_train_cams)))
        print('  target            : {}'.format(self.targets))
        print('  *****************************************')
        print('\n')