예제 #1
0
 def load_swav_imagenet(
     path_or_url:
     str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar"):
     swav = SwAV.load_from_checkpoint(path_or_url, strict=True)
     # remove the last two layers & turn it into a Sequential model
     backbone = nn.Sequential(*list(swav.model.children())[:-2])
     return backbone, 2048
def test_swav(tmpdir, datadir, batch_size=2):
    # inputs, y = batch  (doesn't receive y for some reason)
    datamodule = CIFAR10DataModule(data_dir=datadir,
                                   batch_size=batch_size,
                                   num_workers=0)

    datamodule.train_transforms = SwAVTrainDataTransform(
        normalize=cifar10_normalization(),
        size_crops=[32, 16],
        nmb_crops=[2, 1],
        gaussian_blur=False)
    datamodule.val_transforms = SwAVEvalDataTransform(
        normalize=cifar10_normalization(),
        size_crops=[32, 16],
        nmb_crops=[2, 1],
        gaussian_blur=False)

    model = SwAV(arch='resnet18',
                 hidden_mlp=512,
                 gpus=0,
                 nodes=1,
                 num_samples=datamodule.num_samples,
                 batch_size=batch_size,
                 nmb_crops=[2, 1],
                 sinkhorn_iterations=1,
                 nmb_prototypes=2,
                 queue_length=0,
                 maxpool1=False,
                 first_conv=False,
                 dataset='cifar10')

    trainer = Trainer(gpus=0, fast_dev_run=True, default_root_dir=tmpdir)

    trainer.fit(model, datamodule=datamodule)
예제 #3
0
def swav_example():
	from pl_bolts.models.self_supervised import SwAV
	from pl_bolts.datamodules import STL10DataModule
	from pl_bolts.models.self_supervised.swav.transforms import SwAVTrainDataTransform, SwAVEvalDataTransform
	from pl_bolts.transforms.dataset_normalizations import stl10_normalization

	batch_size = 128

	# Data module.
	dm = STL10DataModule(data_dir=".", num_workers=16, batch_size=batch_size)
	dm.train_dataloader = dm.train_dataloader_mixed
	dm.val_dataloader = dm.val_dataloader_mixed
	dm.train_transforms = SwAVTrainDataTransform(normalize=stl10_normalization())
	dm.val_transforms = SwAVEvalDataTransform(normalize=stl10_normalization())

	# Model.
	model = SwAV(
		gpus=1,
		num_samples=dm.num_unlabeled_samples,
		dataset="stl10",
		batch_size=batch_size
	)

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp", precision=16)
	trainer.fit(model, datamodule=dm)

	#--------------------
	# ImageNet pretrained model:
	weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/bolts_swav_imagenet/swav_imagenet.ckpt"
	#weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar"
	# STL-10 pretrained model:
	#weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/swav_stl10.pth.tar"
	swav = SwAV.load_from_checkpoint(weight_path, strict=True)

	swav.freeze()
def test_swav(tmpdir, datadir):
    seed_everything()

    batch_size = 2

    # inputs, y = batch  (doesn't receive y for some reason)
    datamodule = CIFAR10DataModule(
        data_dir=datadir,
        batch_size=batch_size,
        num_workers=0
    )

    datamodule.train_transforms = SwAVTrainDataTransform(
        normalize=cifar10_normalization(),
        size_crops=[32, 16],
        nmb_crops=[2, 1],
        gaussian_blur=False
    )
    datamodule.val_transforms = SwAVEvalDataTransform(
        normalize=cifar10_normalization(),
        size_crops=[32, 16],
        nmb_crops=[2, 1],
        gaussian_blur=False
    )

    model = SwAV(
        arch='resnet18',
        hidden_mlp=512,
        gpus=0,
        nodes=1,
        num_samples=datamodule.num_samples,
        batch_size=batch_size,
        nmb_crops=[2, 1],
        sinkhorn_iterations=1,
        nmb_prototypes=2,
        maxpool1=False,
        first_conv=False,
        dataset='cifar10'
    )

    trainer = pl.Trainer(
        gpus=0, fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, max_steps=3
    )

    trainer.fit(model, datamodule)
    loss = trainer.progress_bar_dict['loss']

    assert float(loss) > 0
예제 #5
0
def load_swav_imagenet(
    path_or_url:
    str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar"):
    swav = SwAV.load_from_checkpoint(path_or_url, strict=True)
    model_config = {'model': swav.model, 'num_features': 3000}
    return model_config