예제 #1
0
 def load_swav_imagenet(
     path_or_url:
     str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar"):
     swav = SwAV.load_from_checkpoint(path_or_url, strict=True)
     # remove the last two layers & turn it into a Sequential model
     backbone = nn.Sequential(*list(swav.model.children())[:-2])
     return backbone, 2048
예제 #2
0
def swav_example():
	from pl_bolts.models.self_supervised import SwAV
	from pl_bolts.datamodules import STL10DataModule
	from pl_bolts.models.self_supervised.swav.transforms import SwAVTrainDataTransform, SwAVEvalDataTransform
	from pl_bolts.transforms.dataset_normalizations import stl10_normalization

	batch_size = 128

	# Data module.
	dm = STL10DataModule(data_dir=".", num_workers=16, batch_size=batch_size)
	dm.train_dataloader = dm.train_dataloader_mixed
	dm.val_dataloader = dm.val_dataloader_mixed
	dm.train_transforms = SwAVTrainDataTransform(normalize=stl10_normalization())
	dm.val_transforms = SwAVEvalDataTransform(normalize=stl10_normalization())

	# Model.
	model = SwAV(
		gpus=1,
		num_samples=dm.num_unlabeled_samples,
		dataset="stl10",
		batch_size=batch_size
	)

	# Fit.
	trainer = pl.Trainer(gpus=2, accelerator="ddp", precision=16)
	trainer.fit(model, datamodule=dm)

	#--------------------
	# ImageNet pretrained model:
	weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/bolts_swav_imagenet/swav_imagenet.ckpt"
	#weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar"
	# STL-10 pretrained model:
	#weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/swav_stl10.pth.tar"
	swav = SwAV.load_from_checkpoint(weight_path, strict=True)

	swav.freeze()
예제 #3
0
def load_swav_imagenet(
    path_or_url:
    str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar"):
    swav = SwAV.load_from_checkpoint(path_or_url, strict=True)
    model_config = {'model': swav.model, 'num_features': 3000}
    return model_config