Example #1
0
    def __init__(self, args):

        train_list = [
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]
        if args.random_erasing:
            train_list.append(
                RandomErasing(probability=args.probability,
                              mean=[0.0, 0.0, 0.0]))

        train_transform = transforms.Compose(train_list)

        test_transform = transforms.Compose([
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        if not args.test_only:
            module_train = import_module('data.' + args.data_train.lower())
            self.trainset = getattr(module_train,
                                    args.data_train)(args, train_transform,
                                                     'train')
            self.train_loader = dataloader.DataLoader(
                self.trainset,
                sampler=RandomSampler(self.trainset,
                                      args.batchid,
                                      batch_image=args.batchimage),
                #shuffle=True,
                batch_size=args.batchid * args.batchimage,
                num_workers=args.nThread)
        else:
            self.train_loader = None

        if args.data_test in [
                'Market1501', 'Market1501_folder', 'MSMT17_folder',
                'CUHK03_folder', 'SYSU30K_folder'
        ]:
            module = import_module('data.' + args.data_train.lower())
            self.testset = getattr(module,
                                   args.data_test)(args, test_transform,
                                                   'test')
            self.queryset = getattr(module,
                                    args.data_test)(args, test_transform,
                                                    'query')

        else:
            raise Exception()

        self.test_loader = dataloader.DataLoader(self.testset,
                                                 batch_size=args.batchtest,
                                                 num_workers=args.nThread)
        self.query_loader = dataloader.DataLoader(self.queryset,
                                                  batch_size=args.batchtest,
                                                  num_workers=args.nThread)
Example #2
0
def make_dataloader(args, epoch=0):
    """
    Make train dataloader.
    
    Args:
        epoch: current epoch number, used in random erasing data augmentation.
    """
    train_list = [
        transforms.Resize((args.img_height, args.img_width), interpolation=3),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]
    if args.random_erasing:
        probability = 0.3 + 0.4 * min((float(epoch) / args.num_epochs), 0.8)
        s_epoch = 0.1 + 0.3 * min((float(epoch) / args.num_epochs), 0.8)
        train_list.append(
            RandomErasing(probability=probability,
                          s_epoch=s_epoch,
                          mean=[0.0, 0.0, 0.0]))
    train_transform = transforms.Compose(train_list)

    batch_m = args.num_classes
    if 'dmml' in args.loss_type:
        batch_k = args.num_support + args.num_query
    elif args.loss_type == 'npair':
        batch_k = 2
    else:
        batch_k = args.num_instances

    if args.dataset == 'market1501':
        train_set = Market1501(args.dataset_root,
                               train_transform,
                               split='train')
    elif args.dataset == 'duke':
        train_set = DukeMTMC_reID(args.dataset_root,
                                  train_transform,
                                  split='train')
    else:
        raise NotImplementedError

    train_loader = dataloader.DataLoader(train_set,
                                         sampler=RandomSampler(
                                             train_set, batch_k),
                                         batch_size=batch_m * batch_k,
                                         num_workers=args.num_workers,
                                         drop_last=True)

    return train_loader
Example #3
0
    def __init__(self, args):
        train_list = [
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], \
                std=[0.229, 0.224, 0.225])
        ]
        if args.random_erasing:
            train_list.append(RandomErasing(probability=args.probability, \
                mean=[0.0, 0.0, 0.0]))
        
        train_transform = transforms.Compose(train_list)

        test_transform = transforms.Compose([
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], \
                std=[0.229, 0.224, 0.225])
        ])

        if not args.test_only:
            self.trainset = GeneralDataLoader(args, train_transform, \
                args.data_train.lower(), 'train')
            self.train_loader = dataloader.DataLoader(self.trainset,
                sampler=RandomSampler(self.trainset,args.batchid,batch_image=args.batchimage),
                #shuffle=True,
                batch_size=args.batchid * args.batchimage,
                num_workers=args.nThread)
        else:
            self.train_loader = None

        if args.data_test in ['market1501', 'duke', 'cuhk03', 'rap2']:
            self.testset = GeneralDataLoader(args, test_transform, \
                args.data_test, 'test')
            self.queryset= GeneralDataLoader(args, test_transform, \
                args.data_test, 'query')
        else:
            raise Exception()

        self.test_loader = dataloader.DataLoader(self.testset, \
            batch_size=args.batchtest, num_workers=args.nThread)
        self.query_loader = dataloader.DataLoader(self.queryset, \
            batch_size=args.batchtest, num_workers=args.nThread)
