Esempio n. 1
0
def load_cifar(data_dir, cifar100=True):
    # Data loading code
    normalize = transforms.Normalize(mean=[0.5071, 0.4866, 0.4409],
                                     std=[0.2673, 0.2564, 0.2761])

    print("Loading training data")
    st = time.time()
    cifar_version = CIFAR100 if cifar100 else CIFAR10
    train_set = cifar_version(
        data_dir, True,
        transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.3,
                                   contrast=0.3,
                                   saturation=0.1,
                                   hue=0.02),
            transforms.ToTensor(), normalize,
            transforms.RandomErasing(p=0.9, value='random')
        ]))
    print("Took", time.time() - st)

    print("Loading validation data")
    val_set = CIFAR100(data_dir, False,
                       transforms.Compose([transforms.ToTensor(), normalize]))

    return train_set, val_set
Esempio n. 2
0
def load_data(traindir, valdir, img_size=224, crop_pct=0.875):
    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    scale_size = min(int(math.floor(img_size / crop_pct)), 320)

    print("Loading training data")
    st = time.time()
    train_set = ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(img_size, scale=(0.3, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.3,
                                   contrast=0.3,
                                   saturation=0.1,
                                   hue=0.02),
            transforms.ToTensor(), normalize,
            transforms.RandomErasing(p=0.9, value='random')
        ]))
    print("Took", time.time() - st)

    print("Loading validation data")
    eval_tf = []
    if scale_size < 320:
        eval_tf.append(transforms.Resize(scale_size))
    eval_tf.extend(
        [transforms.CenterCrop(img_size),
         transforms.ToTensor(), normalize])
    val_set = ImageFolder(valdir, transforms.Compose(eval_tf))

    return train_set, val_set
Esempio n. 3
0
    def __init__(self,
                 crop_size,
                 mean=(0.485, 0.456, 0.406),
                 std=(0.229, 0.224, 0.225),
                 hflip_prob=0.5,
                 auto_augment_policy=None,
                 random_erase_prob=0.0):
        trans = [transforms.RandomResizedCrop(crop_size)]
        if hflip_prob > 0:
            trans.append(transforms.RandomHorizontalFlip(hflip_prob))
        if auto_augment_policy is not None:
            if auto_augment_policy == "ra":
                trans.append(autoaugment.RandAugment())
            elif auto_augment_policy == "ta_wide":
                trans.append(autoaugment.TrivialAugmentWide())
            else:
                aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
                trans.append(autoaugment.AutoAugment(policy=aa_policy))
        trans.extend([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ])
        if random_erase_prob > 0:
            trans.append(transforms.RandomErasing(p=random_erase_prob))

        self.transforms = transforms.Compose(trans)
Esempio n. 4
0
    def __init__(
        self,
        crop_size,
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
        interpolation=InterpolationMode.BILINEAR,
        hflip_prob=0.5,
        auto_augment_policy=None,
        random_erase_prob=0.0,
    ):
        trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
        if hflip_prob > 0:
            trans.append(transforms.RandomHorizontalFlip(hflip_prob))
        if auto_augment_policy is not None:
            if auto_augment_policy == "ra":
                trans.append(autoaugment.RandAugment(interpolation=interpolation))
            elif auto_augment_policy == "ta_wide":
                trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
            else:
                aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
                trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
        trans.extend(
            [
                transforms.PILToTensor(),
                transforms.ConvertImageDtype(torch.float),
                transforms.Normalize(mean=mean, std=std),
            ]
        )
        if random_erase_prob > 0:
            trans.append(transforms.RandomErasing(p=random_erase_prob))

        self.transforms = transforms.Compose(trans)
