Example #1
0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule
    from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule

    pl.seed_everything(1234)
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = CPCV2.add_model_specific_args(parser)
    parser.add_argument('--dataset', default='cifar10', type=str)
    parser.add_argument('--data_dir', default='.', type=str)
    parser.add_argument('--meta_dir',
                        default='.',
                        type=str,
                        help='path to meta.bin for imagenet')
    parser.add_argument('--num_workers', default=8, type=int)
    parser.add_argument('--batch_size', type=int, default=128)

    args = parser.parse_args()

    datamodule = None

    online_evaluator = SSLOnlineEvaluator()
    if args.dataset == 'cifar10':
        datamodule = CIFAR10DataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsCIFAR10()
        datamodule.val_transforms = CPCEvalTransformsCIFAR10()
        args.patch_size = 8

    elif args.dataset == 'stl10':
        datamodule = STL10DataModule.from_argparse_args(args)
        datamodule.train_dataloader = datamodule.train_dataloader_mixed
        datamodule.val_dataloader = datamodule.val_dataloader_mixed
        datamodule.train_transforms = CPCTrainTransformsSTL10()
        datamodule.val_transforms = CPCEvalTransformsSTL10()
        args.patch_size = 16

        # 16 GB RAM - 64
        # 32 GB RAM - 144
        args.batch_size = 144

        def to_device(batch, device):
            (_, _), (x2, y2) = batch
            x2 = x2.to(device)
            y2 = y2.to(device)
            return x2, y2

        online_evaluator.to_device = to_device

    elif args.dataset == 'imagenet2012':
        datamodule = SSLImagenetDataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsImageNet128()
        datamodule.val_transforms = CPCEvalTransformsImageNet128()
        args.patch_size = 32

    model = CPCV2(**vars(args))
    trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_evaluator])
    trainer.fit(model, datamodule)
Example #2
0
def cli_main():
    from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule
    from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule

    pl.seed_everything(1234)
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = CPCV2.add_model_specific_args(parser)

    args = parser.parse_args()
    args.online_ft = True

    datamodule = None

    online_evaluator = SSLOnlineEvaluator()
    if args.dataset == 'cifar10':
        datamodule = CIFAR10DataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsCIFAR10()
        datamodule.val_transforms = CPCEvalTransformsCIFAR10()
        args.patch_size = 8

    elif args.dataset == 'stl10':
        datamodule = STL10DataModule.from_argparse_args(args)
        datamodule.train_dataloader = datamodule.train_dataloader_mixed
        datamodule.val_dataloader = datamodule.val_dataloader_mixed
        datamodule.train_transforms = CPCTrainTransformsSTL10()
        datamodule.val_transforms = CPCEvalTransformsSTL10()
        args.patch_size = 16

        # 16 GB RAM - 64
        # 32 GB RAM - 144
        args.batch_size = 144

        def to_device(batch, device):
            (_, _), (x2, y2) = batch
            x2 = x2.to(device)
            y2 = y2.to(device)
            return x2, y2

        online_evaluator.to_device = to_device

    elif args.dataset == 'imagenet2012':
        datamodule = SSLImagenetDataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsImageNet128()
        datamodule.val_transforms = CPCEvalTransformsImageNet128()
        args.patch_size = 32

    model = CPCV2(**vars(args), datamodule=datamodule)
    trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_evaluator])
    trainer.fit(model)
Example #3
0
def cli_main():
    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
    from pl_bolts.datamodules import CIFAR10DataModule
    from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule

    seed_everything(1234)
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    parser = CPC_v2.add_model_specific_args(parser)
    parser.add_argument("--dataset", default="cifar10", type=str)
    parser.add_argument("--data_dir", default=".", type=str)
    parser.add_argument("--meta_dir",
                        default=".",
                        type=str,
                        help="path to meta.bin for imagenet")
    parser.add_argument("--num_workers", default=8, type=int)
    parser.add_argument("--hidden_mlp",
                        default=2048,
                        type=int,
                        help="hidden layer dimension in projection head")
    parser.add_argument("--batch_size", type=int, default=128)

    args = parser.parse_args()

    datamodule = None
    if args.dataset == "cifar10":
        datamodule = CIFAR10DataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsCIFAR10()
        datamodule.val_transforms = CPCEvalTransformsCIFAR10()
        args.patch_size = 8

    elif args.dataset == "stl10":
        datamodule = STL10DataModule.from_argparse_args(args)
        datamodule.train_dataloader = datamodule.train_dataloader_mixed
        datamodule.val_dataloader = datamodule.val_dataloader_mixed
        datamodule.train_transforms = CPCTrainTransformsSTL10()
        datamodule.val_transforms = CPCEvalTransformsSTL10()
        args.patch_size = 16

    elif args.dataset == "imagenet2012":
        datamodule = SSLImagenetDataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsImageNet128()
        datamodule.val_transforms = CPCEvalTransformsImageNet128()
        args.patch_size = 32

    online_evaluator = SSLOnlineEvaluator(
        drop_p=0.0,
        hidden_dim=None,
        z_dim=args.hidden_mlp,
        num_classes=datamodule.num_classes,
        dataset=args.dataset,
    )
    if args.dataset == "stl10":
        # 16 GB RAM - 64
        # 32 GB RAM - 144
        args.batch_size = 144

        def to_device(batch, device):
            (_, _), (x2, y2) = batch
            x2 = x2.to(device)
            y2 = y2.to(device)
            return x2, y2

        online_evaluator.to_device = to_device

    model = CPC_v2(**vars(args))
    trainer = Trainer.from_argparse_args(args, callbacks=[online_evaluator])
    trainer.fit(model, datamodule=datamodule)
Example #4
0
    parser = CPCV2.add_model_specific_args(parser)

    args = parser.parse_args()
    args.online_ft = True

    datamodule = None

    if args.dataset == 'cifar10':
        datamodule = CIFAR10DataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsCIFAR10()
        datamodule.val_transforms = CPCEvalTransformsCIFAR10()
        args.patch_size = 8

    elif args.dataset == 'stl10':
        datamodule = STL10DataModule.from_argparse_args(args)
        datamodule.train_dataloader = datamodule.train_dataloader_mixed
        datamodule.val_dataloader = datamodule.val_dataloader_mixed
        datamodule.train_transforms = CPCTrainTransformsSTL10()
        datamodule.val_transforms = CPCEvalTransformsSTL10()
        args.patch_size = 16

    elif args.dataset == 'imagenet2012':
        datamodule = SSLImagenetDataModule.from_argparse_args(args)
        datamodule.train_transforms = CPCTrainTransformsImageNet128()
        datamodule.val_transforms = CPCEvalTransformsImageNet128()
        args.patch_size = 32

    model = CPCV2(**vars(args), datamodule=datamodule)
    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model)