コード例 #1
0
def test_moco(tmpdir, datadir):
    datamodule = CIFAR10DataModule(data_dir=datadir, num_workers=0, batch_size=2)
    datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
    datamodule.val_transforms = Moco2EvalCIFAR10Transforms()

    model = Moco_v2(data_dir=datadir, batch_size=2, online_ft=True)
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir, callbacks=[MocoLRScheduler()])
    trainer.fit(model, datamodule=datamodule)
コード例 #2
0
def test_moco(tmpdir):
    seed_everything()

    datamodule = CIFAR10DataModule(tmpdir, num_workers=0, batch_size=2)
    datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
    datamodule.val_transforms = Moco2EvalCIFAR10Transforms()

    model = MocoV2(data_dir=tmpdir, batch_size=2, online_ft=True)
    trainer = pl.Trainer(fast_dev_run=True,
                         max_epochs=1,
                         default_root_dir=tmpdir,
                         callbacks=[MocoLRScheduler()])
    trainer.fit(model, datamodule=datamodule)
    loss = trainer.progress_bar_dict['loss']

    assert float(loss) > 0
コード例 #3
0
def test_moco(tmpdir):
    reset_seed()

    datamodule = CIFAR10DataModule(tmpdir, num_workers=0, batch_size=2)
    datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
    datamodule.val_transforms = Moco2EvalCIFAR10Transforms()

    model = MocoV2(data_dir=tmpdir,
                   batch_size=2,
                   datamodule=datamodule,
                   online_ft=True)
    trainer = pl.Trainer(fast_dev_run=True,
                         max_epochs=1,
                         default_root_dir=tmpdir,
                         callbacks=[MocoLRScheduler()])
    trainer.fit(model)
    loss = trainer.callback_metrics['loss']

    assert loss > 0
コード例 #4
0
            num_workers=args.num_workers)
    else:
        raise ValueError()

    model = MocoV2(**vars(args), emb_spaces=datamodule.num_keys)

    if args.debug:
        logger = False
        checkpoint_callback = False
    else:
        logger = TensorBoardLogger(save_dir=os.path.join(
            os.getcwd(), 'logs', 'pretrain'),
                                   name=get_experiment_name(args))
        checkpoint_callback = ModelCheckpoint(filename='{epoch}')
    scheduler = MocoLRScheduler(initial_lr=args.learning_rate,
                                schedule=args.schedule,
                                max_epochs=args.max_epochs)
    online_evaluator = SSLOnlineEvaluator(
        data_dir=args.online_data_dir,
        z_dim=model.mlp_dim,
        max_epochs=args.online_max_epochs,
        check_val_every_n_epoch=args.online_val_every_n_epoch)

    trainer = Trainer.from_argparse_args(
        args,
        logger=logger,
        checkpoint_callback=checkpoint_callback,
        callbacks=[scheduler, online_evaluator],
        max_epochs=args.max_epochs,
        weights_summary='full')
    trainer.fit(model, datamodule=datamodule)