def cli_main():

    parser = ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = MocoV2.add_model_specific_args(parser)
    args = parser.parse_args()

    if args.dataset == 'cifar10':
        datamodule = CIFAR10DataModule.from_argparse_args(args)
        datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
        datamodule.val_transforms = Moco2EvalCIFAR10Transforms()

    elif args.dataset == 'stl10':
        datamodule = STL10DataModule.from_argparse_args(args)
        datamodule.train_dataloader = datamodule.train_dataloader_mixed
        datamodule.val_dataloader = datamodule.val_dataloader_mixed
        datamodule.train_transforms = Moco2TrainSTL10Transforms()
        datamodule.val_transforms = Moco2EvalSTL10Transforms()

    elif args.dataset == 'imagenet2012':
        datamodule = SSLImagenetDataModule.from_argparse_args(args)
        datamodule.train_transforms = Moco2TrainImagenetTransforms()
        datamodule.val_transforms = Moco2EvalImagenetTransforms()

    model = MocoV2(**args.__dict__, datamodule=datamodule)

    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model)
Пример #2
0
def cli_main():
    from pl_bolts.datamodules import CIFAR10DataModule, SSLImagenetDataModule, STL10DataModule

    parser = ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = MocoV2.add_model_specific_args(parser)
    args = parser.parse_args()

    if args.dataset == 'cifar10':
        datamodule = CIFAR10DataModule.from_argparse_args(args)
        datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
        datamodule.val_transforms = Moco2EvalCIFAR10Transforms()

    elif args.dataset == 'stl10':
        datamodule = STL10DataModule.from_argparse_args(args)
        datamodule.train_dataloader = datamodule.train_dataloader_mixed
        datamodule.val_dataloader = datamodule.val_dataloader_mixed
        datamodule.train_transforms = Moco2TrainSTL10Transforms()
        datamodule.val_transforms = Moco2EvalSTL10Transforms()

    elif args.dataset == 'imagenet2012':
        datamodule = SSLImagenetDataModule.from_argparse_args(args)
        datamodule.train_transforms = Moco2TrainImagenetTransforms()
        datamodule.val_transforms = Moco2EvalImagenetTransforms()

    else:
        # replace with your own dataset, otherwise CIFAR-10 will be used by default if `None` passed in
        datamodule = None

    model = MocoV2(**args.__dict__)
    wandb_logger = WandbLogger(name='Baseline', project='MocoV2')
    trainer = pl.Trainer.from_argparse_args(args, logger=wandb_logger)
    trainer.fit(model, datamodule=datamodule)
    wandb.finish()
Пример #3
0
def cli_main():
    from pl_bolts.datamodules import CIFAR10DataModule, SSLImagenetDataModule, STL10DataModule

    parser = ArgumentParser()

    # trainer args
    parser = Trainer.add_argparse_args(parser)

    # model args
    parser = Moco_v2.add_model_specific_args(parser)
    args = parser.parse_args()

    if args.dataset == "cifar10":
        datamodule = CIFAR10DataModule.from_argparse_args(args)
        datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
        datamodule.val_transforms = Moco2EvalCIFAR10Transforms()

    elif args.dataset == "stl10":
        datamodule = STL10DataModule.from_argparse_args(args)
        datamodule.train_dataloader = datamodule.train_dataloader_mixed
        datamodule.val_dataloader = datamodule.val_dataloader_mixed
        datamodule.train_transforms = Moco2TrainSTL10Transforms()
        datamodule.val_transforms = Moco2EvalSTL10Transforms()

    elif args.dataset == "imagenet2012":
        datamodule = SSLImagenetDataModule.from_argparse_args(args)
        datamodule.train_transforms = Moco2TrainImagenetTransforms()
        datamodule.val_transforms = Moco2EvalImagenetTransforms()

    else:
        # replace with your own dataset, otherwise CIFAR-10 will be used by default if `None` passed in
        datamodule = None

    model = Moco_v2(**args.__dict__)

    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model, datamodule=datamodule)
    parser = ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = MocoV2.add_model_specific_args(parser)
    args = parser.parse_args()

    if args.dataset == 'cifar10':
        datamodule = CIFAR10DataModule.from_argparse_args(args)
        datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
        datamodule.val_transforms = Moco2EvalCIFAR10Transforms()

    elif args.dataset == 'stl10':
        datamodule = STL10DataModule.from_argparse_args(args)
        datamodule.train_dataloader = datamodule.train_dataloader_mixed
        datamodule.val_dataloader = datamodule.val_dataloader_mixed
        datamodule.train_transforms = Moco2TrainSTL10Transforms()
        datamodule.val_transforms = Moco2EvalSTL10Transforms()

    elif args.dataset == 'imagenet2012':
        datamodule = SSLImagenetDataModule.from_argparse_args(args)
        datamodule.train_transforms = Moco2TrainImagenetTransforms()
        datamodule.val_transforms = Moco2EvalImagenetTransforms()

    model = MocoV2(**args.__dict__, datamodule=datamodule)

    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model)
 def train_transform():
     return Moco2TrainImagenetTransforms(height=224).train_transform
 def train_transform():
     transform = Moco2TrainImagenetTransforms(height=224).train_transform
     return ApplyN(transform=transform, n=2)