Exemplo n.º 1
0
    def __init__(self, args):

        train_list = [
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]
        if args.random_erasing:
            train_list.append(
                RandomErasing(probability=args.probability,
                              mean=[0.0, 0.0, 0.0]))

        train_transform = transforms.Compose(train_list)

        test_transform = transforms.Compose([
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        if not args.test_only:
            module_train = import_module('data.' + args.data_train.lower())
            self.trainset = getattr(module_train,
                                    args.data_train)(args, train_transform,
                                                     'train')
            self.train_loader = dataloader.DataLoader(
                self.trainset,
                sampler=RandomSampler(self.trainset,
                                      args.batchid,
                                      batch_image=args.batchimage),
                #shuffle=True,
                batch_size=args.batchid * args.batchimage,
                num_workers=args.nThread)
        else:
            self.train_loader = None

        if args.data_test in [
                'Market1501', 'Market1501_folder', 'MSMT17_folder',
                'CUHK03_folder', 'SYSU30K_folder'
        ]:
            module = import_module('data.' + args.data_train.lower())
            self.testset = getattr(module,
                                   args.data_test)(args, test_transform,
                                                   'test')
            self.queryset = getattr(module,
                                    args.data_test)(args, test_transform,
                                                    'query')

        else:
            raise Exception()

        self.test_loader = dataloader.DataLoader(self.testset,
                                                 batch_size=args.batchtest,
                                                 num_workers=args.nThread)
        self.query_loader = dataloader.DataLoader(self.queryset,
                                                  batch_size=args.batchtest,
                                                  num_workers=args.nThread)
 def __call__(self, x):
     if self.data == 'person':
         x = T.Resize((384, 128))(x)
     elif self.data == 'car':
         x = pad_shorter(x)
         x = T.Resize((256, 256))(x)
         x = T.RandomCrop((224, 224))(x)
     elif self.data == 'cub':
         x = pad_shorter(x)
         x = T.Resize((256, 256))(x)
         x = T.RandomCrop((224, 224))(x)
     elif self.data == 'clothes':
         x = pad_shorter(x)
         x = T.Resize((256, 256))(x)
         x = T.RandomCrop((224, 224))(x)
     elif self.data == 'product':
         x = pad_shorter(x)
         x = T.Resize((256, 256))(x)
         x = T.RandomCrop((224, 224))(x)
     elif self.data == 'cifar':
         x = T.Resize((40, 40))(x)
         x = T.RandomCrop((32, 32))(x)
     x = T.RandomHorizontalFlip()(x)
     x = T.ToTensor()(x)
     x = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(x)
     if self.data == 'person':
         x = Cutout(probability = 0.5, size=64, mean=[0.0, 0.0, 0.0])(x)
     else:
         x = RandomErasing(probability = 0.5, mean=[0.0, 0.0, 0.0])(x)
     return x
Exemplo n.º 3
0
 def __call__(self, x):
     x = T.Resize((cfg.TRAIN.IMG_HEIGHT, cfg.TRAIN.IMG_WIDTH), interpolation=Image.BICUBIC)(x)
     x = T.RandomHorizontalFlip()(x)
     x = T.RandomVerticalFlip()(x)
     x = T.ToTensor()(x)
     x = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(x)
     x = RandomErasing(probability=0.5, mean=[0., 0., 0.])(x)
     return x
Exemplo n.º 4
0
 def __call__(self, x):
     x = T.Resize((256, 128), interpolation=3)(x)
     x = T.RandomHorizontalFlip(0.5)(x)
     if self.data_aug is not None:
         x = T.Pad(10)(x)
         x = T.RandomCrop((256, 128))(x)
     x = T.ToTensor()(x)
     x = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224,
                                                      0.225])(x)
     if self.data_aug is not None:
         if self.data_aug == 'RandomErase':
             x = RandomErasing(mean=[0.485, 0.456, 0.406])(x)
         else:
             x = Cutout(probability=0.5, size=64, mean=[0.0, 0.0, 0.0])(x)
     return x
