Esempio n. 1
0
def test_transform(args, image):

    input_size1 = 512
    input_size2 = 448

    if int(args.subset) == 0 or int(args.subset) == 192:
        transform = transforms.Compose([
            transforms.Resize(input_size1),
            transforms.CenterCrop(input_size2),
            transforms.Upscale(upscale_factor=2),
            transforms.TransformUpscaledDCT(),
            transforms.ToTensorDCT(),
            transforms.Aggregate(),
            transforms.NormalizeDCT(
                train_upscaled_static_mean,
                train_upscaled_static_std,
            )
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize(input_size1),
            transforms.CenterCrop(input_size2),
            transforms.Upscale(upscale_factor=2),
            transforms.TransformUpscaledDCT(),
            transforms.ToTensorDCT(),
            transforms.SubsetDCT(channels=args.subset),
            transforms.Aggregate(),
            transforms.NormalizeDCT(train_upscaled_static_mean,
                                    train_upscaled_static_std,
                                    channels=args.subset)
        ])

    return transform
Esempio n. 2
0
def trainloader_upscaled_static(args, model='mobilenet'):
    valdir = os.path.join(args.data, 'train')

    if model == 'mobilenet':
        input_size1 = 1024
        input_size2 = 896
    elif model == 'resnet':
        input_size1 = 512
        input_size2 = 448
    else:
        raise NotImplementedError
    if int(args.subset) == 0 or int(args.subset) == 192:
        transform = transforms.Compose([
            enhance.random_crop(),
            enhance.horizontal_flip(),
            enhance.vertical_flip(),
            enhance.random_rotation(),
            enhance.tocv2(),
            transforms.Resize(input_size1),
            transforms.CenterCrop(input_size2),
            transforms.Upscale(upscale_factor=2),
            transforms.TransformUpscaledDCT(),
            transforms.ToTensorDCT(),
            transforms.Aggregate(),
            transforms.NormalizeDCT(
                train_upscaled_static_mean,
                train_upscaled_static_std,
            )
        ])
    else:
        transform = transforms.Compose([
            enhance.random_crop(),
            enhance.horizontal_flip(),
            enhance.vertical_flip(),
            enhance.random_rotation(),
            enhance.tocv2(),
            transforms.Resize(input_size1),
            transforms.CenterCrop(input_size2),
            transforms.Upscale(upscale_factor=2),
            transforms.TransformUpscaledDCT(),
            transforms.ToTensorDCT(),
            transforms.SubsetDCT(channels=args.subset),
            transforms.Aggregate(),
            transforms.NormalizeDCT(train_upscaled_static_mean,
                                    train_upscaled_static_std,
                                    channels=args.subset)
        ])
    dset = ImageFolderDCT(valdir, transform, backend='pil')
    val_loader = torch.utils.data.DataLoader(dset,
                                             batch_size=args.train_batch,
                                             shuffle=True,
                                             num_workers=args.workers,
                                             pin_memory=True)

    return val_loader, len(dset), dset.get_clsnum()
