Exemple #1
0
def test_swa_raises():
    with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"):
        StochasticWeightAveraging(swa_epoch_start=0, swa_lrs=0.1)
    with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"):
        StochasticWeightAveraging(swa_epoch_start=1.5, swa_lrs=0.1)
    with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"):
        StochasticWeightAveraging(swa_epoch_start=-1, swa_lrs=0.1)
    with pytest.raises(MisconfigurationException, match="positive float or a list of positive float"):
        StochasticWeightAveraging(swa_epoch_start=5, swa_lrs=[0.2, 1])
def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks,
                                           stochastic_weight_avg):
    """Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer"""
    class TestModel(BoringModel):
        def configure_optimizers(self):
            optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
            return optimizer

    model = TestModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=StochasticWeightAveraging(
            swa_lrs=1e-3) if use_callbacks else None,
        stochastic_weight_avg=stochastic_weight_avg,
        limit_train_batches=4,
        limit_val_batches=4,
        max_epochs=2,
    )
    trainer.fit(model)
    if use_callbacks or stochastic_weight_avg:
        assert len([
            cb for cb in trainer.callbacks
            if isinstance(cb, StochasticWeightAveraging)
        ]) == 1
        assert trainer.callbacks[0]._swa_lrs == (1e-3
                                                 if use_callbacks else 0.1)
    else:
        assert all(not isinstance(cb, StochasticWeightAveraging)
                   for cb in trainer.callbacks)
Exemple #3
0
def test_swa_multiple_lrs(tmpdir):
    swa_lrs = [0.123, 0.321]

    class TestModel(BoringModel):
        def __init__(self):
            super(BoringModel, self).__init__()
            self.layer1 = torch.nn.Linear(32, 32)
            self.layer2 = torch.nn.Linear(32, 2)

        def forward(self, x):
            x = self.layer1(x)
            x = self.layer2(x)
            return x

        def configure_optimizers(self):
            params = [{"params": self.layer1.parameters(), "lr": 0.1}, {"params": self.layer2.parameters(), "lr": 0.2}]
            return torch.optim.Adam(params)

        def on_train_epoch_start(self):
            optimizer = trainer.optimizers[0]
            assert [pg["lr"] for pg in optimizer.param_groups] == [0.1, 0.2]
            assert [pg["initial_lr"] for pg in optimizer.param_groups] == swa_lrs
            assert [pg["swa_lr"] for pg in optimizer.param_groups] == swa_lrs
            self.on_train_epoch_start_called = True

    model = TestModel()
    swa_callback = StochasticWeightAveraging(swa_lrs=swa_lrs)
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=swa_callback,
        fast_dev_run=1,
    )
    trainer.fit(model)
    assert model.on_train_epoch_start_called
Exemple #4
0
def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks: bool,
                                           stochastic_weight_avg: bool):
    """Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer."""
    class TestModel(BoringModel):
        def configure_optimizers(self):
            optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
            return optimizer

    model = TestModel()
    kwargs = {
        "default_root_dir": tmpdir,
        "callbacks":
        StochasticWeightAveraging(swa_lrs=1e-3) if use_callbacks else None,
        "stochastic_weight_avg": stochastic_weight_avg,
        "limit_train_batches": 4,
        "limit_val_batches": 4,
        "max_epochs": 2,
    }
    if stochastic_weight_avg:
        with pytest.deprecated_call(
                match=r"stochastic_weight_avg=True\)` is deprecated in v1.5"):
            trainer = Trainer(**kwargs)
    else:
        trainer = Trainer(**kwargs)
    trainer.fit(model)
    if use_callbacks or stochastic_weight_avg:
        assert sum(1 for cb in trainer.callbacks
                   if isinstance(cb, StochasticWeightAveraging)) == 1
        assert trainer.callbacks[0]._swa_lrs == [
            1e-3 if use_callbacks else 0.1
        ]
    else:
        assert all(not isinstance(cb, StochasticWeightAveraging)
                   for cb in trainer.callbacks)
Exemple #5
0
def test_advanced_profiler_cprofile_deepcopy(tmpdir):
    """Checks for pickle issue reported in #6522."""
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir, fast_dev_run=True, profiler="advanced", callbacks=StochasticWeightAveraging()
    )
    trainer.fit(model)
