def load_swav_imagenet( path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar"): swav = SwAV.load_from_checkpoint(path_or_url, strict=True) # remove the last two layers & turn it into a Sequential model backbone = nn.Sequential(*list(swav.model.children())[:-2]) return backbone, 2048
def test_swav(tmpdir, datadir, batch_size=2): # inputs, y = batch (doesn't receive y for some reason) datamodule = CIFAR10DataModule(data_dir=datadir, batch_size=batch_size, num_workers=0) datamodule.train_transforms = SwAVTrainDataTransform( normalize=cifar10_normalization(), size_crops=[32, 16], nmb_crops=[2, 1], gaussian_blur=False) datamodule.val_transforms = SwAVEvalDataTransform( normalize=cifar10_normalization(), size_crops=[32, 16], nmb_crops=[2, 1], gaussian_blur=False) model = SwAV(arch='resnet18', hidden_mlp=512, gpus=0, nodes=1, num_samples=datamodule.num_samples, batch_size=batch_size, nmb_crops=[2, 1], sinkhorn_iterations=1, nmb_prototypes=2, queue_length=0, maxpool1=False, first_conv=False, dataset='cifar10') trainer = Trainer(gpus=0, fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, datamodule=datamodule)
def swav_example(): from pl_bolts.models.self_supervised import SwAV from pl_bolts.datamodules import STL10DataModule from pl_bolts.models.self_supervised.swav.transforms import SwAVTrainDataTransform, SwAVEvalDataTransform from pl_bolts.transforms.dataset_normalizations import stl10_normalization batch_size = 128 # Data module. dm = STL10DataModule(data_dir=".", num_workers=16, batch_size=batch_size) dm.train_dataloader = dm.train_dataloader_mixed dm.val_dataloader = dm.val_dataloader_mixed dm.train_transforms = SwAVTrainDataTransform(normalize=stl10_normalization()) dm.val_transforms = SwAVEvalDataTransform(normalize=stl10_normalization()) # Model. model = SwAV( gpus=1, num_samples=dm.num_unlabeled_samples, dataset="stl10", batch_size=batch_size ) # Fit. trainer = pl.Trainer(gpus=2, accelerator="ddp", precision=16) trainer.fit(model, datamodule=dm) #-------------------- # ImageNet pretrained model: weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/bolts_swav_imagenet/swav_imagenet.ckpt" #weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar" # STL-10 pretrained model: #weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/swav_stl10.pth.tar" swav = SwAV.load_from_checkpoint(weight_path, strict=True) swav.freeze()
def test_swav(tmpdir, datadir): seed_everything() batch_size = 2 # inputs, y = batch (doesn't receive y for some reason) datamodule = CIFAR10DataModule( data_dir=datadir, batch_size=batch_size, num_workers=0 ) datamodule.train_transforms = SwAVTrainDataTransform( normalize=cifar10_normalization(), size_crops=[32, 16], nmb_crops=[2, 1], gaussian_blur=False ) datamodule.val_transforms = SwAVEvalDataTransform( normalize=cifar10_normalization(), size_crops=[32, 16], nmb_crops=[2, 1], gaussian_blur=False ) model = SwAV( arch='resnet18', hidden_mlp=512, gpus=0, nodes=1, num_samples=datamodule.num_samples, batch_size=batch_size, nmb_crops=[2, 1], sinkhorn_iterations=1, nmb_prototypes=2, maxpool1=False, first_conv=False, dataset='cifar10' ) trainer = pl.Trainer( gpus=0, fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, max_steps=3 ) trainer.fit(model, datamodule) loss = trainer.progress_bar_dict['loss'] assert float(loss) > 0
def load_swav_imagenet( path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar"): swav = SwAV.load_from_checkpoint(path_or_url, strict=True) model_config = {'model': swav.model, 'num_features': 3000} return model_config