Esempio n. 3
0
def valloader_upscaled_static(args, model='mobilenet'):
    valdir = os.path.join(args.data, 'val')

    if model == 'mobilenet':
        input_size1 = 1024
        input_size2 = 896
    elif model == 'resnet':
        input_size1 = 512
        input_size2 = 448
    else:
        raise NotImplementedError

    transform = transforms.Compose([
            transforms.Resize(input_size1),
            transforms.CenterCrop(input_size2),
            transforms.Upscale(upscale_factor=2),
            transforms.TransformUpscaledDCT(),
            transforms.ToTensorDCT(),
            transforms.SubsetDCT(channels=args.subset, pattern=args.pattern),
            transforms.Aggregate(),
            transforms.NormalizeDCT(
                train_upscaled_static_mean,
                train_upscaled_static_std,
                channels=args.subset,
                pattern=args.pattern
            )
        ])

    val_loader = torch.utils.data.DataLoader(
        ImageFolderDCT(valdir, transform),
        batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    return val_loader
Esempio n. 4
0
def test(model):
    # bar = Bar('Processing', max=len(val_loader))

    # batch_time = AverageMeter()
    # data_time = AverageMeter()
    # losses = AverageMeter()
    # top1 = AverageMeter()
    # top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    csvfile = open('./csv.csv', 'w')
    writer = csv.writer(csvfile)
    test_root = './data/test/'
    img_test = os.listdir(test_root)
    img_test.sort(key=lambda x: int(x[:-4]))

    input_size1 = 512
    input_size2 = 448

    transform = transforms.Compose([
        transforms.Resize(input_size1),
        transforms.CenterCrop(input_size2),
        transforms.Upscale(upscale_factor=2),
        transforms.TransformUpscaledDCT(),
        transforms.ToTensorDCT(),
        transforms.SubsetDCT(channels=args.subset),
        transforms.Aggregate(),
        transforms.NormalizeDCT(train_upscaled_static_mean,
                                train_upscaled_static_std,
                                channels=args.subset)
    ])

    with torch.no_grad():
        # end = time.time()
        for i in range(len(img_test)):
            model.eval()
            # measure data loading time
            # data_time.update(time.time() - end)

            # image, target = image.cuda(non_blocking=True), target.cuda(
            #     non_blocking=True)

            image = cv2.imread(str(test_root + img_test[i]))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            # print(transform(image)[0])
            # print(type(transform(image)[0]))
            # compute output
            output = model(transform(image)[0].unsqueeze(dim=0))
            #print(output)
            _, pred = torch.max(output.data, 1)
            print(i, pred.tolist()[0])
            writer.writerow([i, pred.tolist()[0]])
Esempio n. 5
0
 def cvt_transform(self, img):
     return cvtransforms.Compose([
         cvtransforms.RandomResizedCrop(self.img_size),
         # cvtransforms.RandomHorizontalFlip(),
         cvtransforms.Upscale(upscale_factor=2),
         cvtransforms.TransformUpscaledDCT(),
         cvtransforms.ToTensorDCT(),
         cvtransforms.SubsetDCT(channels=192),
         cvtransforms.Aggregate(),
         cvtransforms.NormalizeDCT(train_upscaled_static_mean,
                                   train_upscaled_static_std,
                                   channels=192)
     ])(img)
Esempio n. 6
0
def trainloader_upscaled_static(args, model='mobilenet'):
    traindir = os.path.join(args.data, 'train')

    if model == 'mobilenet':
        input_size = 896
    elif model == 'resnet':
        input_size = 448
    else:
        raise NotImplementedError

    transform = transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.Upscale(upscale_factor=2),
        transforms.TransformUpscaledDCT(),
        transforms.ToTensorDCT(),
        transforms.SubsetDCT(channels=args.subset, pattern=args.pattern),
        transforms.Aggregate(),
        transforms.NormalizeDCT(
            train_upscaled_static_mean,
            train_upscaled_static_std,
            channels=args.subset,
            pattern=args.pattern
        )
    ])

    train_dataset = ImageFolderDCT(traindir, transform)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    train_loader_len = len(train_loader)

    return train_loader, train_sampler, train_loader_len
Esempio n. 7
0
    #         train_cr_mean_resized, train_cr_std_resized),
    # ])

    # transform3 =transforms.Compose([
    #     transforms.RandomResizedCrop(224),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.ResizedTransformDCT(),
    #     transforms.ToTensorDCT(),
    #     transforms.SubsetDCT(32),
    # ])

    transform4 = transforms.Compose([
        transforms.RandomResizedCrop(896),
        transforms.RandomHorizontalFlip(),
        transforms.Upscale(upscale_factor=2),
        transforms.TransformUpscaledDCT(),
        transforms.ToTensorDCT(),
        transforms.SubsetDCT(channels='24'),
        transforms.Aggregate(),
        transforms.NormalizeDCT(
            train_upscaled_static_mean,
            train_upscaled_static_std,
            channels='24'
        )
        ])

    transform5 = transforms.Compose([
        transforms.DCTFlatten2D(),
        transforms.UpsampleDCT(size_threshold=112 * 8, T=112 * 8, debug=False),
        transforms.SubsetDCT2(channels='32'),
        transforms.Aggregate2(),