Exemple #6
0
def test_swa_warns(tmpdir, caplog):
    model = SwaTestModel(interval="step")
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      callbacks=StochasticWeightAveraging())
    with caplog.at_level(level=logging.INFO), pytest.warns(
            UserWarning, match="SWA is currently only supported"):
        trainer.fit(model)
    assert "Swapping scheduler `StepLR` for `SWALR`" in caplog.text
Exemple #7
0
def main(args, dm_setup='fit'):
    pl.seed_everything(args.seed)

    dm = SaltDM.from_argparse_args(args)
    dm.setup(dm_setup)
    model = LitResUnet(**vars(args))
    model.hparams.update(dm.kwargs)

    # print(wandb_logger.log_dir)

    # checkpoint_dir = os.path.join(tt_logger.log_dir, 'ckpt')

    checkpoint_dir = os.path.join(
        '../params/{}/f{:02d}_{}'.format(args.model, args.val_fold_idx,
                                         args.version), 'ckpt')

    checkpoint_file_path = os.path.join(checkpoint_dir, 'last.ckpt')
    if os.path.isfile(checkpoint_file_path):
        args.resume_from_checkpoint = checkpoint_file_path
        print('Detect checkpoint:', args.resume_from_checkpoint)

    # Callbacks
    callbacks = []
    if args.loss_func == 'lovasz_hinge':
        callbacks.append(ToLovaszHingeLossCB())
    if args.logger_type != 'none':
        callbacks.append(LearningRateMonitor())
    if args.swa_epoch_start > 0:
        callbacks.append(
            StochasticWeightAveraging(swa_epoch_start=args.swa_epoch_start))
    if args.snapshot_size > 0:
        callbacks.append(ResetSnapshotCB(checkpoint_dir))

    trainer = pl.Trainer.from_argparse_args(
        args,
        # logger=tt_logger,
        logger=get_logger(args),
        callbacks=callbacks,
        #                                             callbacks=[ResetSnapshotCB(checkpoint_dir)],
        checkpoint_callback=checkpointcb(args, checkpoint_dir))
    # trainer = pl.Trainer.from_argparse_args(args, logger=tt_logger, callbacks=[ResetSnapshotCB(), LearningRateMonitor()], checkpoint_callback=False)
    return dm, model, trainer
Exemple #8
0
 def transfer_weights(self, *args, **kwargs):
     self.transfer_weights_calls += 1
     return StochasticWeightAveraging.transfer_weights(*args, **kwargs)
Exemple #9
0
 def update_parameters(self, *args, **kwargs):
     self.update_parameters_calls += 1
     return StochasticWeightAveraging.update_parameters(*args, **kwargs)
