def test_simsiam(tmpdir, datadir): datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) datamodule.train_transforms = SimCLRTrainDataTransform(32) datamodule.val_transforms = SimCLREvalDataTransform(32) model = SimSiam(batch_size=2, num_samples=datamodule.num_samples, gpus=0, nodes=1, dataset='cifar10') trainer = pl.Trainer(gpus=0, fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, datamodule=datamodule)
def test_simsiam(tmpdir, datadir): seed_everything() datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) datamodule.train_transforms = SimCLRTrainDataTransform(32) datamodule.val_transforms = SimCLREvalDataTransform(32) model = SimSiam(batch_size=2, num_samples=datamodule.num_samples, gpus=0, nodes=1, dataset='cifar10') trainer = pl.Trainer(gpus=0, fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model, datamodule) loss = trainer.progress_bar_dict['loss'] assert float(loss) < 0
def simsiam_example(): from pl_bolts.models.self_supervised import SimSiam from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform # Data module. dm = CIFAR10DataModule(num_workers=12, batch_size=32) dm.train_transforms = SimCLRTrainDataTransform(input_height=32) dm.val_transforms = SimCLREvalDataTransform(input_height=32) # Model. model = SimSiam(gpus=2, num_samples=dm.num_samples, batch_size=dm.batch_size, dataset="cifar10") # Fit. trainer = pl.Trainer(gpus=2, accelerator="ddp") trainer.fit(model, datamodule=dm)