Esempio n. 5
0
    def __init__(self,
                 size=(256, 128),
                 random_horizontal_flip=0,
                 pad=0,
                 normalize=True,
                 random_erase=0):
        """
        :param size:
        :param random_horizontal_flip: strong baseline = 0.5
        :param pad: strong baseline = 10
        :param normalize:
        :param random_erase: strong baseline = 0.5
        """
        transforms_list = list()
        transforms_list.append(transforms.Resize(size))

        if random_horizontal_flip:
            transforms_list.append(
                transforms.RandomHorizontalFlip(random_horizontal_flip))

        if pad:
            transforms_list.append(transforms.Pad(pad))
            transforms_list.append(transforms.RandomCrop(size))

        transforms_list.append(transforms.ToTensor())

        if normalize:
            transforms_list.append(self.normalization)

        if random_erase:
            transforms_list.append(
                transforms.RandomErasing(random_erase, self.random_erase_scale,
                                         self.random_erase_ratio,
                                         self.random_erase_value))

        super().__init__(transforms_list)
Esempio n. 6
0
 def RandomErasing(self, **args):
     return self._add(transforms.RandomErasing(**args))
def get_loader_with_idx(dataset,
                        batch_size,
                        image_size,
                        rand_crop,
                        mean=None,
                        std=None,
                        num_workers=6,
                        augment=False,
                        shuffle=True,
                        offset_idx=0,
                        offset_label=0,
                        sampler=None,
                        eval=False,
                        autoaugment=False,
                        drop_last=False,
                        cutout=False,
                        random_erase=False):
    '''
	Note not to use normalize in NagTrainer
	'''
    if std is None:
        std = [0.267, 0.256, 0.276]
    if mean is None:
        mean = [0.507, 0.487, 0.441]
    if sampler is not None:
        shuffle = False
    normalize = transforms.Normalize(mean=mean, std=std)  # CIFAR100
    transform_list = []

    is_cifar = image_size == 32
    is_stl = image_size == 96

    if isinstance(augment, bool):
        if augment:
            print(f"augment:{augment}")
            if is_cifar:
                transform_list.append(
                    transforms.ToPILImage())  # Comment this linr for full data
                transform_list.append(
                    transforms.RandomCrop(rand_crop, padding=4))
                transform_list.append(transforms.RandomHorizontalFlip(p=0.5))

            elif is_stl:
                transform_list.append(
                    transforms.RandomCrop(image_size, padding=12))
                transform_list.append(transforms.RandomHorizontalFlip(p=0.5))
            else:  # CUB, Places365,imagenet
                transform_list.append(
                    transforms.RandomResizedCrop(rand_crop, scale=(0.875, 1.)))
                transform_list.append(transforms.RandomHorizontalFlip(p=0.5))
            if autoaugment:
                transform_list.append(AutoAugment())
            if cutout:
                transform_list.append(Cutout_())
            transform_list.append(transforms.ToTensor())

        if augment:
            transform_list.append(normalize)
            if random_erase:
                transform_list.append(transforms.RandomErasing())
        elif eval:
            print(f"eval:{eval}")
            shuffle = False
            if not is_cifar:
                transform_list.append(transforms.Resize(image_size))
                transform_list.append(transforms.CenterCrop(rand_crop))
            transform_list.append(transforms.ToTensor())
            transform_list.append(normalize)
        else:  # not eval | augment is False:
            print(f"not eval | augment : TRUE")
            if not is_cifar:
                transform_list.append(
                    transforms.Resize((image_size, image_size)))
            # transform_list.append(transforms.CenterCrop(rand_crop))
            transform_list.append(transforms.ToTensor())

        transform = transforms.Compose(transform_list)
    else:
        transform = None

    loader = torch.utils.data.DataLoader(IndexToImageDataset(
        dataset,
        transform=transform,
        offset_idx=offset_idx,
        offset_label=offset_label),
                                         batch_size=batch_size,
                                         shuffle=shuffle,
                                         num_workers=num_workers,
                                         pin_memory=False,
                                         sampler=sampler,
                                         drop_last=drop_last)

    print(
        f"=>Generated data loader, res={image_size}, workers={num_workers} transform={transform} sampler={sampler}"
    )
    return loader
