def test_swa_raises(): with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"): StochasticWeightAveraging(swa_epoch_start=0, swa_lrs=0.1) with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"): StochasticWeightAveraging(swa_epoch_start=1.5, swa_lrs=0.1) with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"): StochasticWeightAveraging(swa_epoch_start=-1, swa_lrs=0.1) with pytest.raises(MisconfigurationException, match="positive float or a list of positive float"): StochasticWeightAveraging(swa_epoch_start=5, swa_lrs=[0.2, 1])
def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks, stochastic_weight_avg): """Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer""" class TestModel(BoringModel): def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer model = TestModel() trainer = Trainer( default_root_dir=tmpdir, callbacks=StochasticWeightAveraging( swa_lrs=1e-3) if use_callbacks else None, stochastic_weight_avg=stochastic_weight_avg, limit_train_batches=4, limit_val_batches=4, max_epochs=2, ) trainer.fit(model) if use_callbacks or stochastic_weight_avg: assert len([ cb for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging) ]) == 1 assert trainer.callbacks[0]._swa_lrs == (1e-3 if use_callbacks else 0.1) else: assert all(not isinstance(cb, StochasticWeightAveraging) for cb in trainer.callbacks)
def test_swa_multiple_lrs(tmpdir): swa_lrs = [0.123, 0.321] class TestModel(BoringModel): def __init__(self): super(BoringModel, self).__init__() self.layer1 = torch.nn.Linear(32, 32) self.layer2 = torch.nn.Linear(32, 2) def forward(self, x): x = self.layer1(x) x = self.layer2(x) return x def configure_optimizers(self): params = [{"params": self.layer1.parameters(), "lr": 0.1}, {"params": self.layer2.parameters(), "lr": 0.2}] return torch.optim.Adam(params) def on_train_epoch_start(self): optimizer = trainer.optimizers[0] assert [pg["lr"] for pg in optimizer.param_groups] == [0.1, 0.2] assert [pg["initial_lr"] for pg in optimizer.param_groups] == swa_lrs assert [pg["swa_lr"] for pg in optimizer.param_groups] == swa_lrs self.on_train_epoch_start_called = True model = TestModel() swa_callback = StochasticWeightAveraging(swa_lrs=swa_lrs) trainer = Trainer( default_root_dir=tmpdir, callbacks=swa_callback, fast_dev_run=1, ) trainer.fit(model) assert model.on_train_epoch_start_called
def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks: bool, stochastic_weight_avg: bool): """Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer.""" class TestModel(BoringModel): def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer model = TestModel() kwargs = { "default_root_dir": tmpdir, "callbacks": StochasticWeightAveraging(swa_lrs=1e-3) if use_callbacks else None, "stochastic_weight_avg": stochastic_weight_avg, "limit_train_batches": 4, "limit_val_batches": 4, "max_epochs": 2, } if stochastic_weight_avg: with pytest.deprecated_call( match=r"stochastic_weight_avg=True\)` is deprecated in v1.5"): trainer = Trainer(**kwargs) else: trainer = Trainer(**kwargs) trainer.fit(model) if use_callbacks or stochastic_weight_avg: assert sum(1 for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging)) == 1 assert trainer.callbacks[0]._swa_lrs == [ 1e-3 if use_callbacks else 0.1 ] else: assert all(not isinstance(cb, StochasticWeightAveraging) for cb in trainer.callbacks)
def test_advanced_profiler_cprofile_deepcopy(tmpdir): """Checks for pickle issue reported in #6522.""" model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, fast_dev_run=True, profiler="advanced", callbacks=StochasticWeightAveraging() ) trainer.fit(model)
def test_swa_warns(tmpdir, caplog): model = SwaTestModel(interval="step") trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=StochasticWeightAveraging()) with caplog.at_level(level=logging.INFO), pytest.warns( UserWarning, match="SWA is currently only supported"): trainer.fit(model) assert "Swapping scheduler `StepLR` for `SWALR`" in caplog.text
def main(args, dm_setup='fit'): pl.seed_everything(args.seed) dm = SaltDM.from_argparse_args(args) dm.setup(dm_setup) model = LitResUnet(**vars(args)) model.hparams.update(dm.kwargs) # print(wandb_logger.log_dir) # checkpoint_dir = os.path.join(tt_logger.log_dir, 'ckpt') checkpoint_dir = os.path.join( '../params/{}/f{:02d}_{}'.format(args.model, args.val_fold_idx, args.version), 'ckpt') checkpoint_file_path = os.path.join(checkpoint_dir, 'last.ckpt') if os.path.isfile(checkpoint_file_path): args.resume_from_checkpoint = checkpoint_file_path print('Detect checkpoint:', args.resume_from_checkpoint) # Callbacks callbacks = [] if args.loss_func == 'lovasz_hinge': callbacks.append(ToLovaszHingeLossCB()) if args.logger_type != 'none': callbacks.append(LearningRateMonitor()) if args.swa_epoch_start > 0: callbacks.append( StochasticWeightAveraging(swa_epoch_start=args.swa_epoch_start)) if args.snapshot_size > 0: callbacks.append(ResetSnapshotCB(checkpoint_dir)) trainer = pl.Trainer.from_argparse_args( args, # logger=tt_logger, logger=get_logger(args), callbacks=callbacks, # callbacks=[ResetSnapshotCB(checkpoint_dir)], checkpoint_callback=checkpointcb(args, checkpoint_dir)) # trainer = pl.Trainer.from_argparse_args(args, logger=tt_logger, callbacks=[ResetSnapshotCB(), LearningRateMonitor()], checkpoint_callback=False) return dm, model, trainer
def transfer_weights(self, *args, **kwargs): self.transfer_weights_calls += 1 return StochasticWeightAveraging.transfer_weights(*args, **kwargs)
def update_parameters(self, *args, **kwargs): self.update_parameters_calls += 1 return StochasticWeightAveraging.update_parameters(*args, **kwargs)
monitor=f'val_loss_fold_{f}', dirpath='model_dir', filename=f"{model_name}_loss_fold_{f}", save_top_k=1, mode='min', ) checkpoint_callback2 = ModelCheckpoint( monitor=f'val_roc_auc_fold_{f}', dirpath='model_dir', filename=f"{model_name}_roc_auc_fold_{f}", save_top_k=1, mode='max', ) lr_monitor = LearningRateMonitor(logging_interval='step') swa_callback = StochasticWeightAveraging() trainer = pl.Trainer( max_epochs=n_epochs, precision=16, # auto_lr_find=True, # Usually the auto is pretty bad. You should instead plot and pick manually. gradient_clip_val=100, num_sanity_val_steps=10, profiler="simple", weights_summary='top', accumulate_grad_batches=accum_step, logger=[wandb_logger], checkpoint_callback=True, gpus=gpu_ids, num_processes=4 * len(gpu_ids), stochastic_weight_avg=True,
def run_lightning(argv=None): '''Run training with PyTorch Lightning''' global RANK from pytorch_lightning.loggers import WandbLogger import numpy as np import traceback import os import pprint pformat = pprint.PrettyPrinter(sort_dicts=False, width=100, indent=2).pformat model, args, addl_targs, data_mod = process_args(parse_args(argv=argv)) # if 'OMPI_COMM_WORLD_RANK' in os.environ or 'SLURMD_NODENAME' in os.environ: # from mpi4py import MPI # comm = MPI.COMM_WORLD # RANK = comm.Get_rank() # else: # RANK = 0 # print('OMPI_COMM_WORLD_RANK or SLURMD_NODENAME not set in environment -- not using MPI') # output is a wrapper function for os.path.join(outdir, <FILE>) outdir, output = process_output(args) check_directory(outdir) if not args.quiet: print0(' '.join(sys.argv), file=sys.stderr) print0("Processed Args:\n", pformat(vars(args)), file=sys.stderr) # save arguments with open(output('args.pkl'), 'wb') as f: pickle.dump(args, f) checkpoint = None if args.init is not None: checkpoint = args.init link_dest = 'init.ckpt' elif args.checkpoint is not None: checkpoint = args.checkpoint link_dest = 'resumed_from.ckpt' if checkpoint is not None: if RANK == 0: print0(f'symlinking to {args.checkpoint} from {outdir}') dest = output(link_dest) src = os.path.relpath(checkpoint, start=outdir) if os.path.exists(dest): existing_src = os.readlink(dest) if existing_src != src: msg = f'Cannot create symlink to checkpoint -- {dest} already exists, but points to {existing_src}' raise RuntimeError(msg) else: os.symlink(src, dest) seed_everything(args.seed) if args.csv: logger = CSVLogger(save_dir=output('logs')), else: logger = WandbLogger(project="deep-taxon", entity='deep-taxon', name=args.experiment) # get dataset so we can set model parameters that are # dependent on the dataset, such as final number of outputs monitor, mode = (AbstractLit.val_loss, 'min') if args.manifold else (AbstractLit.val_acc, 'max') callbacks = [ LearningRateMonitor(logging_interval='epoch'), TQDMProgressBar(refresh_rate=50) ] if not args.disable_checkpoint: callbacks.append( ModelCheckpoint(dirpath=outdir, save_weights_only=False, save_last=True, save_top_k=3, mode=mode, monitor=monitor)) if args.early_stop: callbacks.append( EarlyStopping(monitor=monitor, min_delta=0.001, patience=10, verbose=False, mode=mode)) if args.swa: callbacks.append( StochasticWeightAveraging(swa_epoch_start=args.swa_start, annealing_epochs=args.swa_anneal)) targs = dict( enable_checkpointing=True, callbacks=callbacks, logger=logger, num_sanity_val_steps=0, ) targs.update(addl_targs) if args.debug: targs['log_every_n_steps'] = 1 targs['fast_dev_run'] = 10 if not args.quiet: print0('Trainer args:\n', pformat(targs), file=sys.stderr) print0('DataLoader args\n:', pformat(data_mod._loader_kwargs), file=sys.stderr) print0('Model:\n', model, file=sys.stderr) trainer = Trainer(**targs) if args.debug: #print_dataloader(data_mod.test_dataloader()) print_dataloader(data_mod.train_dataloader()) print_dataloader(data_mod.val_dataloader()) s = datetime.now() print0('START_TIME', time()) trainer.fit(model, data_mod) e = datetime.now() td = e - s hours, seconds = divmod(td.seconds, 3600) minutes, seconds = divmod(seconds, 60) print0("Took %02d:%02d:%02d.%d" % (hours, minutes, seconds, td.microseconds), file=sys.stderr) print0("Total seconds:", td.total_seconds(), file=sys.stderr)