Exemplo n.º 5
0
    def __init__(self, args):
        train_list = [
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], \
                std=[0.229, 0.224, 0.225])
        ]
        if args.random_erasing:
            train_list.append(RandomErasing(probability=args.probability, \
                mean=[0.0, 0.0, 0.0]))
        
        train_transform = transforms.Compose(train_list)

        test_transform = transforms.Compose([
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], \
                std=[0.229, 0.224, 0.225])
        ])

        if not args.test_only:
            self.trainset = GeneralDataLoader(args, train_transform, \
                args.data_train.lower(), 'train')
            self.train_loader = dataloader.DataLoader(self.trainset,
                sampler=RandomSampler(self.trainset,args.batchid,batch_image=args.batchimage),
                #shuffle=True,
                batch_size=args.batchid * args.batchimage,
                num_workers=args.nThread)
        else:
            self.train_loader = None

        if args.data_test in ['market1501', 'duke', 'cuhk03', 'rap2']:
            self.testset = GeneralDataLoader(args, test_transform, \
                args.data_test, 'test')
            self.queryset= GeneralDataLoader(args, test_transform, \
                args.data_test, 'query')
        else:
            raise Exception()

        self.test_loader = dataloader.DataLoader(self.testset, \
            batch_size=args.batchtest, num_workers=args.nThread)
        self.query_loader = dataloader.DataLoader(self.queryset, \
            batch_size=args.batchtest, num_workers=args.nThread)
Exemplo n.º 6
0
######################################################################
# Load Data
# --------
#

normalize_transform = T.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])

train_transforms = T.Compose([
    T.Resize([384, 128]),
    T.RandomHorizontalFlip(p=0.5),
    T.Pad(10),
    T.RandomCrop([384, 128]),
    T.ToTensor(), normalize_transform,
    RandomErasing(probability=0.5, mean=[0.485, 0.456, 0.406])
])

# val_transforms = T.Compose([
#     T.Resize([384, 128]),
#     T.ToTensor(),
#     normalize_transform
# ])

dataset = init_dataset('mars', root='../')
dataset_sizes = {}
dataset_sizes['train'] = dataset.num_train_imgs
train_set = VideoDataset(dataset.train, opt.seq_len, opt.sample_method,
                         train_transforms)
dataloaders = {}
dataloaders['train'] = DataLoader(train_set,
Exemplo n.º 7
0
transform_train_list = [
    transforms.Resize((384, 128), interpolation=3),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
transform_val_list = [
    transforms.Resize(size=(384, 128), interpolation=3),  # Image.BICUBIC
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]

if opt.erasing_p > 0:
    transform_train_list = transform_train_list + \
        [RandomErasing(probability=opt.erasing_p, mean=[0.0, 0.0, 0.0])]

if opt.color_jitter:
    transform_train_list = [
        transforms.ColorJitter(
            brightness=0.1, contrast=0.1, saturation=0.1, hue=0)
    ] + transform_train_list

print(transform_train_list)

data_transforms = {
    'train': transforms.Compose(transform_train_list),
    'val': transforms.Compose(transform_val_list),
}

train_all = ''
Exemplo n.º 8
0
    gid = int(str_id)
    if gid >=0:
        gpu_ids.append(gid)

if len(gpu_ids)>0:
    torch.cuda.set_device(gpu_ids[0])

######################################################################
# prepossessing
transform_train_list = [transforms.Resize((384,128)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]

if opt.erasing_p>0:
    transform_train_list = transform_train_list + [RandomErasing(opt.erasing_p)]
if opt.color_jitter:
    transform_train_list = [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)] + transform_train_list

data_transforms = {'train': transforms.Compose(transform_train_list)}

train_all = ''
if opt.train_all:
     train_all = '_all'

image_datasets = datasets.ImageFolder(os.path.join(data_dir, 'train' + train_all), data_transforms['train'])
Sampler = BalancedBatchSampler(image_datasets, n_classes=opt.n_classes, n_samples=opt.n_images)
dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=opt.n_classes*opt.n_images, shuffle=False, sampler=Sampler)

dataset_sizes = len(image_datasets)
class_names = image_datasets.classes
Exemplo n.º 9
0
    def __init__(self, args):
        # 1. Training transforms
        train_list = []

        if args.random_crop:
            train_list.append(Random2DTranslation(args.height, args.width,
                                                  0.5))
        else:
            train_list.append(
                transforms.Resize((args.height, args.width), interpolation=3))

        train_list.append(transforms.RandomHorizontalFlip())

        if args.color_jitter:
            train_list.append(
                transforms.ColorJitter(brightness=0.2,
                                       contrast=0.15,
                                       saturation=0,
                                       hue=0))

        train_list.append(transforms.ToTensor())
        train_list.append(
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]))

        if args.random_erasing:
            train_list.append(
                RandomErasing(probability=args.probability,
                              mean=[0.0, 0.0, 0.0]))

        train_transform = transforms.Compose(train_list)

        # 2. Test transforms
        test_transform = transforms.Compose([
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        if not args.test_only:
            module_train = import_module('data.' + args.data_train.lower())
            self.trainset = getattr(module_train,
                                    args.data_train)(args, train_transform,
                                                     'train')
            self.train_loader = dataloader.DataLoader(
                self.trainset,
                sampler=RandomSampler(self.trainset,
                                      args.batchid,
                                      batch_image=args.batchimage),
                # shuffle=True,
                batch_size=args.batchid * args.batchimage,
                num_workers=args.nThread)
        else:
            self.train_loader = None

        if args.data_test in ['Market1501', 'Boxes']:
            module = import_module('data.' + args.data_train.lower())
            self.testset = getattr(module,
                                   args.data_test)(args, test_transform,
                                                   'test')
            self.queryset = getattr(module,
                                    args.data_test)(args, test_transform,
                                                    'query')
        else:
            raise Exception()

        self.test_loader = dataloader.DataLoader(self.testset,
                                                 batch_size=args.batchtest,
                                                 num_workers=args.nThread)
        self.query_loader = dataloader.DataLoader(self.queryset,
                                                  batch_size=args.batchtest,
                                                  num_workers=args.nThread)
Exemplo n.º 10
0
    transforms.Resize((288, 144), interpolation=3),
    transforms.RandomCrop((256, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]

transform_val_list = [
    transforms.Resize(size=(256, 128), interpolation=3),  #Image.BICUBIC
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]

if opt.erasing_p > 0:
    transform_train_list = transform_train_list + [
        RandomErasing(probability=opt.erasing_p, mean=[0.0, 0.0, 0.0])
    ]

if opt.color_jitter:
    transform_train_list = [
        transforms.ColorJitter(
            brightness=0.1, contrast=0.1, saturation=0.1, hue=0)
    ] + transform_train_list

# print(transform_train_list)
data_transforms = {
    'train': transforms.Compose(transform_train_list),
    'val': transforms.Compose(transform_val_list),
}

train_all = ''
Exemplo n.º 11
0
]
transform_train_list2 = [
    # transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
    transforms.Resize((384, 128), interpolation=3),  # resize
    # transforms.RandomGrayscale(p=0.2),
    # transforms.RandomCrop((256,128)),
    # transforms.RandomHorizontalFlip(),#randomly horizon flip image
    transforms.ToTensor(),  # convert PIL image or numpy.ndarray to tensor
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225])  # [m1,m2...mn][s1,s2...sn] for n channels
]

