Exemplo n.º 1
0
    def __init__(
        self,
        crop_size,
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
        interpolation=InterpolationMode.BILINEAR,
        hflip_prob=0.5,
        auto_augment_policy=None,
        random_erase_prob=0.0,
    ):
        trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
        if hflip_prob > 0:
            trans.append(transforms.RandomHorizontalFlip(hflip_prob))
        if auto_augment_policy is not None:
            if auto_augment_policy == "ra":
                trans.append(autoaugment.RandAugment(interpolation=interpolation))
            elif auto_augment_policy == "ta_wide":
                trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
            else:
                aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
                trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
        trans.extend(
            [
                transforms.PILToTensor(),
                transforms.ConvertImageDtype(torch.float),
                transforms.Normalize(mean=mean, std=std),
            ]
        )
        if random_erase_prob > 0:
            trans.append(transforms.RandomErasing(p=random_erase_prob))

        self.transforms = transforms.Compose(trans)
Exemplo n.º 2
0
    def __init__(
            self,
            crop_size,
            resize_size=256,
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225),
            interpolation=InterpolationMode.BILINEAR,
    ):

        self.transforms = transforms.Compose([
            transforms.Resize(resize_size, interpolation=interpolation),
            transforms.CenterCrop(crop_size),
            transforms.PILToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Normalize(mean=mean, std=std),
        ])
