def test_moco(tmpdir, datadir): datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() model = Moco_v2(data_dir=datadir, batch_size=2, online_ft=True) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir, callbacks=[MocoLRScheduler()]) trainer.fit(model, datamodule=datamodule)
def test_moco(tmpdir): seed_everything() datamodule = CIFAR10DataModule(tmpdir, num_workers=0, batch_size=2) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() model = MocoV2(data_dir=tmpdir, batch_size=2, online_ft=True) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, callbacks=[MocoLRScheduler()]) trainer.fit(model, datamodule=datamodule) loss = trainer.progress_bar_dict['loss'] assert float(loss) > 0
def test_moco(tmpdir): reset_seed() datamodule = CIFAR10DataModule(tmpdir, num_workers=0, batch_size=2) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() model = MocoV2(data_dir=tmpdir, batch_size=2, datamodule=datamodule, online_ft=True) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, callbacks=[MocoLRScheduler()]) trainer.fit(model) loss = trainer.callback_metrics['loss'] assert loss > 0
num_workers=args.num_workers) else: raise ValueError() model = MocoV2(**vars(args), emb_spaces=datamodule.num_keys) if args.debug: logger = False checkpoint_callback = False else: logger = TensorBoardLogger(save_dir=os.path.join( os.getcwd(), 'logs', 'pretrain'), name=get_experiment_name(args)) checkpoint_callback = ModelCheckpoint(filename='{epoch}') scheduler = MocoLRScheduler(initial_lr=args.learning_rate, schedule=args.schedule, max_epochs=args.max_epochs) online_evaluator = SSLOnlineEvaluator( data_dir=args.online_data_dir, z_dim=model.mlp_dim, max_epochs=args.online_max_epochs, check_val_every_n_epoch=args.online_val_every_n_epoch) trainer = Trainer.from_argparse_args( args, logger=logger, checkpoint_callback=checkpoint_callback, callbacks=[scheduler, online_evaluator], max_epochs=args.max_epochs, weights_summary='full') trainer.fit(model, datamodule=datamodule)