if opt.erasing_p > 0:
    transform_train_list = transform_train_list + [
        RandomErasing(opt.erasing_p)
    ]  # randomly select a rectangle region in a image and erase its pixels with random values
if opt.color_jitter:
    transform_train_list = [
        transforms.ColorJitter(
            brightness=0.1, contrast=0.1, saturation=0.1, hue=0)
    ] + transform_train_list
# randomly change the brightness,contrast,saturation
print(transform_train_list)
data_transforms = {
    'train': transforms.Compose(
        transform_train_list),  # compose several transforms together
    'none': transforms.Compose(transform_train_list2)
}

image_datasets = datasets.ImageFolder(
Exemplo n.º 12
0
    transforms.RandomHorizontalFlip(),  # randomly horizon flip image
    transforms.ToTensor(),  # convert PIL image or numpy.ndarray to tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # [m1,m2...mn][s1,s2...sn] for n channels
]
transform_train_list2 = [
    # transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
    transforms.Resize((384, 128), interpolation=3),  # resize
    # transforms.RandomGrayscale(p=0.2),
    # transforms.RandomCrop((256,128)),
    # transforms.RandomHorizontalFlip(),#randomly horizon flip image
    transforms.ToTensor(),  # convert PIL image or numpy.ndarray to tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # [m1,m2...mn][s1,s2...sn] for n channels
]

if opt.erasing_p > 0:
    transform_train_list = transform_train_list + [RandomErasing(
        opt.erasing_p)]  # randomly select a rectangle region in a image and erase its pixels with random values
if opt.color_jitter:
    transform_train_list = [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1,
                                                   hue=0)] + transform_train_list
# randomly change the brightness,contrast,saturation
print(transform_train_list)
data_transforms = {
    'train': transforms.Compose(transform_train_list),  # compose several transforms together
    'none': transforms.Compose(transform_train_list2)
}

