def __init__( self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), interpolation=InterpolationMode.BILINEAR, hflip_prob=0.5, auto_augment_policy=None, random_erase_prob=0.0, ): trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: if auto_augment_policy == "ra": trans.append(autoaugment.RandAugment(interpolation=interpolation)) elif auto_augment_policy == "ta_wide": trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) else: aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) trans.extend( [ transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] ) if random_erase_prob > 0: trans.append(transforms.RandomErasing(p=random_erase_prob)) self.transforms = transforms.Compose(trans)
def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), hflip_prob=0.5, auto_augment_policy=None, random_erase_prob=0.0): trans = [transforms.RandomResizedCrop(crop_size)] if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: if auto_augment_policy == "ra": trans.append(autoaugment.RandAugment()) elif auto_augment_policy == "ta_wide": trans.append(autoaugment.TrivialAugmentWide()) else: aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) trans.append(autoaugment.AutoAugment(policy=aa_policy)) trans.extend([ transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) if random_erase_prob > 0: trans.append(transforms.RandomErasing(p=random_erase_prob)) self.transforms = transforms.Compose(trans)
def main(args): print(args) torch.backends.cudnn.benchmark = True # Data loading train_loader, val_loader = None, None normalize = T.Normalize(mean=[0.485, 0.456, 0.406] if args.dataset.lower() == "imagenette" else [0.5071, 0.4866, 0.4409], std=[0.229, 0.224, 0.225] if args.dataset.lower() == "imagenette" else [0.2673, 0.2564, 0.2761]) interpolation = InterpolationMode.BILINEAR num_classes = None if not args.test_only: st = time.time() if args.dataset.lower() == "imagenette": train_set = ImageFolder( os.path.join(args.data_path, 'train'), T.Compose([ T.RandomResizedCrop(args.img_size, scale=(0.3, 1.0), interpolation=interpolation), T.RandomHorizontalFlip(), A.TrivialAugmentWide(interpolation=interpolation), T.PILToTensor(), T.ConvertImageDtype(torch.float32), normalize, T.RandomErasing(p=0.9, scale=(0.02, 0.2), value='random'), ])) else: cifar_version = CIFAR100 if args.dataset.lower( ) == "cifar100" else CIFAR10 train_set = cifar_version( args.data_path, True, T.Compose([ T.RandomHorizontalFlip(), A.TrivialAugmentWide(interpolation=interpolation), T.PILToTensor(), T.ConvertImageDtype(torch.float32), normalize, T.RandomErasing(p=0.9, value='random') ]), download=True, ) num_classes = len(train_set.classes) collate_fn = default_collate if args.mixup_alpha > 0: mix = Mixup(len(train_set.classes), alpha=0.2) collate_fn = lambda batch: mix(*default_collate(batch) ) # noqa: E731 train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, drop_last=True, sampler=RandomSampler(train_set), num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn, collate_fn=collate_fn, ) print(f"Training set loaded in {time.time() - st:.2f}s " f"({len(train_set)} samples in {len(train_loader)} batches)") if args.show_samples: x, target = next(iter(train_loader)) plot_samples(x, target) return if not (args.lr_finder or args.check_setup): st = time.time() if args.dataset.lower() == "imagenette": eval_tf = [] crop_pct = 0.875 scale_size = min(int(math.floor(args.img_size / crop_pct)), 320) if scale_size < 320: eval_tf.append(T.Resize(scale_size)) eval_tf.extend([ T.CenterCrop(args.img_size), T.PILToTensor(), T.ConvertImageDtype(torch.float32), normalize ]) val_set = ImageFolder(os.path.join(args.data_path, 'val'), T.Compose(eval_tf)) else: cifar_version = CIFAR100 if args.dataset.lower( ) == "cifar100" else CIFAR10 val_set = cifar_version(args.data_path, False, T.Compose([ T.PILToTensor(), T.ConvertImageDtype(torch.float32), normalize ]), download=True) num_classes = len(val_set.classes) val_loader = torch.utils.data.DataLoader( val_set, batch_size=args.batch_size, drop_last=False, sampler=SequentialSampler(val_set), num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn) print( f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)" ) model = holocron.models.__dict__[args.arch](args.pretrained, num_classes=num_classes) criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) # Create the contiguous parameters. model_params = [p for p in model.parameters() if p.requires_grad] if args.opt == 'sgd': optimizer = torch.optim.SGD(model_params, args.lr, momentum=0.9, weight_decay=args.weight_decay) elif args.opt == 'radam': optimizer = torch.optim.RAdam(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) elif args.opt == 'adamp': optimizer = holocron.optim.AdamP(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) elif args.opt == 'adabelief': optimizer = holocron.optim.AdaBelief(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) log_wb = lambda metrics: wandb.log(metrics) if args.wb else None trainer = ClassificationTrainer(model, train_loader, val_loader, criterion, optimizer, args.device, args.output_file, amp=args.amp, on_epoch_end=log_wb) if args.resume: print(f"Resuming {args.resume}") checkpoint = torch.load(args.resume, map_location='cpu') trainer.load(checkpoint) if args.test_only: print("Running evaluation") eval_metrics = trainer.evaluate() print( f"Validation loss: {eval_metrics['val_loss']:.4} " f"(Acc@1: {eval_metrics['acc1']:.2%}, Acc@5: {eval_metrics['acc5']:.2%})" ) return if args.lr_finder: print("Looking for optimal LR") trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100), norm_weight_decay=args.norm_weight_decay) trainer.plot_recorder() return if args.check_setup: print("Checking batch overfitting") is_ok = trainer.check_setup(args.freeze_until, args.lr, norm_weight_decay=args.norm_weight_decay, num_it=min(len(train_loader), 100)) print(is_ok) return # Training monitoring current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") exp_name = f"{args.arch}-{current_time}" if args.name is None else args.name # W&B if args.wb: run = wandb.init(name=exp_name, project="holocron-image-classification", config={ "learning_rate": args.lr, "scheduler": args.sched, "weight_decay": args.weight_decay, "epochs": args.epochs, "batch_size": args.batch_size, "architecture": args.arch, "input_size": args.img_size, "optimizer": args.opt, "dataset": args.dataset, "loss": "crossentropy", "label_smoothing": args.label_smoothing, "mixup_alpha": args.mixup_alpha, }) print("Start training") start_time = time.time() trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched, norm_weight_decay=args.norm_weight_decay) total_time_str = str( datetime.timedelta(seconds=int(time.time() - start_time))) print(f"Training time {total_time_str}") if args.wb: run.finish()