def cli_main(): parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = MocoV2.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == 'cifar10': datamodule = CIFAR10DataModule.from_argparse_args(args) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) datamodule.train_dataloader = datamodule.train_dataloader_mixed datamodule.val_dataloader = datamodule.val_dataloader_mixed datamodule.train_transforms = Moco2TrainSTL10Transforms() datamodule.val_transforms = Moco2EvalSTL10Transforms() elif args.dataset == 'imagenet2012': datamodule = SSLImagenetDataModule.from_argparse_args(args) datamodule.train_transforms = Moco2TrainImagenetTransforms() datamodule.val_transforms = Moco2EvalImagenetTransforms() model = MocoV2(**args.__dict__, datamodule=datamodule) trainer = pl.Trainer.from_argparse_args(args) trainer.fit(model)
def test_moco(tmpdir, datadir): datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() model = Moco_v2(data_dir=datadir, batch_size=2, online_ft=True) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir, callbacks=[MocoLRScheduler()]) trainer.fit(model, datamodule=datamodule)
def test_moco(tmpdir): seed_everything() datamodule = CIFAR10DataModule(tmpdir, num_workers=0, batch_size=2) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() model = MocoV2(data_dir=tmpdir, batch_size=2, online_ft=True) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, callbacks=[MocoLRScheduler()]) trainer.fit(model, datamodule=datamodule) loss = trainer.progress_bar_dict['loss'] assert float(loss) > 0
def test_moco(tmpdir): reset_seed() datamodule = CIFAR10DataModule(tmpdir, num_workers=0) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() model = MocoV2(data_dir=tmpdir, batch_size=2, datamodule=datamodule, online_ft=True) trainer = pl.Trainer(overfit_batches=2, max_epochs=1, default_root_dir=tmpdir, callbacks=[MocoLRScheduler()]) trainer.fit(model) loss = trainer.callback_metrics['loss'] assert loss > 0
def cli_main(): from pl_bolts.datamodules import CIFAR10DataModule, SSLImagenetDataModule, STL10DataModule parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = MocoV2.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == 'cifar10': datamodule = CIFAR10DataModule.from_argparse_args(args) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) datamodule.train_dataloader = datamodule.train_dataloader_mixed datamodule.val_dataloader = datamodule.val_dataloader_mixed datamodule.train_transforms = Moco2TrainSTL10Transforms() datamodule.val_transforms = Moco2EvalSTL10Transforms() elif args.dataset == 'imagenet2012': datamodule = SSLImagenetDataModule.from_argparse_args(args) datamodule.train_transforms = Moco2TrainImagenetTransforms() datamodule.val_transforms = Moco2EvalImagenetTransforms() else: # replace with your own dataset, otherwise CIFAR-10 will be used by default if `None` passed in datamodule = None model = MocoV2(**args.__dict__) wandb_logger = WandbLogger(name='Baseline', project='MocoV2') trainer = pl.Trainer.from_argparse_args(args, logger=wandb_logger) trainer.fit(model, datamodule=datamodule) wandb.finish()
def cli_main(): from pl_bolts.datamodules import CIFAR10DataModule, SSLImagenetDataModule, STL10DataModule parser = ArgumentParser() # trainer args parser = Trainer.add_argparse_args(parser) # model args parser = Moco_v2.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == "cifar10": datamodule = CIFAR10DataModule.from_argparse_args(args) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() elif args.dataset == "stl10": datamodule = STL10DataModule.from_argparse_args(args) datamodule.train_dataloader = datamodule.train_dataloader_mixed datamodule.val_dataloader = datamodule.val_dataloader_mixed datamodule.train_transforms = Moco2TrainSTL10Transforms() datamodule.val_transforms = Moco2EvalSTL10Transforms() elif args.dataset == "imagenet2012": datamodule = SSLImagenetDataModule.from_argparse_args(args) datamodule.train_transforms = Moco2TrainImagenetTransforms() datamodule.val_transforms = Moco2EvalImagenetTransforms() else: # replace with your own dataset, otherwise CIFAR-10 will be used by default if `None` passed in datamodule = None model = Moco_v2(**args.__dict__) trainer = Trainer.from_argparse_args(args) trainer.fit(model, datamodule=datamodule)
def __init__(self, base_encoder: Union[str, torch.nn.Module] = 'resnet18', emb_dim: int = 128, num_negatives: int = 65536, encoder_momentum: float = 0.999, softmax_temperature: float = 0.07, learning_rate: float = 0.03, momentum: float = 0.9, weight_decay: float = 1e-4, datamodule: pl.LightningDataModule = None, data_dir: str = './', batch_size: int = 256, use_mlp: bool = False, num_workers: int = 8, *args, **kwargs): """ PyTorch Lightning implementation of `Moco <https://arxiv.org/abs/2003.04297>`_ Paper authors: Xinlei Chen, Haoqi Fan, Ross Girshick, Kaiming He. Code adapted from `facebookresearch/moco <https://github.com/facebookresearch/moco>`_ to Lightning by: - `William Falcon <https://github.com/williamFalcon>`_ Example: >>> from pl_bolts.models.self_supervised import MocoV2 ... >>> model = MocoV2() Train:: trainer = Trainer() trainer.fit(model) CLI command:: # cifar10 python moco2_module.py --gpus 1 # imagenet python moco2_module.py --gpus 8 --dataset imagenet2012 --data_dir /path/to/imagenet/ --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32 Args: base_encoder: torchvision model name or torch.nn.Module emb_dim: feature dimension (default: 128) num_negatives: queue size; number of negative keys (default: 65536) encoder_momentum: moco momentum of updating key encoder (default: 0.999) softmax_temperature: softmax temperature (default: 0.07) learning_rate: the learning rate momentum: optimizer momentum weight_decay: optimizer weight decay datamodule: the DataModule (train, val, test dataloaders) data_dir: the directory to store data batch_size: batch size use_mlp: add an mlp to the encoders num_workers: workers for the loaders """ super().__init__() self.save_hyperparameters() # use CIFAR-10 by default if no datamodule passed in if datamodule is None: datamodule = CIFAR10DataModule(data_dir) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() self.datamodule = datamodule # create the encoders # num_classes is the output fc dimension self.encoder_q, self.encoder_k = self.init_encoders(base_encoder) if use_mlp: # hack: brute-force replacement dim_mlp = self.encoder_q.fc.weight.shape[1] self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): param_k.data.copy_(param_q.data) # initialize param_k.requires_grad = False # not update by gradient # create the queue self.register_buffer("queue", torch.randn(emb_dim, num_negatives)) self.queue = nn.functional.normalize(self.queue, dim=0) self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
# todo: covert to CLI func and add test if __name__ == '__main__': from argparse import ArgumentParser parser = ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = MocoV2.add_model_specific_args(parser) args = parser.parse_args() if args.dataset == 'cifar10': datamodule = CIFAR10DataModule.from_argparse_args(args) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() elif args.dataset == 'stl10': datamodule = STL10DataModule.from_argparse_args(args) datamodule.train_dataloader = datamodule.train_dataloader_mixed datamodule.val_dataloader = datamodule.val_dataloader_mixed datamodule.train_transforms = Moco2TrainSTL10Transforms() datamodule.val_transforms = Moco2EvalSTL10Transforms() elif args.dataset == 'imagenet2012': datamodule = SSLImagenetDataModule.from_argparse_args(args) datamodule.train_transforms = Moco2TrainImagenetTransforms() datamodule.val_transforms = Moco2EvalImagenetTransforms() model = MocoV2(**args.__dict__, datamodule=datamodule)