image_datasets = datasets.ImageFolder(os.path.join(data_dir), data_transforms[
    'train'])  # return several image folders indicating class including images
# image_datasets=ImageDataset(filelist='./train_set/train_list.txt',source='./train_set',transform1=data_transforms['train'],transform2=data_transforms['none'])
print(len(image_datasets))
# batch sampling for dataset
Exemplo n.º 13
0
 def __init__(self, args):
     self.args = args
     self.random_erasing = RandomErasing(probability=args.probability,
                                         mean=[0.0, 0.0, 0.0])
Exemplo n.º 14
0
    def __init__(self, args):

        train_list = [
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]
        if args.random_erasing:
            train_list.append(
                RandomErasing(probability=args.probability,
                              mean=[0.0, 0.0, 0.0]))

        train_transform = transforms.Compose(train_list)
        self.train_transform = train_transform

        test_transform = transforms.Compose([
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.test_transform = test_transform

        if args.data_train in ['Market1501']:
            module_train = import_module('data.' + args.data_train.lower())
            self.trainset = getattr(module_train,
                                    args.data_train)(args, train_transform,
                                                     'train')
            # a = self.trainset[1]
            self.train_loader = dataloader.DataLoader(
                self.trainset,
                sampler=RandomSampler(self.trainset,
                                      args.batchid,
                                      batch_image=args.batchimage),
                # shuffle=True,
                batch_size=args.batchid * args.batchimage,
                num_workers=args.nThread)
        elif args.data_train in ['SYSU']:
            module_train = import_module('data.' + args.data_train.lower())
            data_path = args.datadir

            self.trainset = getattr(module_train,
                                    args.data_train)(data_path,
                                                     train_transform)
            color_pos, thermal_pos = module_train.GenIdx(
                self.trainset.train_color_label,
                self.trainset.train_thermal_label)

            self.train_loader = dataloader.DataLoader(
                self.trainset,
                sampler=IdentitySampler(self.trainset.train_color_label,
                                        self.trainset.train_thermal_label,
                                        color_pos, thermal_pos,
                                        args.batchimage, args.batchid),
                # shuffle=True,
                batch_size=args.batchid * args.batchimage,
                num_workers=args.nThread)
            self.trainset.cIndex = self.train_loader.sampler.index1  # RGB index
            self.trainset.tIndex = self.train_loader.sampler.index2  # IR index

            # a = self.trainset[10]
            # embed()

        else:
            self.train_loader = None

        # for get sysu attribute
        # args.data_test = 'SYSU'

        if args.data_test in ['Market1501']:
            module = import_module('data.' + args.data_test.lower())
            self.queryset = getattr(module,
                                    args.data_test)(args, test_transform,
                                                    'query')
            self.testset = getattr(module,
                                   args.data_test)(args, test_transform,
                                                   'test')

            print("  Dataset statistics: {}".format(args.data_test))
            print("  ------------------------------")
            print("  subset   | # ids | # images")
            print("  ------------------------------")
            print('  train    | {:5d} | {:8d}'.format(
                len(self.trainset.unique_ids), len(self.trainset)))
            print('  ------------------------------')
            print("  query    | {:5d} | {:8d}".format(
                len(self.queryset.unique_ids), len(self.queryset)))
            print("  gallery  | {:5d} | {:8d}".format(
                len(self.testset.unique_ids), len(self.testset)))
            print("  ------------------------------")

        elif args.data_test in ['SYSU']:
            module = import_module('data.' + args.data_test.lower())
            data_path = args.datadir

            n_class = len(np.unique(self.trainset.train_color_label))

            # rgb --> ir
            query_img, query_label, query_cam = module.process_query_sysu(
                data_path, mode=args.mode, img_mode="rgb")
            gall_img, gall_label, gall_cam = module.process_gallery_sysu(
                data_path, mode=args.mode, img_mode="ir")

            nquery_rgb2ir = len(query_label)
            ngall_rgb2ir = len(gall_label)

            self.queryset = module.TestData(query_img,
                                            query_label,
                                            transform=test_transform)
            self.testset = module.TestData(gall_img,
                                           gall_label,
                                           transform=test_transform)

            # ir --> rgb
            query_img, query_label, query_cam = module.process_query_sysu(
                data_path, mode=args.mode, img_mode="ir")
            gall_img, gall_label, gall_cam = module.process_gallery_sysu(
                data_path, mode=args.mode, img_mode="rgb")

            nquery_ir2rgb = len(query_label)
            ngall_ir2rgb = len(gall_label)

            self.queryset = module.TestData(query_img,
                                            query_label,
                                            transform=test_transform)
            self.testset = module.TestData(gall_img,
                                           gall_label,
                                           transform=test_transform)

            print("  Dataset statistics: {}".format(args.data_test))
            print("  ------------------------------")
            print("  subset       | # ids | # images")
            print("  ------------------------------")
            print('  rgb_train    | {:5d} | {:8d}'.format(
                n_class, len(self.trainset.train_color_label)))
            print('  ir_train     | {:5d} | {:8d}'.format(
                n_class, len(self.trainset.train_thermal_label)))
            print('  ------------------------------')
            print("  rgb_query    | {:5d} | {:8d}".format(
                len(np.unique(query_label)), nquery_rgb2ir))
            print("  ir_gallery   | {:5d} | {:8d}".format(
                len(np.unique(gall_label)), ngall_rgb2ir))
            print("  ------------------------------")
            print("  ir_query     | {:5d} | {:8d}".format(
                len(np.unique(query_label)), nquery_ir2rgb))
            print("  rgb_gallery  | {:5d} | {:8d}".format(
                len(np.unique(gall_label)), ngall_ir2rgb))
            print("  ------------------------------")

            # for get sysu attribute
            # data_path = "/home/zzz/pytorch/ECNU_TXD/shz/data/sysu"
            # args.mode = "all"
            # module = import_module('data.sysu')
            # self.testset = module.SYSU_INFERENCE(data_path, test_transform)
            # self.queryset = module.SYSU_INFERENCE(data_path, test_transform)

        else:
            raise Exception()

        self.query_loader = dataloader.DataLoader(self.queryset,
                                                  batch_size=args.batchtest,
                                                  num_workers=args.nThread)
        self.test_loader = dataloader.DataLoader(self.testset,
                                                 batch_size=args.batchtest,
                                                 num_workers=args.nThread)

        self.args = args
Exemplo n.º 15
0
    def __init__(self, args):
        print('[INFO] Making Data...')
        train_list = [
            transforms.Resize((args.height, args.width), interpolation=3),
            # resize the picture (interpolation=3为选择插值方法)
            transforms.RandomHorizontalFlip(),  # 依概率p水平翻转
            transforms.ToTensor(),  # 转Tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])  # 归一化
        ]
        if args.random_erasing:  # 随机擦除
            train_list.append(
                RandomErasing(probability=args.probability,
                              mean=[0.0, 0.0, 0.0]))

        train_transform = transforms.Compose(train_list)  # 组合步骤

        test_transform = transforms.Compose([
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        if not args.test_only:
            # 加载data/market1501.py,cuhk03.py, dukemtmcreid.py
            module_train = import_module(
                'data.' + args.data_train.lower())  # lower() aid : A to a
            # getattr(对象,属性)为获取某一对象的某个属性的属性值
            # def __init__(self, args = args, transform = train_transform, dtype = 'train'):
            self.trainset = getattr(module_train,
                                    args.data_train)(args, train_transform,
                                                     'train')
            self.train_loader = dataloader.DataLoader(
                self.trainset,  # 传入的数据集
                # 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
                sampler=RandomSampler(self.trainset,
                                      args.batchid,
                                      batch_image=args.batchimage),
                shuffle=False,  # 在每个epoch开始的时候,是否对数据进行重新排序
                batch_size=args.batchid * args.batchimage,  # 每个batch有多少个样本
                # 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程
                num_workers=args.nThread)
        else:
            self.train_loader = None

        if args.data_test in ['Market1501']:
            module = import_module('data.' + args.data_train.lower())
            self.testset = getattr(module,
                                   args.data_test)(args, test_transform,
                                                   'test')
            self.queryset = getattr(module,
                                    args.data_test)(args, test_transform,
                                                    'query')

        elif args.data_test in ['DukeMTMCreID']:
            module = import_module('data.' + args.data_train.lower())
            self.testset = getattr(module,
                                   args.data_test)(args, test_transform,
                                                   'test')
            self.queryset = getattr(module,
                                    args.data_test)(args, test_transform,
                                                    'query')
        else:
            raise Exception()

        self.test_loader = dataloader.DataLoader(self.testset,
                                                 batch_size=args.batchtest,
                                                 num_workers=args.nThread)
        self.query_loader = dataloader.DataLoader(self.queryset,
                                                  batch_size=args.batchtest,
                                                  num_workers=args.nThread)
Exemplo n.º 16
0
    def __init__(self, args):

        # train_list = [
        #     transforms.Resize((args.height, args.width), interpolation=3),
        #     transforms.RandomHorizontalFlip(),
        #     transforms.ToTensor(),
        #     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
        #                          0.229, 0.224, 0.225])
        # ]

        train_list = [
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.Pad(10),
            transforms.RandomCrop((args.height, args.width)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]
        if args.random_erasing:
            train_list.append(
                RandomErasing(probability=args.probability,
                              mean=[0.485, 0.456, 0.406]))
            print('Using random_erasing augmentation.')
        if args.cutout:
            train_list.append(Cutout(mean=[0.485, 0.456, 0.406]))
            print('Using cutout augmentation.')

        train_transform = transforms.Compose(train_list)

        test_transform = transforms.Compose([
            transforms.Resize((args.height, args.width), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        if not args.test_only and args.model == 'MGN':
            module_train = import_module('data.' + args.data_train.lower())
            self.trainset = getattr(module_train,
                                    args.data_train)(args, train_transform,
                                                     'train')
            self.train_loader = dataloader.DataLoader(
                self.trainset,
                sampler=RandomIdentitySampler(self.trainset,
                                              args.batchid * args.batchimage,
                                              args.batchimage),
                # shuffle=True,
                batch_size=args.batchid * args.batchimage,
                num_workers=args.nThread)
        # elif not args.test_only and args.model in ['ResNet50','PCB'] and args.loss.split('*')[1]=='CrossEntropy':
        #     module_train = import_module('data.' + args.data_train.lower())
        #     self.trainset = getattr(module_train, args.data_train)(
        #         args, train_transform, 'train')
        #     self.train_loader = dataloader.DataLoader(self.trainset,
        #                                               shuffle=True,
        #                                               batch_size=args.batchid * args.batchimage,
        #                                               num_workers=args.nThread)
        elif not args.test_only and args.model in [
                'ResNet50', 'PCB', 'PCB_v', 'PCB_conv', 'BB_2_db', 'BB',
                'MGDB', 'MGDB_v2', 'MGDB_v3', 'BB_2_v3', 'BB_2',
                'PCB_conv_modi_2', 'BB_2_conv', 'BB_2_cat', 'BB_4_cat',
                'PCB_conv_modi', 'Pyramid', 'PLR'
        ] and bool(args.sampler):

            module_train = import_module('data.' + args.data_train.lower())
            self.trainset = getattr(module_train,
                                    args.data_train)(args, train_transform,
                                                     'train')
            # self.train_loader = dataloader.DataLoader(self.trainset,
            #                                           sampler=RandomSampler(
            #                                               self.trainset, args.batchid, batch_image=args.batchimage),
            #                                           # shuffle=True,
            #                                           batch_size=args.batchid * args.batchimage,
            #                                           num_workers=args.nThread,
            #                                           drop_last=True)
            self.train_loader = dataloader.DataLoader(
                self.trainset,
                sampler=RandomIdentitySampler(self.trainset,
                                              args.batchid * args.batchimage,
                                              args.batchimage),
                # shuffle=True,
                batch_size=args.batchid * args.batchimage,
                num_workers=args.nThread)

        elif not args.test_only and args.model not in [
                'MGN', 'ResNet50', 'PCB', 'BB_2_db', 'PCB_v', 'PCB_conv',
                'MGDB', 'PCB_conv_modi_2', 'PCB_conv_modi', 'BB', 'BB_2',
                'BB_2_cat', 'BB_4_cat', 'PLR'
        ]:
            raise Exception('DataLoader for {} not designed'.format(
                args.model))
        else:
            self.train_loader = None

        if args.data_test in ['Market1501', 'DukeMTMC', 'GTA']:
            module = import_module('data.' + args.data_train.lower())
            self.galleryset = getattr(module,
                                      args.data_test)(args, test_transform,
                                                      'test')
            self.queryset = getattr(module,
                                    args.data_test)(args, test_transform,
                                                    'query')

        else:
            raise Exception()
        # print(len(self.trainset))

        self.test_loader = dataloader.DataLoader(self.galleryset,
                                                 batch_size=args.batchtest,
                                                 num_workers=args.nThread)
        self.query_loader = dataloader.DataLoader(self.queryset,
                                                  batch_size=args.batchtest,
                                                  num_workers=args.nThread)