def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform seed_everything(1234) parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = BYOL.add_model_specific_args(parser) args = parser.parse_args() # pick data dm = None # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) args.num_classes = dm.num_classes elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) args.num_classes = dm.num_classes elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) args.num_classes = dm.num_classes model = BYOL(**args.__dict__) def to_device(batch, device): (x1, x2), y = batch x1 = x1.to(device) y = y.to(device) return x1, y # finetune in real-time online_eval = SSLOnlineEvaluator(z_dim=2048, num_classes=dm.num_classes) online_eval.to_device = to_device trainer = pl.Trainer.from_argparse_args(args, max_steps=300000, callbacks=[online_eval]) trainer.fit(model, dm)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule pl.seed_everything(1234) parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = CPCV2.add_model_specific_args(parser) parser.add_argument('--dataset', default='cifar10', type=str) parser.add_argument('--data_dir', default='.', type=str) parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet') parser.add_argument('--num_workers', default=8, type=int) parser.add_argument('--batch_size', type=int, default=128) args = parser.parse_args() datamodule = None online_evaluator = SSLOnlineEvaluator() if args.dataset == 'cifar10': datamodule = CIFAR10DataModule.from_argparse_args(args) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() args.patch_size = 8 elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) datamodule.train_dataloader = datamodule.train_dataloader_mixed datamodule.val_dataloader = datamodule.val_dataloader_mixed datamodule.train_transforms = CPCTrainTransformsSTL10() datamodule.val_transforms = CPCEvalTransformsSTL10() args.patch_size = 16 # 16 GB RAM - 64 # 32 GB RAM - 144 args.batch_size = 144 def to_device(batch, device): (_, _), (x2, y2) = batch x2 = x2.to(device) y2 = y2.to(device) return x2, y2 online_evaluator.to_device = to_device elif args.dataset == 'imagenet2012': datamodule = SSLImagenetDataModule.from_argparse_args(args) datamodule.train_transforms = CPCTrainTransformsImageNet128() datamodule.val_transforms = CPCEvalTransformsImageNet128() args.patch_size = 32 model = CPCV2(**vars(args)) trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_evaluator]) trainer.fit(model, datamodule)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = SimCLR.add_model_specific_args(parser) args = parser.parse_args() # init default datamodule if args.dataset == 'cifar10': dm = CIFAR10DataModule.from_argparse_args(args) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) args.num_samples = dm.num_samples elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) elif args.dataset == 'imagenet2012': dm = ImagenetDataModule.from_argparse_args(args, image_size=196) (c, h, w) = dm.size() dm.train_transforms = SimCLRTrainDataTransform(h) dm.val_transforms = SimCLREvalDataTransform(h) model = SimCLR(**args.__dict__) # finetune in real-time def to_device(batch, device): (x1, x2), y = batch x1 = x1.to(device) y = y.to(device) return x1, y online_eval = SSLOnlineEvaluator(z_dim=2048 * 2 * 2, num_classes=dm.num_classes) online_eval.to_device = to_device trainer = pl.Trainer.from_argparse_args(args, callbacks=[online_eval]) trainer.fit(model, dm)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule from pl_bolts.models.self_supervised.swav.transforms import SwAVEvalDataTransform, SwAVTrainDataTransform parser = ArgumentParser() # model args parser = SwAV.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == 'stl10': dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples args.maxpool1 = False normalization = stl10_normalization() elif args.dataset == 'cifar10': args.batch_size = 2 args.num_workers = 0 dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False normalization = cifar10_normalization() # cifar10 specific params args.size_crops = [32, 16] args.nmb_crops = [2, 1] args.gaussian_blur = False elif args.dataset == 'imagenet': args.maxpool1 = True args.first_conv = True normalization = imagenet_normalization() args.size_crops = [224, 96] args.nmb_crops = [2, 6] args.min_scale_crops = [0.14, 0.05] args.max_scale_crops = [1., 0.14] args.gaussian_blur = True args.jitter_strength = 1. args.batch_size = 64 args.num_nodes = 8 args.gpus = 8 # per-node args.max_epochs = 800 args.optimizer = 'lars' args.learning_rate = 4.8 args.final_lr = 0.0048 args.start_lr = 0.3 args.nmb_prototypes = 3000 args.online_ft = True dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.input_height = dm.size()[-1] else: raise NotImplementedError( "other datasets have not been implemented till now") dm.train_transforms = SwAVTrainDataTransform( normalize=normalization, size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength) dm.val_transforms = SwAVEvalDataTransform( normalize=normalization, size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength) # swav model init model = SwAV(**args.__dict__) online_evaluator = None if args.online_ft: # online eval online_evaluator = SSLOnlineEvaluator( drop_p=0., hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset, ) lr_monitor = LearningRateMonitor(logging_interval="step") model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor='val_loss') callbacks = [model_checkpoint, online_evaluator ] if args.online_ft else [model_checkpoint] callbacks.append(lr_monitor) trainer = Trainer( max_epochs=args.max_epochs, max_steps=None if args.max_steps == -1 else args.max_steps, gpus=args.gpus, num_nodes=args.num_nodes, distributed_backend='ddp' if args.gpus > 1 else None, sync_batchnorm=True if args.gpus > 1 else False, precision=32 if args.fp32 else 16, callbacks=callbacks, fast_dev_run=args.fast_dev_run) trainer.fit(model, datamodule=dm)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule seed_everything(1234) parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) parser = CPC_v2.add_model_specific_args(parser) parser.add_argument("--dataset", default="cifar10", type=str) parser.add_argument("--data_dir", default=".", type=str) parser.add_argument("--meta_dir", default=".", type=str, help="path to meta.bin for imagenet") parser.add_argument("--num_workers", default=8, type=int) parser.add_argument("--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head") parser.add_argument("--batch_size", type=int, default=128) args = parser.parse_args() datamodule = None if args.dataset == "cifar10": datamodule = CIFAR10DataModule.from_argparse_args(args) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() args.patch_size = 8 elif args.dataset == "stl10": datamodule = STL10DataModule.from_argparse_args(args) datamodule.train_dataloader = datamodule.train_dataloader_mixed datamodule.val_dataloader = datamodule.val_dataloader_mixed datamodule.train_transforms = CPCTrainTransformsSTL10() datamodule.val_transforms = CPCEvalTransformsSTL10() args.patch_size = 16 elif args.dataset == "imagenet2012": datamodule = SSLImagenetDataModule.from_argparse_args(args) datamodule.train_transforms = CPCTrainTransformsImageNet128() datamodule.val_transforms = CPCEvalTransformsImageNet128() args.patch_size = 32 online_evaluator = SSLOnlineEvaluator( drop_p=0.0, hidden_dim=None, z_dim=args.hidden_mlp, num_classes=datamodule.num_classes, dataset=args.dataset, ) if args.dataset == "stl10": # 16 GB RAM - 64 # 32 GB RAM - 144 args.batch_size = 144 def to_device(batch, device): (_, _), (x2, y2) = batch x2 = x2.to(device) y2 = y2.to(device) return x2, y2 online_evaluator.to_device = to_device model = CPC_v2(**vars(args)) trainer = Trainer.from_argparse_args(args, callbacks=[online_evaluator]) trainer.fit(model, datamodule=datamodule)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform parser = ArgumentParser() # model args parser = SimCLR.add_model_specific_args(parser) parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args() if args.dataset == 'stl10': dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples args.maxpool1 = False args.first_conv = True args.input_height = dm.size()[-1] normalization = stl10_normalization() args.gaussian_blur = True args.jitter_strength = 1. elif args.dataset == 'cifar10': val_split = 5000 if args.num_nodes * args.gpus * args.batch_size > val_split: val_split = args.num_nodes * args.gpus * args.batch_size dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, val_split=val_split) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False args.input_height = dm.size()[-1] args.temperature = 0.5 normalization = cifar10_normalization() args.gaussian_blur = False args.jitter_strength = 0.5 elif args.dataset == 'imagenet': args.maxpool1 = True args.first_conv = True normalization = imagenet_normalization() args.gaussian_blur = True args.jitter_strength = 1. args.batch_size = 64 args.num_nodes = 8 args.gpus = 8 # per-node args.max_epochs = 800 args.optimizer = 'sgd' args.lars_wrapper = True args.learning_rate = 4.8 args.final_lr = 0.0048 args.start_lr = 0.3 args.online_ft = True dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.input_height = dm.size()[-1] else: raise NotImplementedError( "other datasets have not been implemented till now") dm.train_transforms = SimCLRTrainDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) dm.val_transforms = SimCLREvalDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) model = SimCLR(**args.__dict__) online_evaluator = None if args.online_ft: # online eval online_evaluator = SSLOnlineEvaluator(drop_p=0., hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset) model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor='val_loss') callbacks = [model_checkpoint, online_evaluator ] if args.online_ft else [model_checkpoint] trainer = pl.Trainer.from_argparse_args( args, sync_batchnorm=True if args.gpus > 1 else False, callbacks=callbacks, ) trainer.fit(model, datamodule=dm)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.models.self_supervised.swav.transforms import SwAVTrainDataTransform, SwAVEvalDataTransform from pl_bolts.datamodules import STL10DataModule, CIFAR10DataModule parser = ArgumentParser() # model args parser = SwAV.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == 'stl10': dm = STL10DataModule(data_dir=args.data_path, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples args.maxpool1 = False normalization = stl10_normalization() elif args.dataset == 'cifar10': args.batch_size = 2 args.num_workers = 0 dm = CIFAR10DataModule(data_dir=args.data_path, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False normalization = cifar10_normalization() # cifar10 specific params args.size_crops = [32, 16] args.nmb_crops = [2, 1] args.gaussian_blur = False else: raise NotImplementedError( "other datasets have not been implemented till now") dm.train_transforms = SwAVTrainDataTransform( normalize=stl10_normalization(), size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength) dm.val_transforms = SwAVEvalDataTransform( normalize=stl10_normalization(), size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength) # swav model init model = SwAV(**args.__dict__) online_evaluator = None if args.online_ft: # online eval online_evaluator = SSLOnlineEvaluator(drop_p=0., hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset) trainer = pl.Trainer( max_epochs=args.max_epochs, max_steps=None if args.max_steps == -1 else args.max_steps, gpus=args.gpus, distributed_backend='ddp' if args.gpus > 1 else None, sync_batchnorm=True if args.gpus > 1 else False, precision=32 if args.fp32 else 16, callbacks=[online_evaluator] if args.online_ft else None, fast_dev_run=args.fast_dev_run) trainer.fit(model, dm)
def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform seed_everything(1234) parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = SimSiam.add_model_specific_args(parser) args = parser.parse_args() # pick data dm = None # init datamodule if args.dataset == "stl10": dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed args.num_samples = dm.num_unlabeled_samples args.maxpool1 = False args.first_conv = True args.input_height = dm.size()[-1] normalization = stl10_normalization() args.gaussian_blur = True args.jitter_strength = 1.0 elif args.dataset == "cifar10": val_split = 5000 if args.nodes * args.gpus * args.batch_size > val_split: val_split = args.nodes * args.gpus * args.batch_size dm = CIFAR10DataModule( data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, val_split=val_split, ) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False args.input_height = dm.size()[-1] args.temperature = 0.5 normalization = cifar10_normalization() args.gaussian_blur = False args.jitter_strength = 0.5 elif args.dataset == "cifar100": val_split = 5000 if args.nodes * args.gpus * args.batch_size > val_split: val_split = args.nodes * args.gpus * args.batch_size dm = CIFAR100DataModule( data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, val_split=val_split, ) args.num_samples = dm.num_samples args.maxpool1 = False args.first_conv = False args.input_height = dm.size()[-1] args.temperature = 0.5 # ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)) normalization = transforms.Normalize( mean=(0.5071, 0.4866, 0.4409), std=(0.2009, 0.1984, 0.2023), ) args.gaussian_blur = False args.jitter_strength = 0.5 elif args.dataset == "imagenet": args.maxpool1 = True args.first_conv = True normalization = imagenet_normalization() args.gaussian_blur = True args.jitter_strength = 1.0 args.batch_size = 64 args.nodes = 8 args.gpus = 8 # per-node args.max_epochs = 800 args.optimizer = "sgd" args.lars_wrapper = True args.learning_rate = 4.8 args.final_lr = 0.0048 args.start_lr = 0.3 args.online_ft = True dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) args.num_samples = dm.num_samples args.input_height = dm.size()[-1] else: raise NotImplementedError("other datasets have not been implemented till now") dm.train_transforms = SimCLRTrainDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) dm.val_transforms = SimCLREvalDataTransform( input_height=args.input_height, gaussian_blur=args.gaussian_blur, jitter_strength=args.jitter_strength, normalize=normalization, ) model = SimSiam(**args.__dict__) # finetune in real-time online_evaluator = None if args.online_ft: # online eval online_evaluator = SSLOnlineEvaluator( drop_p=0.0, hidden_dim=None, z_dim=args.hidden_mlp, num_classes=dm.num_classes, dataset=args.dataset, ) trainer = pl.Trainer( max_epochs=args.max_epochs, max_steps=None if args.max_steps == -1 else args.max_steps, gpus=args.gpus, num_nodes=args.nodes, distributed_backend="ddp" if args.gpus > 1 else None, sync_batchnorm=True if args.gpus > 1 else False, precision=32 if args.fp32 else 16, callbacks=[online_evaluator] if args.online_ft else None, fast_dev_run=args.fast_dev_run, ) trainer.fit(model, dm)
def cli_main(): parser = ArgumentParser() parser.add_argument("--DATA_PATH", type=str, help="path to folders with images") parser.add_argument( "--encoder", default=None, type=str, help= "encoder to initialize. Can accept SimCLR model checkpoint or just encoder name in from encoders_dali" ) parser.add_argument("--batch_size", default=128, type=int, help="batch size for SSL") parser.add_argument("--num_workers", default=1, type=int, help="number of workers to use to fetch data") parser.add_argument( "--hidden_dims", default=128, type=int, help= "hidden dimensions in classification layer added onto model for finetuning" ) parser.add_argument("--epochs", default=400, type=int, help="number of epochs to train model") parser.add_argument("--lr", default=1e-3, type=float, help="learning rate for training model") parser.add_argument( "--patience", default=-1, type=int, help= "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping." ) parser.add_argument("--val_split", default=0.2, type=float, help="percent in validation data") parser.add_argument( "--withhold_split", default=0, type=float, help= "decimal from 0-1 representing how much of the training data to withold from either training or validation. Used for experimenting with labels neeeded" ) parser.add_argument("--gpus", default=1, type=int, help="number of gpus to use for training") parser.add_argument("--log_name", type=str, help="name of model to log on wandb and locally") parser.add_argument( "--online_eval", default=False, type=bool, help="Do finetuning on model if labels are provided as a sanity check") args = parser.parse_args() DATA_PATH = args.DATA_PATH batch_size = args.batch_size num_workers = args.num_workers hidden_dims = args.hidden_dims epochs = args.epochs lr = args.lr patience = args.patience val_split = args.val_split withhold = args.withhold_split gpus = args.gpus encoder = args.encoder log_name = 'SIMCLR_SSL_' + args.log_name + '.ckpt' online_eval = args.online_eval wandb_logger = WandbLogger(name=log_name, project='SpaceForce') checkpointed = '.ckpt' in encoder if checkpointed: print('Resuming SSL Training from Model Checkpoint') try: model = SIMCLR.load_from_checkpoint(checkpoint_path=encoder) embedding_size = model.embedding_size except Exception as e: print(e) print( 'invalid checkpoint to initialize SIMCLR. This checkpoint needs to include the encoder and projection and is of the SIMCLR class from this library. Will try to initialize just the encoder' ) checkpointed = False elif not checkpointed: encoder, embedding_size = load_encoder(encoder) model = SIMCLR(encoder=encoder, embedding_size=embedding_size, gpus=gpus, epochs=epochs, DATA_PATH=DATA_PATH, withhold=withhold, batch_size=batch_size, val_split=val_split, hidden_dims=hidden_dims, train_transform=SimCLRTrainDataTransform, val_transform=SimCLRTrainDataTransform, num_workers=num_workers, lr=lr) online_evaluator = SSLOnlineEvaluator(drop_p=0., hidden_dim=None, z_dim=embedding_size, num_classes=model.num_classes, dataset='None') cbs = [] backend = 'dp' if patience > 0: cb = EarlyStopping('val_loss', patience=patience) cbs.append(cb) if online_eval: cbs.append(online_evaluator) backend = 'ddp' trainer = Trainer( gpus=gpus, max_epochs=epochs, progress_bar_refresh_rate=5, callbacks=cbs, distributed_backend=f'{backend}' if args.gpus > 1 else None, logger=wandb_logger, enable_pl_optimizer=True) print('USING BACKEND______________________________ ', backend) trainer.fit(model) Path(f"./models/SSL").mkdir(parents=True, exist_ok=True) trainer.save_checkpoint(f"./models/SSL/{log_name}")