Exemplo n.º 3
0
def main(args):

    print(args)

    torch.backends.cudnn.benchmark = True

    # Data loading
    train_loader, val_loader = None, None

    normalize = T.Normalize(mean=[0.485, 0.456, 0.406] if args.dataset.lower()
                            == "imagenette" else [0.5071, 0.4866, 0.4409],
                            std=[0.229, 0.224, 0.225] if args.dataset.lower()
                            == "imagenette" else [0.2673, 0.2564, 0.2761])

    interpolation = InterpolationMode.BILINEAR

    num_classes = None
    if not args.test_only:
        st = time.time()
        if args.dataset.lower() == "imagenette":

            train_set = ImageFolder(
                os.path.join(args.data_path, 'train'),
                T.Compose([
                    T.RandomResizedCrop(args.img_size,
                                        scale=(0.3, 1.0),
                                        interpolation=interpolation),
                    T.RandomHorizontalFlip(),
                    A.TrivialAugmentWide(interpolation=interpolation),
                    T.PILToTensor(),
                    T.ConvertImageDtype(torch.float32),
                    normalize,
                    T.RandomErasing(p=0.9, scale=(0.02, 0.2), value='random'),
                ]))
        else:
            cifar_version = CIFAR100 if args.dataset.lower(
            ) == "cifar100" else CIFAR10
            train_set = cifar_version(
                args.data_path,
                True,
                T.Compose([
                    T.RandomHorizontalFlip(),
                    A.TrivialAugmentWide(interpolation=interpolation),
                    T.PILToTensor(),
                    T.ConvertImageDtype(torch.float32), normalize,
                    T.RandomErasing(p=0.9, value='random')
                ]),
                download=True,
            )

        num_classes = len(train_set.classes)
        collate_fn = default_collate
        if args.mixup_alpha > 0:
            mix = Mixup(len(train_set.classes), alpha=0.2)
            collate_fn = lambda batch: mix(*default_collate(batch)
                                           )  # noqa: E731
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            drop_last=True,
            sampler=RandomSampler(train_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn,
            collate_fn=collate_fn,
        )

        print(f"Training set loaded in {time.time() - st:.2f}s "
              f"({len(train_set)} samples in {len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target)
        return

    if not (args.lr_finder or args.check_setup):
        st = time.time()
        if args.dataset.lower() == "imagenette":
            eval_tf = []
            crop_pct = 0.875
            scale_size = min(int(math.floor(args.img_size / crop_pct)), 320)
            if scale_size < 320:
                eval_tf.append(T.Resize(scale_size))
            eval_tf.extend([
                T.CenterCrop(args.img_size),
                T.PILToTensor(),
                T.ConvertImageDtype(torch.float32), normalize
            ])
            val_set = ImageFolder(os.path.join(args.data_path, 'val'),
                                  T.Compose(eval_tf))
        else:
            cifar_version = CIFAR100 if args.dataset.lower(
            ) == "cifar100" else CIFAR10
            val_set = cifar_version(args.data_path,
                                    False,
                                    T.Compose([
                                        T.PILToTensor(),
                                        T.ConvertImageDtype(torch.float32),
                                        normalize
                                    ]),
                                    download=True)
        num_classes = len(val_set.classes)

        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            drop_last=False,
            sampler=SequentialSampler(val_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(
            f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)"
        )

    model = holocron.models.__dict__[args.arch](args.pretrained,
                                                num_classes=num_classes)

    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

    # Create the contiguous parameters.
    model_params = [p for p in model.parameters() if p.requires_grad]
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model_params,
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'radam':
        optimizer = torch.optim.RAdam(model_params,
                                      args.lr,
                                      betas=(0.95, 0.99),
                                      eps=1e-6,
                                      weight_decay=args.weight_decay)
    elif args.opt == 'adamp':
        optimizer = holocron.optim.AdamP(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adabelief':
        optimizer = holocron.optim.AdaBelief(model_params,
                                             args.lr,
                                             betas=(0.95, 0.99),
                                             eps=1e-6,
                                             weight_decay=args.weight_decay)

    log_wb = lambda metrics: wandb.log(metrics) if args.wb else None
    trainer = ClassificationTrainer(model,
                                    train_loader,
                                    val_loader,
                                    criterion,
                                    optimizer,
                                    args.device,
                                    args.output_file,
                                    amp=args.amp,
                                    on_epoch_end=log_wb)
    if args.resume:
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')
        trainer.load(checkpoint)

    if args.test_only:
        print("Running evaluation")
        eval_metrics = trainer.evaluate()
        print(
            f"Validation loss: {eval_metrics['val_loss']:.4} "
            f"(Acc@1: {eval_metrics['acc1']:.2%}, Acc@5: {eval_metrics['acc5']:.2%})"
        )
        return

    if args.lr_finder:
        print("Looking for optimal LR")
        trainer.lr_find(args.freeze_until,
                        num_it=min(len(train_loader), 100),
                        norm_weight_decay=args.norm_weight_decay)
        trainer.plot_recorder()
        return

    if args.check_setup:
        print("Checking batch overfitting")
        is_ok = trainer.check_setup(args.freeze_until,
                                    args.lr,
                                    norm_weight_decay=args.norm_weight_decay,
                                    num_it=min(len(train_loader), 100))
        print(is_ok)
        return

    # Training monitoring
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    exp_name = f"{args.arch}-{current_time}" if args.name is None else args.name

    # W&B
    if args.wb:

        run = wandb.init(name=exp_name,
                         project="holocron-image-classification",
                         config={
                             "learning_rate": args.lr,
                             "scheduler": args.sched,
                             "weight_decay": args.weight_decay,
                             "epochs": args.epochs,
                             "batch_size": args.batch_size,
                             "architecture": args.arch,
                             "input_size": args.img_size,
                             "optimizer": args.opt,
                             "dataset": args.dataset,
                             "loss": "crossentropy",
                             "label_smoothing": args.label_smoothing,
                             "mixup_alpha": args.mixup_alpha,
                         })

    print("Start training")
    start_time = time.time()
    trainer.fit_n_epochs(args.epochs,
                         args.lr,
                         args.freeze_until,
                         args.sched,
                         norm_weight_decay=args.norm_weight_decay)
    total_time_str = str(
        datetime.timedelta(seconds=int(time.time() - start_time)))
    print(f"Training time {total_time_str}")

    if args.wb:
        run.finish()