Esempio n. 8
0
def main(args):

    print(args)

    torch.backends.cudnn.benchmark = True

    # Data loading
    train_loader, val_loader = None, None

    normalize = T.Normalize(mean=[0.485, 0.456, 0.406] if args.dataset.lower()
                            == "imagenette" else [0.5071, 0.4866, 0.4409],
                            std=[0.229, 0.224, 0.225] if args.dataset.lower()
                            == "imagenette" else [0.2673, 0.2564, 0.2761])

    if not args.test_only:
        st = time.time()
        if args.dataset.lower() == "imagenette":

            train_set = ImageFolder(
                os.path.join(args.data_path, 'train'),
                T.Compose([
                    T.RandomResizedCrop(args.img_size, scale=(0.3, 1.0)),
                    T.RandomHorizontalFlip(),
                    T.ColorJitter(brightness=0.3,
                                  contrast=0.3,
                                  saturation=0.1,
                                  hue=0.02),
                    T.ToTensor(), normalize,
                    T.RandomErasing(p=0.9, value='random')
                ]))
        else:
            cifar_version = CIFAR100 if args.dataset.lower(
            ) == "cifar100" else CIFAR10
            train_set = cifar_version(
                data_dir, True,
                T.Compose([
                    T.RandomHorizontalFlip(),
                    T.ColorJitter(brightness=0.3,
                                  contrast=0.3,
                                  saturation=0.1,
                                  hue=0.02),
                    T.ToTensor(), normalize,
                    T.RandomErasing(p=0.9, value='random')
                ]))

        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            drop_last=True,
            sampler=RandomSampler(train_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(f"Training set loaded in {time.time() - st:.2f}s "
              f"({len(train_set)} samples in {len(train_loader)} batches)")

    if not (args.lr_finder or args.check_setup):
        st = time.time()
        if args.dataset.lower() == "imagenette":
            eval_tf = []
            crop_pct = 0.875
            scale_size = min(int(math.floor(args.img_size / crop_pct)), 320)
            if scale_size < 320:
                eval_tf.append(T.Resize(scale_size))
            eval_tf.extend(
                [T.CenterCrop(args.img_size),
                 T.ToTensor(), normalize])
            val_set = ImageFolder(os.path.join(args.data_path, 'val'),
                                  T.Compose(eval_tf))
        else:
            val_set = CIFAR100(data_dir, False,
                               T.Compose([T.ToTensor(), normalize]))

        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            drop_last=False,
            sampler=SequentialSampler(val_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(
            f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)"
        )

    model = holocron.models.__dict__[args.model](args.pretrained,
                                                 num_classes=len(
                                                     train_set.classes))

    if args.loss == 'crossentropy':
        criterion = nn.CrossEntropyLoss()
    elif args.loss == 'label_smoothing':
        criterion = holocron.nn.LabelSmoothingCrossEntropy()

    # Create the contiguous parameters.
    model_params = [p for p in model.parameters() if p.requires_grad]
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model_params,
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model_params,
                                     args.lr,
                                     betas=(0.95, 0.99),
                                     eps=1e-6,
                                     weight_decay=args.weight_decay)
    elif args.opt == 'radam':
        optimizer = holocron.optim.RAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'ranger':
        optimizer = Lookahead(
            holocron.optim.RAdam(model_params,
                                 args.lr,
                                 betas=(0.95, 0.99),
                                 eps=1e-6,
                                 weight_decay=args.weight_decay))
    elif args.opt == 'tadam':
        optimizer = holocron.optim.TAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)

    trainer = ClassificationTrainer(model, train_loader, val_loader, criterion,
                                    optimizer, args.device, args.output_file)
    if args.resume:
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')
        trainer.load(checkpoint)

    if args.test_only:
        print("Running evaluation")
        eval_metrics = trainer.evaluate()
        print(
            f"Validation loss: {eval_metrics['val_loss']:.4} "
            f"(Acc@1: {eval_metrics['acc1']:.2%}, Acc@5: {eval_metrics['acc5']:.2%})"
        )
        return

    if args.lr_finder:
        print("Looking for optimal LR")
        trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100))
        trainer.plot_recorder()
        return

    if args.check_setup:
        print("Checking batch overfitting")
        is_ok = trainer.check_setup(args.freeze_until,
                                    args.lr,
                                    num_it=min(len(train_loader), 100))
        print(is_ok)
        return

    print("Start training")
    start_time = time.time()
    trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched)
    total_time_str = str(
        datetime.timedelta(seconds=int(time.time() - start_time)))
    print(f"Training time {total_time_str}")
