コード例 #1
0
    def __init__(
        self,
        crop_size,
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
        interpolation=InterpolationMode.BILINEAR,
        hflip_prob=0.5,
        auto_augment_policy=None,
        random_erase_prob=0.0,
    ):
        trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
        if hflip_prob > 0:
            trans.append(transforms.RandomHorizontalFlip(hflip_prob))
        if auto_augment_policy is not None:
            if auto_augment_policy == "ra":
                trans.append(autoaugment.RandAugment(interpolation=interpolation))
            elif auto_augment_policy == "ta_wide":
                trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
            else:
                aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
                trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
        trans.extend(
            [
                transforms.PILToTensor(),
                transforms.ConvertImageDtype(torch.float),
                transforms.Normalize(mean=mean, std=std),
            ]
        )
        if random_erase_prob > 0:
            trans.append(transforms.RandomErasing(p=random_erase_prob))

        self.transforms = transforms.Compose(trans)
コード例 #2
0
ファイル: presets.py プロジェクト: zzll1220/vision
 def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)):
     self.transforms = transforms.Compose([
         ConvertBHWCtoBCHW(),
         transforms.ConvertImageDtype(torch.float32),
         transforms.Resize(resize_size),
         transforms.Normalize(mean=mean, std=std),
         transforms.CenterCrop(crop_size),
         ConvertBCHWtoCBHW()
     ])
コード例 #3
0
ファイル: presets.py プロジェクト: zzll1220/vision
 def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989),
              hflip_prob=0.5):
     trans = [
         ConvertBHWCtoBCHW(),
         transforms.ConvertImageDtype(torch.float32),
         transforms.Resize(resize_size),
     ]
     if hflip_prob > 0:
         trans.append(transforms.RandomHorizontalFlip(hflip_prob))
     trans.extend([
         transforms.Normalize(mean=mean, std=std),
         transforms.RandomCrop(crop_size),
         ConvertBCHWtoCBHW()
     ])
     self.transforms = transforms.Compose(trans)
コード例 #4
0
ファイル: presets.py プロジェクト: IntelAI/models
    def __init__(
            self,
            crop_size,
            resize_size=256,
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225),
            interpolation=InterpolationMode.BILINEAR,
    ):

        self.transforms = transforms.Compose([
            transforms.Resize(resize_size, interpolation=interpolation),
            transforms.CenterCrop(crop_size),
            transforms.PILToTensor(),
            transforms.ConvertImageDtype(torch.float),
            transforms.Normalize(mean=mean, std=std),
        ])