Exemple #10
0
        monitor=f'val_loss_fold_{f}',
        dirpath='model_dir',
        filename=f"{model_name}_loss_fold_{f}",
        save_top_k=1,
        mode='min',
    )

    checkpoint_callback2 = ModelCheckpoint(
        monitor=f'val_roc_auc_fold_{f}',
        dirpath='model_dir',
        filename=f"{model_name}_roc_auc_fold_{f}",
        save_top_k=1,
        mode='max',
    )
    lr_monitor = LearningRateMonitor(logging_interval='step')
    swa_callback = StochasticWeightAveraging()

    trainer = pl.Trainer(
        max_epochs=n_epochs,
        precision=16,
        # auto_lr_find=True,  # Usually the auto is pretty bad. You should instead plot and pick manually.
        gradient_clip_val=100,
        num_sanity_val_steps=10,
        profiler="simple",
        weights_summary='top',
        accumulate_grad_batches=accum_step,
        logger=[wandb_logger],
        checkpoint_callback=True,
        gpus=gpu_ids,
        num_processes=4 * len(gpu_ids),
        stochastic_weight_avg=True,
Exemple #11
0
def run_lightning(argv=None):
    '''Run training with PyTorch Lightning'''
    global RANK
    from pytorch_lightning.loggers import WandbLogger
    import numpy as np
    import traceback
    import os
    import pprint

    pformat = pprint.PrettyPrinter(sort_dicts=False, width=100,
                                   indent=2).pformat

    model, args, addl_targs, data_mod = process_args(parse_args(argv=argv))

    # if 'OMPI_COMM_WORLD_RANK' in os.environ or 'SLURMD_NODENAME' in os.environ:
    #     from mpi4py import MPI
    #     comm = MPI.COMM_WORLD
    #     RANK = comm.Get_rank()
    # else:
    #     RANK = 0
    #     print('OMPI_COMM_WORLD_RANK or SLURMD_NODENAME not set in environment -- not using MPI')

    # output is a wrapper function for os.path.join(outdir, <FILE>)
    outdir, output = process_output(args)
    check_directory(outdir)
    if not args.quiet:
        print0(' '.join(sys.argv), file=sys.stderr)
        print0("Processed Args:\n", pformat(vars(args)), file=sys.stderr)

    # save arguments
    with open(output('args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    checkpoint = None
    if args.init is not None:
        checkpoint = args.init
        link_dest = 'init.ckpt'
    elif args.checkpoint is not None:
        checkpoint = args.checkpoint
        link_dest = 'resumed_from.ckpt'

    if checkpoint is not None:
        if RANK == 0:
            print0(f'symlinking to {args.checkpoint} from {outdir}')
            dest = output(link_dest)
            src = os.path.relpath(checkpoint, start=outdir)
            if os.path.exists(dest):
                existing_src = os.readlink(dest)
                if existing_src != src:
                    msg = f'Cannot create symlink to checkpoint -- {dest} already exists, but points to {existing_src}'
                    raise RuntimeError(msg)
            else:
                os.symlink(src, dest)

    seed_everything(args.seed)

    if args.csv:
        logger = CSVLogger(save_dir=output('logs')),
    else:
        logger = WandbLogger(project="deep-taxon",
                             entity='deep-taxon',
                             name=args.experiment)

    # get dataset so we can set model parameters that are
    # dependent on the dataset, such as final number of outputs

    monitor, mode = (AbstractLit.val_loss,
                     'min') if args.manifold else (AbstractLit.val_acc, 'max')
    callbacks = [
        LearningRateMonitor(logging_interval='epoch'),
        TQDMProgressBar(refresh_rate=50)
    ]
    if not args.disable_checkpoint:
        callbacks.append(
            ModelCheckpoint(dirpath=outdir,
                            save_weights_only=False,
                            save_last=True,
                            save_top_k=3,
                            mode=mode,
                            monitor=monitor))

    if args.early_stop:
        callbacks.append(
            EarlyStopping(monitor=monitor,
                          min_delta=0.001,
                          patience=10,
                          verbose=False,
                          mode=mode))

    if args.swa:
        callbacks.append(
            StochasticWeightAveraging(swa_epoch_start=args.swa_start,
                                      annealing_epochs=args.swa_anneal))

    targs = dict(
        enable_checkpointing=True,
        callbacks=callbacks,
        logger=logger,
        num_sanity_val_steps=0,
    )
    targs.update(addl_targs)

    if args.debug:
        targs['log_every_n_steps'] = 1
        targs['fast_dev_run'] = 10

    if not args.quiet:
        print0('Trainer args:\n', pformat(targs), file=sys.stderr)
        print0('DataLoader args\n:',
               pformat(data_mod._loader_kwargs),
               file=sys.stderr)
        print0('Model:\n', model, file=sys.stderr)

    trainer = Trainer(**targs)

    if args.debug:
        #print_dataloader(data_mod.test_dataloader())
        print_dataloader(data_mod.train_dataloader())
        print_dataloader(data_mod.val_dataloader())

    s = datetime.now()
    print0('START_TIME', time())
    trainer.fit(model, data_mod)
    e = datetime.now()
    td = e - s
    hours, seconds = divmod(td.seconds, 3600)
    minutes, seconds = divmod(seconds, 60)

    print0("Took %02d:%02d:%02d.%d" %
           (hours, minutes, seconds, td.microseconds),
           file=sys.stderr)
    print0("Total seconds:", td.total_seconds(), file=sys.stderr)