Esempio n. 9
0
import math
import random
import PIL
import torch
from PIL import Image
from matplotlib import pyplot as plt
from torchvision.transforms import transforms

from torchvision.transforms import ToPILImage

transform_train = transforms.Compose([
    # transforms.RandomCrop(32, padding=4),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.RandomErasing(value=(0.4914, 0.4822, 0.4465)),
    # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

if __name__ == "__main__":
    img = Image.open(
        "/home/xiaoyang/Dev/AI6103-assignment/data/cifar-10-batches-py/train/1_10000.jpg"
    )
    img = transform_train(img)
    show = ToPILImage()
    img = show(img)
    plt.imshow(img)
    plt.show()
Esempio n. 10
0
def main(args):

    print(args)

    torch.backends.cudnn.benchmark = True

    # Data loading
    train_loader, val_loader = None, None

    normalize = T.Normalize(mean=[0.485, 0.456, 0.406] if args.dataset.lower()
                            == "imagenette" else [0.5071, 0.4866, 0.4409],
                            std=[0.229, 0.224, 0.225] if args.dataset.lower()
                            == "imagenette" else [0.2673, 0.2564, 0.2761])

    interpolation = InterpolationMode.BILINEAR

    num_classes = None
    if not args.test_only:
        st = time.time()
        if args.dataset.lower() == "imagenette":

            train_set = ImageFolder(
                os.path.join(args.data_path, 'train'),
                T.Compose([
                    T.RandomResizedCrop(args.img_size,
                                        scale=(0.3, 1.0),
                                        interpolation=interpolation),
                    T.RandomHorizontalFlip(),
                    A.TrivialAugmentWide(interpolation=interpolation),
                    T.PILToTensor(),
                    T.ConvertImageDtype(torch.float32),
                    normalize,
                    T.RandomErasing(p=0.9, scale=(0.02, 0.2), value='random'),
                ]))
        else:
            cifar_version = CIFAR100 if args.dataset.lower(
            ) == "cifar100" else CIFAR10
            train_set = cifar_version(
                args.data_path,
                True,
                T.Compose([
                    T.RandomHorizontalFlip(),
                    A.TrivialAugmentWide(interpolation=interpolation),
                    T.PILToTensor(),
                    T.ConvertImageDtype(torch.float32), normalize,
                    T.RandomErasing(p=0.9, value='random')
                ]),
                download=True,
            )

        num_classes = len(train_set.classes)
        collate_fn = default_collate
        if args.mixup_alpha > 0:
            mix = Mixup(len(train_set.classes), alpha=0.2)
            collate_fn = lambda batch: mix(*default_collate(batch)
                                           )  # noqa: E731
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            drop_last=True,
            sampler=RandomSampler(train_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn,
            collate_fn=collate_fn,
        )

        print(f"Training set loaded in {time.time() - st:.2f}s "
              f"({len(train_set)} samples in {len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target)
        return

    if not (args.lr_finder or args.check_setup):
        st = time.time()
        if args.dataset.lower() == "imagenette":
            eval_tf = []
            crop_pct = 0.875
            scale_size = min(int(math.floor(args.img_size / crop_pct)), 320)
            if scale_size < 320:
                eval_tf.append(T.Resize(scale_size))
            eval_tf.extend([
                T.CenterCrop(args.img_size),
                T.PILToTensor(),
                T.ConvertImageDtype(torch.float32), normalize
            ])
            val_set = ImageFolder(os.path.join(args.data_path, 'val'),
                                  T.Compose(eval_tf))
        else:
            cifar_version = CIFAR100 if args.dataset.lower(
            ) == "cifar100" else CIFAR10
            val_set = cifar_version(args.data_path,
                                    False,
                                    T.Compose([
                                        T.PILToTensor(),
                                        T.ConvertImageDtype(torch.float32),
                                        normalize
                                    ]),
                                    download=True)
        num_classes = len(val_set.classes)

        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            drop_last=False,
            sampler=SequentialSampler(val_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(
            f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)"
        )

    model = holocron.models.__dict__[args.arch](args.pretrained,
                                                num_classes=num_classes)

    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

    # Create the contiguous parameters.
    model_params = [p for p in model.parameters() if p.requires_grad]
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model_params,
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'radam':
        optimizer = torch.optim.RAdam(model_params,
                                      args.lr,
                                      betas=(0.95, 0.99),
                                      eps=1e-6,
                                      weight_decay=args.weight_decay)
    elif args.opt == 'adamp':
        optimizer = holocron.optim.AdamP(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adabelief':
        optimizer = holocron.optim.AdaBelief(model_params,
                                             args.lr,
                                             betas=(0.95, 0.99),
                                             eps=1e-6,
                                             weight_decay=args.weight_decay)

    log_wb = lambda metrics: wandb.log(metrics) if args.wb else None
    trainer = ClassificationTrainer(model,
                                    train_loader,
                                    val_loader,
                                    criterion,
                                    optimizer,
                                    args.device,
                                    args.output_file,
                                    amp=args.amp,
                                    on_epoch_end=log_wb)
    if args.resume:
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')
        trainer.load(checkpoint)

    if args.test_only:
        print("Running evaluation")
        eval_metrics = trainer.evaluate()
        print(
            f"Validation loss: {eval_metrics['val_loss']:.4} "
            f"(Acc@1: {eval_metrics['acc1']:.2%}, Acc@5: {eval_metrics['acc5']:.2%})"
        )
        return

    if args.lr_finder:
        print("Looking for optimal LR")
        trainer.lr_find(args.freeze_until,
                        num_it=min(len(train_loader), 100),
                        norm_weight_decay=args.norm_weight_decay)
        trainer.plot_recorder()
        return

    if args.check_setup:
        print("Checking batch overfitting")
        is_ok = trainer.check_setup(args.freeze_until,
                                    args.lr,
                                    norm_weight_decay=args.norm_weight_decay,
                                    num_it=min(len(train_loader), 100))
        print(is_ok)
        return

    # Training monitoring
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    exp_name = f"{args.arch}-{current_time}" if args.name is None else args.name

    # W&B
    if args.wb:

        run = wandb.init(name=exp_name,
                         project="holocron-image-classification",
                         config={
                             "learning_rate": args.lr,
                             "scheduler": args.sched,
                             "weight_decay": args.weight_decay,
                             "epochs": args.epochs,
                             "batch_size": args.batch_size,
                             "architecture": args.arch,
                             "input_size": args.img_size,
                             "optimizer": args.opt,
                             "dataset": args.dataset,
                             "loss": "crossentropy",
                             "label_smoothing": args.label_smoothing,
                             "mixup_alpha": args.mixup_alpha,
                         })

    print("Start training")
    start_time = time.time()
    trainer.fit_n_epochs(args.epochs,
                         args.lr,
                         args.freeze_until,
                         args.sched,
                         norm_weight_decay=args.norm_weight_decay)
    total_time_str = str(
        datetime.timedelta(seconds=int(time.time() - start_time)))
    print(f"Training time {total_time_str}")

    if args.wb:
        run.finish()