コード例 #5
0
ファイル: train.py プロジェクト: frgfm/Holocron
def main(args):

    print(args)

    torch.backends.cudnn.benchmark = True

    # Data loading
    train_loader, val_loader = None, None

    normalize = T.Normalize(mean=[0.485, 0.456, 0.406] if args.dataset.lower()
                            == "imagenette" else [0.5071, 0.4866, 0.4409],
                            std=[0.229, 0.224, 0.225] if args.dataset.lower()
                            == "imagenette" else [0.2673, 0.2564, 0.2761])

    interpolation = InterpolationMode.BILINEAR

    num_classes = None
    if not args.test_only:
        st = time.time()
        if args.dataset.lower() == "imagenette":

            train_set = ImageFolder(
                os.path.join(args.data_path, 'train'),
                T.Compose([
                    T.RandomResizedCrop(args.img_size,
                                        scale=(0.3, 1.0),
                                        interpolation=interpolation),
                    T.RandomHorizontalFlip(),
                    A.TrivialAugmentWide(interpolation=interpolation),
                    T.PILToTensor(),
                    T.ConvertImageDtype(torch.float32),
                    normalize,
                    T.RandomErasing(p=0.9, scale=(0.02, 0.2), value='random'),
                ]))
        else:
            cifar_version = CIFAR100 if args.dataset.lower(
            ) == "cifar100" else CIFAR10
            train_set = cifar_version(
                args.data_path,
                True,
                T.Compose([
                    T.RandomHorizontalFlip(),
                    A.TrivialAugmentWide(interpolation=interpolation),
                    T.PILToTensor(),
                    T.ConvertImageDtype(torch.float32), normalize,
                    T.RandomErasing(p=0.9, value='random')
                ]),
                download=True,
            )

        num_classes = len(train_set.classes)
        collate_fn = default_collate
        if args.mixup_alpha > 0:
            mix = Mixup(len(train_set.classes), alpha=0.2)
            collate_fn = lambda batch: mix(*default_collate(batch)
                                           )  # noqa: E731
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            drop_last=True,
            sampler=RandomSampler(train_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn,
            collate_fn=collate_fn,
        )

        print(f"Training set loaded in {time.time() - st:.2f}s "
              f"({len(train_set)} samples in {len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target)
        return

    if not (args.lr_finder or args.check_setup):
        st = time.time()
        if args.dataset.lower() == "imagenette":
            eval_tf = []
            crop_pct = 0.875
            scale_size = min(int(math.floor(args.img_size / crop_pct)), 320)
            if scale_size < 320:
                eval_tf.append(T.Resize(scale_size))
            eval_tf.extend([
                T.CenterCrop(args.img_size),
                T.PILToTensor(),
                T.ConvertImageDtype(torch.float32), normalize
            ])
            val_set = ImageFolder(os.path.join(args.data_path, 'val'),
                                  T.Compose(eval_tf))
        else:
            cifar_version = CIFAR100 if args.dataset.lower(
            ) == "cifar100" else CIFAR10
            val_set = cifar_version(args.data_path,
                                    False,
                                    T.Compose([
                                        T.PILToTensor(),
                                        T.ConvertImageDtype(torch.float32),
                                        normalize
                                    ]),
                                    download=True)
        num_classes = len(val_set.classes)

        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            drop_last=False,
            sampler=SequentialSampler(val_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(
            f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)"
        )

    model = holocron.models.__dict__[args.arch](args.pretrained,
                                                num_classes=num_classes)

    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

    # Create the contiguous parameters.
    model_params = [p for p in model.parameters() if p.requires_grad]
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model_params,
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'radam':
        optimizer = torch.optim.RAdam(model_params,
                                      args.lr,
                                      betas=(0.95, 0.99),
                                      eps=1e-6,
                                      weight_decay=args.weight_decay)
    elif args.opt == 'adamp':
        optimizer = holocron.optim.AdamP(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adabelief':
        optimizer = holocron.optim.AdaBelief(model_params,
                                             args.lr,
                                             betas=(0.95, 0.99),
                                             eps=1e-6,
                                             weight_decay=args.weight_decay)

    log_wb = lambda metrics: wandb.log(metrics) if args.wb else None
    trainer = ClassificationTrainer(model,
                                    train_loader,
                                    val_loader,
                                    criterion,
                                    optimizer,
                                    args.device,
                                    args.output_file,
                                    amp=args.amp,
                                    on_epoch_end=log_wb)
    if args.resume:
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')
        trainer.load(checkpoint)

    if args.test_only:
        print("Running evaluation")
        eval_metrics = trainer.evaluate()
        print(
            f"Validation loss: {eval_metrics['val_loss']:.4} "
            f"(Acc@1: {eval_metrics['acc1']:.2%}, Acc@5: {eval_metrics['acc5']:.2%})"
        )
        return

    if args.lr_finder:
        print("Looking for optimal LR")
        trainer.lr_find(args.freeze_until,
                        num_it=min(len(train_loader), 100),
                        norm_weight_decay=args.norm_weight_decay)
        trainer.plot_recorder()
        return

    if args.check_setup:
        print("Checking batch overfitting")
        is_ok = trainer.check_setup(args.freeze_until,
                                    args.lr,
                                    norm_weight_decay=args.norm_weight_decay,
                                    num_it=min(len(train_loader), 100))
        print(is_ok)
        return

    # Training monitoring
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    exp_name = f"{args.arch}-{current_time}" if args.name is None else args.name

    # W&B
    if args.wb:

        run = wandb.init(name=exp_name,
                         project="holocron-image-classification",
                         config={
                             "learning_rate": args.lr,
                             "scheduler": args.sched,
                             "weight_decay": args.weight_decay,
                             "epochs": args.epochs,
                             "batch_size": args.batch_size,
                             "architecture": args.arch,
                             "input_size": args.img_size,
                             "optimizer": args.opt,
                             "dataset": args.dataset,
                             "loss": "crossentropy",
                             "label_smoothing": args.label_smoothing,
                             "mixup_alpha": args.mixup_alpha,
                         })

    print("Start training")
    start_time = time.time()
    trainer.fit_n_epochs(args.epochs,
                         args.lr,
                         args.freeze_until,
                         args.sched,
                         norm_weight_decay=args.norm_weight_decay)
    total_time_str = str(
        datetime.timedelta(seconds=int(time.time() - start_time)))
    print(f"Training time {total_time_str}")

    if args.wb:
        run.finish()
コード例 #6
0
ファイル: utils.py プロジェクト: drivendataorg/zamba
import pandas as pd
from pytorch_lightning import LightningDataModule, LightningModule
from sklearn.metrics import f1_score, top_k_accuracy_score, accuracy_score
import torch
import torch.nn.functional as F
import torch.utils.data
from torchvision.transforms import transforms

from zamba.data.video import VideoLoaderConfig
from zamba.metrics import compute_species_specific_metrics
from zamba.pytorch.dataloaders import get_datasets
from zamba.pytorch.transforms import ConvertTHWCtoCTHW

default_transform = transforms.Compose([
    ConvertTHWCtoCTHW(),
    transforms.ConvertImageDtype(torch.float32),
])

DEFAULT_TOP_K = (1, 3, 5, 10)


class ZambaDataModule(LightningDataModule):
    def __init__(
        self,
        batch_size: int = 1,
        num_workers: int = max(cpu_count() - 1, 1),
        transform: transforms.Compose = default_transform,
        video_loader_config: Optional[VideoLoaderConfig] = None,
        prefetch_factor: int = 2,
        train_metadata: Optional[pd.DataFrame] = None,
        predict_metadata: Optional[pd.DataFrame] = None,