Example #4
0
    def _get_loader(self, indices, **kwargs):
        dataset = self._get_dataset(**kwargs)
        dataset = Subset(dataset, indices)
        shuffle = kwargs.pop("shuffle", False)
        if shuffle is True:
            sampler = RandomSampler(dataset)
            dataloader = DataLoader(dataset,
                                    sampler=sampler,
                                    num_workers=self.num_workers,
                                    pin_memory=self.pin_memory,
                                    **kwargs)
        else:
            dataloader = DataLoader(dataset,
                                    shuffle=False,
                                    num_workers=self.num_workers,
                                    pin_memory=self.pin_memory,
                                    **kwargs)

        return dataloader
Example #5
0
    def __init__(self, args):
       # 重置图像分辨率参数 依概率p水平翻转 转化为tensor 对数据按通道进行标准化
        train_list = [
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]
        #随机擦除图像增强
        if args.random_erasing:
            train_list.append(RandomErasing(probability=args.probability, mean=[0.0, 0.0, 0.0]))

        train_transform = transforms.Compose(train_list)
       # 一起组成几个变换
        test_transform = transforms.Compose([
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    # 判断是否仅仅测试
        if not args.test_only:
            module_train = import_module('data.' + args.data_train.lower())
            self.trainset = getattr(module_train, args.data_train)(args, train_transform, 'train')
    # 从数据集中加载 batch_size个随机选取的图像数据 
            self.train_loader = dataloader.DataLoader(self.trainset,
                            sampler=RandomSampler(self.trainset,args.batchid,batch_image=args.batchimage),
                            #shuffle=True,
                            batch_size=args.batchid * args.batchimage,
                            num_workers=args.nThread)
        else:
            self.train_loader = None
        
        if args.data_test in ['Market1501']:
            module = import_module('data.' + args.data_train.lower())
            self.testset = getattr(module, args.data_test)(args, test_transform, 'test')
            self.queryset = getattr(module, args.data_test)(args, test_transform, 'query')

        else:
            raise Exception()

        self.test_loader = dataloader.DataLoader(self.testset, batch_size=args.batchtest, num_workers=args.nThread)
        self.query_loader = dataloader.DataLoader(self.queryset, batch_size=args.batchtest, num_workers=args.nThread)
Example #6
0
    def _get_loader(self, indices, **kwargs):
        """
        Takes the "extra" argument from kwargs and zips it together with the original data into a ZipDataset
        :param indices: indices of the subset of the data to be extracted
        :param kwargs: an arbitrary dictionary
        :return: a DataLoader
        """
        dataset = self._get_dataset()
        dataset = Subset(dataset, indices)
        dataset_extra = kwargs.pop("extra", None)

        if dataset_extra is not None and isinstance(
                dataset_extra, list) and len(dataset_extra) > 0:
            datasets = [dataset, dataset_extra]
            dataset = ZipDataset(*datasets)
        elif dataset_extra is None or (isinstance(dataset_extra, list)
                                       and len(dataset_extra) == 0):
            pass
        else:
            raise NotImplementedError(
                "Check that extra is None, an empty list or a non-empty list")

        shuffle = kwargs.pop("shuffle", False)
        if shuffle is True:
            sampler = RandomSampler(dataset)
            dataloader = DataLoader(dataset,
                                    sampler=sampler,
                                    num_workers=self.num_workers,
                                    pin_memory=self.pin_memory)
        else:
            dataloader = DataLoader(dataset,
                                    shuffle=False,
                                    num_workers=self.num_workers,
                                    pin_memory=self.pin_memory)

        return dataloader
Example #7
0
    def __init__(self, args):
        # 1. Training transforms
        train_list = []

        if args.random_crop:
            train_list.append(Random2DTranslation(args.height, args.width,
                                                  0.5))
        else:
            train_list.append(
                transforms.Resize((args.height, args.width), interpolation=3))

        train_list.append(transforms.RandomHorizontalFlip())

        if args.color_jitter:
            train_list.append(
                transforms.ColorJitter(brightness=0.2,
                                       contrast=0.15,
                                       saturation=0,
                                       hue=0))

        train_list.append(transforms.ToTensor())
        train_list.append(
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]))

        if args.random_erasing:
            train_list.append(
                RandomErasing(probability=args.probability,
                              mean=[0.0, 0.0, 0.0]))

        train_transform = transforms.Compose(train_list)

        # 2. Test transforms
        test_transform = transforms.Compose([
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        if not args.test_only:
            module_train = import_module('data.' + args.data_train.lower())
            self.trainset = getattr(module_train,
                                    args.data_train)(args, train_transform,
                                                     'train')
            self.train_loader = dataloader.DataLoader(
                self.trainset,
                sampler=RandomSampler(self.trainset,
                                      args.batchid,
                                      batch_image=args.batchimage),
                # shuffle=True,
                batch_size=args.batchid * args.batchimage,
                num_workers=args.nThread)
        else:
            self.train_loader = None

        if args.data_test in ['Market1501', 'Boxes']:
            module = import_module('data.' + args.data_train.lower())
            self.testset = getattr(module,
                                   args.data_test)(args, test_transform,
                                                   'test')
            self.queryset = getattr(module,
                                    args.data_test)(args, test_transform,
                                                    'query')
        else:
            raise Exception()

        self.test_loader = dataloader.DataLoader(self.testset,
                                                 batch_size=args.batchtest,
                                                 num_workers=args.nThread)
        self.query_loader = dataloader.DataLoader(self.queryset,
                                                  batch_size=args.batchtest,
                                                  num_workers=args.nThread)
Example #8
0
    def __init__(self, args):

        train_list = [
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]
        if args.random_erasing:
            train_list.append(
                RandomErasing(probability=args.probability,
                              mean=[0.0, 0.0, 0.0]))

        train_transform = transforms.Compose(train_list)
        self.train_transform = train_transform

        test_transform = transforms.Compose([
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.test_transform = test_transform

        if args.data_train in ['Market1501']:
            module_train = import_module('data.' + args.data_train.lower())
            self.trainset = getattr(module_train,
                                    args.data_train)(args, train_transform,
                                                     'train')
            # a = self.trainset[1]
            self.train_loader = dataloader.DataLoader(
                self.trainset,
                sampler=RandomSampler(self.trainset,
                                      args.batchid,
                                      batch_image=args.batchimage),
                # shuffle=True,
                batch_size=args.batchid * args.batchimage,
                num_workers=args.nThread)
        elif args.data_train in ['SYSU']:
            module_train = import_module('data.' + args.data_train.lower())
            data_path = args.datadir

            self.trainset = getattr(module_train,
                                    args.data_train)(data_path,
                                                     train_transform)
            color_pos, thermal_pos = module_train.GenIdx(
                self.trainset.train_color_label,
                self.trainset.train_thermal_label)

            self.train_loader = dataloader.DataLoader(
                self.trainset,
                sampler=IdentitySampler(self.trainset.train_color_label,
                                        self.trainset.train_thermal_label,
                                        color_pos, thermal_pos,
                                        args.batchimage, args.batchid),
                # shuffle=True,
                batch_size=args.batchid * args.batchimage,
                num_workers=args.nThread)
            self.trainset.cIndex = self.train_loader.sampler.index1  # RGB index
            self.trainset.tIndex = self.train_loader.sampler.index2  # IR index

            # a = self.trainset[10]
            # embed()

        else:
            self.train_loader = None

        # for get sysu attribute
        # args.data_test = 'SYSU'

        if args.data_test in ['Market1501']:
            module = import_module('data.' + args.data_test.lower())
            self.queryset = getattr(module,
                                    args.data_test)(args, test_transform,
                                                    'query')
            self.testset = getattr(module,
                                   args.data_test)(args, test_transform,
                                                   'test')

            print("  Dataset statistics: {}".format(args.data_test))
            print("  ------------------------------")
            print("  subset   | # ids | # images")
            print("  ------------------------------")
            print('  train    | {:5d} | {:8d}'.format(
                len(self.trainset.unique_ids), len(self.trainset)))
            print('  ------------------------------')
            print("  query    | {:5d} | {:8d}".format(
                len(self.queryset.unique_ids), len(self.queryset)))
            print("  gallery  | {:5d} | {:8d}".format(
                len(self.testset.unique_ids), len(self.testset)))
            print("  ------------------------------")

        elif args.data_test in ['SYSU']:
            module = import_module('data.' + args.data_test.lower())
            data_path = args.datadir

            n_class = len(np.unique(self.trainset.train_color_label))

            # rgb --> ir
            query_img, query_label, query_cam = module.process_query_sysu(
                data_path, mode=args.mode, img_mode="rgb")
            gall_img, gall_label, gall_cam = module.process_gallery_sysu(
                data_path, mode=args.mode, img_mode="ir")

            nquery_rgb2ir = len(query_label)
            ngall_rgb2ir = len(gall_label)

            self.queryset = module.TestData(query_img,
                                            query_label,
                                            transform=test_transform)
            self.testset = module.TestData(gall_img,
                                           gall_label,
                                           transform=test_transform)

            # ir --> rgb
            query_img, query_label, query_cam = module.process_query_sysu(
                data_path, mode=args.mode, img_mode="ir")
            gall_img, gall_label, gall_cam = module.process_gallery_sysu(
                data_path, mode=args.mode, img_mode="rgb")

            nquery_ir2rgb = len(query_label)
            ngall_ir2rgb = len(gall_label)

            self.queryset = module.TestData(query_img,
                                            query_label,
                                            transform=test_transform)
            self.testset = module.TestData(gall_img,
                                           gall_label,
                                           transform=test_transform)

            print("  Dataset statistics: {}".format(args.data_test))
            print("  ------------------------------")
            print("  subset       | # ids | # images")
            print("  ------------------------------")
            print('  rgb_train    | {:5d} | {:8d}'.format(
                n_class, len(self.trainset.train_color_label)))
            print('  ir_train     | {:5d} | {:8d}'.format(
                n_class, len(self.trainset.train_thermal_label)))
            print('  ------------------------------')
            print("  rgb_query    | {:5d} | {:8d}".format(
                len(np.unique(query_label)), nquery_rgb2ir))
            print("  ir_gallery   | {:5d} | {:8d}".format(
                len(np.unique(gall_label)), ngall_rgb2ir))
            print("  ------------------------------")
            print("  ir_query     | {:5d} | {:8d}".format(
                len(np.unique(query_label)), nquery_ir2rgb))
            print("  rgb_gallery  | {:5d} | {:8d}".format(
                len(np.unique(gall_label)), ngall_ir2rgb))
            print("  ------------------------------")

            # for get sysu attribute
            # data_path = "/home/zzz/pytorch/ECNU_TXD/shz/data/sysu"
            # args.mode = "all"
            # module = import_module('data.sysu')
            # self.testset = module.SYSU_INFERENCE(data_path, test_transform)
            # self.queryset = module.SYSU_INFERENCE(data_path, test_transform)

        else:
            raise Exception()

        self.query_loader = dataloader.DataLoader(self.queryset,
                                                  batch_size=args.batchtest,
                                                  num_workers=args.nThread)
        self.test_loader = dataloader.DataLoader(self.testset,
                                                 batch_size=args.batchtest,
                                                 num_workers=args.nThread)

        self.args = args
Example #9
0
    def __init__(self, args):
        print('[INFO] Making Data...')
        train_list = [
            transforms.Resize((args.height, args.width), interpolation=3),
            # resize the picture (interpolation=3为选择插值方法)
            transforms.RandomHorizontalFlip(),  # 依概率p水平翻转
            transforms.ToTensor(),  # 转Tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])  # 归一化
        ]
        if args.random_erasing:  # 随机擦除
            train_list.append(
                RandomErasing(probability=args.probability,
                              mean=[0.0, 0.0, 0.0]))

        train_transform = transforms.Compose(train_list)  # 组合步骤

        test_transform = transforms.Compose([
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        if not args.test_only:
            # 加载data/market1501.py,cuhk03.py, dukemtmcreid.py
            module_train = import_module(
                'data.' + args.data_train.lower())  # lower() aid : A to a
            # getattr(对象,属性)为获取某一对象的某个属性的属性值
            # def __init__(self, args = args, transform = train_transform, dtype = 'train'):
            self.trainset = getattr(module_train,
                                    args.data_train)(args, train_transform,
                                                     'train')
            self.train_loader = dataloader.DataLoader(
                self.trainset,  # 传入的数据集
                # 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
                sampler=RandomSampler(self.trainset,
                                      args.batchid,
                                      batch_image=args.batchimage),
                shuffle=False,  # 在每个epoch开始的时候,是否对数据进行重新排序
                batch_size=args.batchid * args.batchimage,  # 每个batch有多少个样本
                # 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程
                num_workers=args.nThread)
        else:
            self.train_loader = None

        if args.data_test in ['Market1501']:
            module = import_module('data.' + args.data_train.lower())
            self.testset = getattr(module,
                                   args.data_test)(args, test_transform,
                                                   'test')
            self.queryset = getattr(module,
                                    args.data_test)(args, test_transform,
                                                    'query')

        elif args.data_test in ['DukeMTMCreID']:
            module = import_module('data.' + args.data_train.lower())
            self.testset = getattr(module,
                                   args.data_test)(args, test_transform,
                                                   'test')
            self.queryset = getattr(module,
                                    args.data_test)(args, test_transform,
                                                    'query')
        else:
            raise Exception()

        self.test_loader = dataloader.DataLoader(self.testset,
                                                 batch_size=args.batchtest,
                                                 num_workers=args.nThread)
        self.query_loader = dataloader.DataLoader(self.queryset,
                                                  batch_size=args.batchtest,
                                                  num_workers=args.nThread)
Example #10
0
    def __init__(self, args):

        train_list = [
            # transforms.RandomResizedCrop(size=(args.height, args.width),scale=(0.97, 1.0)),   #随机剪裁,0.97
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(5),  #随机角度旋转+-5
            transforms.RandomAffine(5),
            transforms.ColorJitter(brightness=0.5,
                                   contrast=0.5,
                                   saturation=0.5),  #亮度随机变化0.5
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]

        # if args.random_erasing:
        #     train_list.append(RandomErasing(probability=args.probability, mean=[0.0, 0.0, 0.0]))

        train_transform = transforms.Compose(train_list)

        test_transform = transforms.Compose([
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        if not args.test_only:
            module_train = import_module('data.' + args.data_train.lower())
            self.trainset = getattr(module_train, args.data_train)(
                args, train_transform, 'train')  #调用market1501类,第二个括号里面是参数
            # 为了propotypical 服务
            # path = [j for j in self.trainset.imgs]
            # y = []
            # for _, i in enumerate(path):
            #     y.append(self.trainset._id2label[self.trainset.id(i)])
            # y = tuple(y)
            # sampler = PrototypicalBatchSampler(labels=y,
            #                                     classes_per_it=args.classes_per_it_tr,
            #                                     num_samples=args.num_support_tr + args.num_query_tr,
            #                                     iterations=args.iterations)
            self.train_loader = dataloader.DataLoader(
                self.trainset,
                sampler=RandomSampler(
                    self.trainset, args.batchid,
                    batch_image=args.batchimage),  #MGN的smaple
                #batch_sampler=sampler,                                                           #prototypical的smaple
                batch_size=args.batchid * args.batchimage,  #MGN
                num_workers=args.nThread)
        else:
            self.train_loader = None

        if args.data_test in ['Market1501']:
            module = import_module('data.' + args.data_train.lower())
            self.testset = getattr(module,
                                   args.data_test)(args, test_transform,
                                                   'test')
            self.queryset = getattr(module,
                                    args.data_test)(args, test_transform,
                                                    'query')

        else:
            raise Exception()

        self.test_loader = dataloader.DataLoader(self.testset,
                                                 batch_size=args.batchtest,
                                                 num_workers=args.nThread)
        self.query_loader = dataloader.DataLoader(self.queryset,
                                                  batch_size=args.batchtest,
                                                  num_workers=args.nThread)