Ejemplo n.º 1
0
def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val,
                                         tmpdir):
    """Ensure that clip gradients is only called if the value is greater than 0.

    TODO: Fix (test fails with parametrize)
    """
    tutils.reset_seed()
    trainer_options = dict(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=1,
        accelerator="tpu",
        devices=1,
        precision=16,
        limit_train_batches=4,
        limit_val_batches=4,
        gradient_clip_val=clip_val,
    )
    model = BoringModel()
    tpipes.run_model_test(trainer_options, model, with_hpc=False)

    if clip_val > 0:
        mock_clip_grad_norm.assert_called()
    else:
        mock_clip_grad_norm.assert_not_called()
Ejemplo n.º 2
0
def test_early_stopping_cpu_model(tmpdir):
    class ModelTrainVal(BoringModel):
        def validation_step(self, *args, **kwargs):
            output = super().validation_step(*args, **kwargs)
            self.log("val_loss", output["x"])
            return output

    tutils.reset_seed()
    stopping = EarlyStopping(monitor="val_loss", min_delta=0.1)
    trainer_options = dict(
        callbacks=[stopping],
        default_root_dir=tmpdir,
        gradient_clip_val=1.0,
        track_grad_norm=2,
        enable_progress_bar=False,
        accumulate_grad_batches=2,
        limit_train_batches=0.1,
        limit_val_batches=0.1,
    )

    model = ModelTrainVal()
    tpipes.run_model_test(trainer_options, model)

    # test freeze on cpu
    model.freeze()
    model.unfreeze()
def test_trainer_reset_correctly(tmpdir):
    """Check that all trainer parameters are reset correctly after scaling batch size."""
    tutils.reset_seed()

    model = BatchSizeModel(batch_size=2)

    # logger file to get meta
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

    changed_attributes = [
        "callbacks",
        "checkpoint_callback",
        "current_epoch",
        "limit_train_batches",
        "logger",
        "max_steps",
        "global_step",
    ]
    expected = {ca: getattr(trainer, ca) for ca in changed_attributes}

    with no_warning_call(UserWarning,
                         match="Please add the following callbacks"):
        trainer.tuner.scale_batch_size(model, max_trials=5)

    actual = {ca: getattr(trainer, ca) for ca in changed_attributes}
    assert actual == expected
Ejemplo n.º 4
0
def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
    # rtol=5e-3 respects the 3 decimals rounding in `.grad_norms` and above
    reset_seed()

    class TestModel(ModelWithManualGradTracker):
        logged_metrics = []

        def on_train_batch_end(self, *_) -> None:
            # copy so they don't get reduced
            self.logged_metrics.append(self.trainer.logged_metrics.copy())

    model = TestModel(norm_type)

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=3,
        track_grad_norm=norm_type,
        log_every_n_steps=1,  # request grad_norms every batch
    )
    trainer.fit(model)
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    assert len(model.logged_metrics) == len(model.stored_grad_norms)
    # compare the logged metrics against tracked norms on `.backward`
    for mod, log in zip(model.stored_grad_norms, model.logged_metrics):
        for k in mod.keys() & log.keys():
            assert np.allclose(mod[k], log[k], rtol=rtol), k
Ejemplo n.º 5
0
def test_dataloaders_passed_to_fit(tmpdir):
    """Test if dataloaders passed to trainer works on TPU."""
    tutils.reset_seed()
    model = BoringModel()

    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      accelerator="tpu",
                      devices=8)
    trainer.fit(model,
                train_dataloaders=model.train_dataloader(),
                val_dataloaders=model.val_dataloader())
    assert trainer.state.finished, f"Training failed with {trainer.state}"
Ejemplo n.º 6
0
def test_dm_checkpoint_save_and_load(tmpdir):
    class CustomBoringModel(BoringModel):
        def validation_step(self, batch, batch_idx):
            out = super().validation_step(batch, batch_idx)
            self.log("early_stop_on", out["x"])
            return out

    class CustomBoringDataModule(BoringDataModule):
        def state_dict(self) -> Dict[str, Any]:
            return {"my": "state_dict"}

        def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
            self.my_state_dict = state_dict

        def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
            checkpoint[self.__class__.__qualname__].update({"on_save": "update"})

        def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
            self.checkpoint_state = checkpoint.get(self.__class__.__qualname__).copy()
            checkpoint[self.__class__.__qualname__].pop("on_save")

    reset_seed()
    dm = CustomBoringDataModule()
    model = CustomBoringModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=1,
        enable_model_summary=False,
        callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on")],
    )

    # fit model
    with pytest.deprecated_call(
        match="`LightningDataModule.on_save_checkpoint` was deprecated in"
        " v1.6 and will be removed in v1.8. Use `state_dict` instead."
    ):
        trainer.fit(model, datamodule=dm)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
    checkpoint = torch.load(checkpoint_path)
    assert dm.__class__.__qualname__ in checkpoint
    assert checkpoint[dm.__class__.__qualname__] == {"my": "state_dict", "on_save": "update"}

    for trainer_fn in TrainerFn:
        trainer.state.fn = trainer_fn
        trainer._restore_modules_and_callbacks(checkpoint_path)
        assert dm.checkpoint_state == {"my": "state_dict", "on_save": "update"}
        assert dm.my_state_dict == {"my": "state_dict"}
Ejemplo n.º 7
0
def test_result_reduce_horovod(tmpdir):
    """Make sure result logging works with Horovod.

    This test mirrors tests/core/test_results.py::_ddp_test_fn
    """
    tutils.reset_seed()
    tutils.set_random_main_port()

    def hvd_test_fn():
        path_here = os.path.abspath(os.path.dirname(__file__))
        path_root = os.path.abspath(os.path.join(path_here, "..", ".."))
        sys.path.insert(0, os.path.abspath(path_root))

        class TestModel(BoringModel):
            def training_step(self, batch, batch_idx):
                self.training_step_called = True

                tensor = torch.tensor([1.0])
                self.log("test_tensor",
                         tensor,
                         sync_dist=True,
                         reduce_fx="sum",
                         on_step=True,
                         on_epoch=True)

                res = self._results

                # Check that `tensor` is summed across all ranks automatically
                assert (
                    res["test_tensor"].item() == hvd.size()
                ), "Result-Log does not work properly with Horovod and Tensors"

            def training_epoch_end(self, outputs) -> None:
                assert len(outputs) == 0

        model = TestModel()
        model.val_dataloader = None

        trainer = Trainer(
            default_root_dir=tmpdir,
            limit_train_batches=2,
            limit_val_batches=2,
            max_epochs=1,
            log_every_n_steps=1,
            enable_model_summary=False,
            logger=False,
        )

        trainer.fit(model)

    horovod.run(hvd_test_fn, np=2)
Ejemplo n.º 8
0
def test_lr_monitor_no_logger(tmpdir):
    tutils.reset_seed()

    model = BoringModel()

    lr_monitor = LearningRateMonitor()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      callbacks=[lr_monitor],
                      logger=False)

    with pytest.raises(MisconfigurationException,
                       match="`Trainer` that has no logger"):
        trainer.fit(model)
Ejemplo n.º 9
0
def test_model_tpu_devices_1(tmpdir):
    """Make sure model trains on TPU."""
    tutils.reset_seed()
    trainer_options = dict(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=2,
        accelerator="tpu",
        devices=1,
        limit_train_batches=4,
        limit_val_batches=4,
    )

    model = BoringModel()
    tpipes.run_model_test(trainer_options, model, with_hpc=False)
def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg):
    """Test possible values for 'batch size auto scaling' Trainer argument."""
    tutils.reset_seed()
    before_batch_size = 2
    model = BatchSizeModel(batch_size=before_batch_size)
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      auto_scale_batch_size=scale_arg,
                      accelerator="gpu",
                      devices=1)
    trainer.tune(model)
    after_batch_size = model.batch_size
    assert before_batch_size != after_batch_size, "Batch size was not altered after running auto scaling of batch size"

    assert not os.path.exists(tmpdir / "scale_batch_size_temp_model.ckpt")
Ejemplo n.º 11
0
def test_train_val_loop_only(tmpdir):
    reset_seed()

    dm = ClassifDataModule()
    model = ClassificationModel()

    model.validation_step = None
    model.validation_step_end = None
    model.validation_epoch_end = None

    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False)

    # fit model
    trainer.fit(model, datamodule=dm)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    assert trainer.callback_metrics["train_loss"] < 1.0
Ejemplo n.º 12
0
def test_model_tpu_devices_8(tmpdir):
    """Make sure model trains on TPU."""
    tutils.reset_seed()
    trainer_options = dict(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=1,
        accelerator="tpu",
        devices=8,
        limit_train_batches=4,
        limit_val_batches=4,
    )

    # 8 cores needs a big dataset
    model = SerialLoaderBoringModel()
    tpipes.run_model_test(trainer_options, model, with_hpc=False, min_acc=0.05)
Ejemplo n.º 13
0
def test_tpu_grad_norm(tmpdir):
    """Test if grad_norm works on TPU."""
    tutils.reset_seed()
    trainer_options = dict(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=4,
        accelerator="tpu",
        devices=1,
        limit_train_batches=0.4,
        limit_val_batches=0.4,
        gradient_clip_val=0.5,
    )

    model = BoringModel()
    tpipes.run_model_test(trainer_options, model, with_hpc=False)
Ejemplo n.º 14
0
def test_model_tpu_index(tmpdir, tpu_core):
    """Make sure model trains on TPU."""
    tutils.reset_seed()
    trainer_options = dict(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=2,
        accelerator="tpu",
        devices=[tpu_core],
        limit_train_batches=4,
        limit_val_batches=4,
    )

    model = BoringModel()
    tpipes.run_model_test(trainer_options, model, with_hpc=False)
    assert torch_xla._XLAC._xla_get_default_device() == f"xla:{tpu_core}"
Ejemplo n.º 15
0
def test_lr_monitor_duplicate_custom_pg_names(tmpdir):
    tutils.reset_seed()

    class TestModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.linear_a = torch.nn.Linear(32, 16)
            self.linear_b = torch.nn.Linear(16, 2)

        def forward(self, x):
            x = self.linear_a(x)
            x = self.linear_b(x)
            return x

        def configure_optimizers(self):
            param_groups = [
                {
                    "params": list(self.linear_a.parameters()),
                    "name": "linear"
                },
                {
                    "params": list(self.linear_b.parameters()),
                    "name": "linear"
                },
            ]
            optimizer = torch.optim.SGD(param_groups, lr=0.1)
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                           step_size=1)
            return [optimizer], [lr_scheduler]

    lr_monitor = LearningRateMonitor()
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=2,
        limit_val_batches=2,
        limit_train_batches=2,
        callbacks=[lr_monitor],
        enable_progress_bar=False,
        enable_model_summary=False,
    )

    with pytest.raises(
            MisconfigurationException,
            match=
            "A single `Optimizer` cannot have multiple parameter groups with identical"
    ):
        trainer.fit(TestModel())
Ejemplo n.º 16
0
def test_logger_reset_correctly(tmpdir, extra_params):
    """Test that the tuners do not alter the logger reference."""
    class CustomModel(BoringModel):
        def __init__(self, lr=0.1, batch_size=1):
            super().__init__()
            self.save_hyperparameters()

    tutils.reset_seed()
    model = CustomModel()
    trainer = Trainer(default_root_dir=tmpdir, **extra_params)
    logger1 = trainer.logger
    trainer.tune(model)
    logger2 = trainer.logger
    logger3 = model.logger

    assert logger1 == logger2, "Finder altered the logger of trainer"
    assert logger2 == logger3, "Finder altered the logger of model"
Ejemplo n.º 17
0
def run_model_test(
    trainer_options,
    model: LightningModule,
    data: LightningDataModule = None,
    version=None,
    with_hpc: bool = True,
    min_acc: float = 0.25,
):
    reset_seed()
    save_dir = trainer_options["default_root_dir"]

    # logger file to get meta
    logger = get_default_logger(save_dir, version=version)
    trainer_options.update(logger=logger)
    trainer = Trainer(**trainer_options)
    initial_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()])
    trainer.fit(model, datamodule=data)
    post_train_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()])

    assert trainer.state.finished, f"Training failed with {trainer.state}"
    # Check that the model is actually changed post-training
    change_ratio = torch.norm(initial_values - post_train_values)
    assert change_ratio > 0.03, f"the model is changed of {change_ratio}"

    # test model loading
    _ = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model))

    # test new model accuracy
    test_loaders = model.test_dataloader() if not data else data.test_dataloader()
    if not isinstance(test_loaders, list):
        test_loaders = [test_loaders]

    if not isinstance(model, BoringModel):
        for dataloader in test_loaders:
            run_model_prediction(model, dataloader, min_acc=min_acc)

    if with_hpc:
        # test HPC saving
        # save logger to make sure we get all the metrics
        if logger:
            logger.finalize("finished")
        hpc_save_path = trainer._checkpoint_connector.hpc_save_path(save_dir)
        trainer.save_checkpoint(hpc_save_path)
        # test HPC loading
        checkpoint_path = trainer._checkpoint_connector._CheckpointConnector__get_max_ckpt_path_from_folder(save_dir)
        trainer._checkpoint_connector.restore(checkpoint_path)
def test_call_to_trainer_method(tmpdir, scale_method):
    """Test that calling the trainer method itself works."""
    tutils.reset_seed()

    before_batch_size = 2
    model = BatchSizeModel(batch_size=before_batch_size)

    # logger file to get meta
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

    after_batch_size = trainer.tuner.scale_batch_size(model,
                                                      mode=scale_method,
                                                      max_trials=5)
    model.batch_size = after_batch_size
    trainer.fit(model)

    assert before_batch_size != after_batch_size, "Batch size was not altered after running auto scaling of batch size"
Ejemplo n.º 19
0
def test_lr_monitor_multi_lrs(tmpdir, logging_interval: str):
    """Test that learning rates are extracted and logged for multi lr schedulers."""
    tutils.reset_seed()

    class CustomBoringModel(BoringModel):
        def training_step(self, batch, batch_idx, optimizer_idx):
            return super().training_step(batch, batch_idx)

        def configure_optimizers(self):
            optimizer1 = optim.Adam(self.parameters(), lr=1e-2)
            optimizer2 = optim.Adam(self.parameters(), lr=1e-2)
            lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1)
            lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)
            return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2]

    model = CustomBoringModel()
    model.training_epoch_end = None

    lr_monitor = LearningRateMonitor(logging_interval=logging_interval)
    log_every_n_steps = 2

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=2,
        log_every_n_steps=log_every_n_steps,
        limit_train_batches=7,
        limit_val_batches=0.1,
        callbacks=[lr_monitor],
    )
    trainer.fit(model)

    assert lr_monitor.lrs, "No learning rates logged"
    assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs)
    assert list(lr_monitor.lrs) == [
        "lr-Adam", "lr-Adam-1"
    ], "Names of learning rates not set correctly"

    if logging_interval == "step":
        # divide by 2 because we have 2 optimizers
        expected_number_logged = trainer.global_step // 2 // log_every_n_steps
    if logging_interval == "epoch":
        expected_number_logged = trainer.max_epochs

    assert all(
        len(lr) == expected_number_logged for lr in lr_monitor.lrs.values())
def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams):
    """Test that new batch size gets written to the correct hyperparameter attribute."""
    tutils.reset_seed()

    hparams = {"batch_size": 2}
    before_batch_size = hparams.get("batch_size")

    class HparamsBatchSizeModel(BatchSizeModel):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.save_hyperparameters()

        def dataloader(self, *args, **kwargs):
            # artificially set batch_size so we can get a dataloader
            # remove it immediately after, because we want only self.hparams.batch_size
            setattr(self, "batch_size", before_batch_size)
            dataloader = super().dataloader(*args, **kwargs)
            del self.batch_size
            return dataloader

    class HparamsBatchSizeDataModule(BoringDataModule):
        def __init__(self, data_dir, batch_size):
            super().__init__(data_dir)
            self.batch_size = batch_size

        def train_dataloader(self):
            return DataLoader(self.random_train, batch_size=self.batch_size)

    datamodule_fit = HparamsBatchSizeDataModule(data_dir=tmpdir,
                                                batch_size=before_batch_size)
    model_class = HparamsBatchSizeModel if use_hparams else BatchSizeModel
    model = model_class(**hparams)

    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      auto_scale_batch_size=True,
                      accelerator="gpu",
                      devices=1)
    trainer.tune(model, datamodule_fit)
    after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
    assert trainer.datamodule == datamodule_fit
    assert before_batch_size != after_batch_size
    assert after_batch_size <= len(trainer.train_dataloader.dataset)
    assert datamodule_fit.batch_size == after_batch_size
Ejemplo n.º 21
0
def test_load_model_from_checkpoint(tmpdir, model_template):
    """Verify test() on pretrained model."""
    tutils.reset_seed()
    model = model_template()

    trainer_options = dict(
        enable_progress_bar=False,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_top_k=-1)],
        default_root_dir=tmpdir,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model)
    trainer.test(model)

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    # load last checkpoint
    last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]

    # Since `BoringModel` has `_save_hparams = True` by default, check that ckpt has hparams
    ckpt = torch.load(last_checkpoint)
    assert model_template.CHECKPOINT_HYPER_PARAMS_KEY in ckpt.keys(), "hyper_parameters missing from checkpoints"

    # Ensure that model can be correctly restored from checkpoint
    pretrained_model = model_template.load_from_checkpoint(last_checkpoint)

    # test that hparams loaded correctly
    for k, v in model.hparams.items():
        assert getattr(pretrained_model.hparams, k) == v

    # assert weights are the same
    for (old_name, old_p), (new_name, new_p) in zip(model.named_parameters(), pretrained_model.named_parameters()):
        assert torch.all(torch.eq(old_p, new_p)), "loaded weights are not the same as the saved weights"

    # Check `test` on pretrained model:
    new_trainer = Trainer(**trainer_options)
    new_trainer.test(pretrained_model)
Ejemplo n.º 22
0
def test_lr_monitor_single_lr(tmpdir):
    """Test that learning rates are extracted and logged for single lr scheduler."""
    tutils.reset_seed()

    model = BoringModel()

    lr_monitor = LearningRateMonitor()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=2,
                      limit_val_batches=0.1,
                      limit_train_batches=0.5,
                      callbacks=[lr_monitor])
    trainer.fit(model)

    assert lr_monitor.lrs, "No learning rates logged"
    assert all(v is None for v in lr_monitor.last_momentum_values.values()
               ), "Momentum should not be logged by default"
    assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs)
    assert list(lr_monitor.lrs) == ["lr-SGD"]
Ejemplo n.º 23
0
def test_amp_cpus(tmpdir, strategy, precision, devices):
    """Make sure combinations of AMP and strategies work if supported."""
    tutils.reset_seed()

    trainer = Trainer(
        default_root_dir=tmpdir,
        accelerator="cpu",
        devices=devices,
        max_epochs=1,
        strategy=strategy,
        precision=precision,
    )

    model = AMPTestModel()
    trainer.fit(model)
    trainer.test(model)
    trainer.predict(model, DataLoader(RandomDataset(32, 64)))

    assert trainer.state.finished, f"Training failed with {trainer.state}"
Ejemplo n.º 24
0
def run_model_test_without_loggers(
    trainer_options: dict, model: LightningModule, data: LightningDataModule = None, min_acc: float = 0.50
):
    reset_seed()

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model, datamodule=data)

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    model2 = load_model_from_checkpoint(trainer.logger, trainer.checkpoint_callback.best_model_path, type(model))

    # test new model accuracy
    test_loaders = model2.test_dataloader() if not data else data.test_dataloader()
    if not isinstance(test_loaders, list):
        test_loaders = [test_loaders]

    if not isinstance(model2, BoringModel):
        for dataloader in test_loaders:
            run_model_prediction(model2, dataloader, min_acc=min_acc)
Ejemplo n.º 25
0
def test_tpu_host_world_size(tmpdir):
    """Test Host World size env setup on TPU."""
    class DebugModel(BoringModel):
        def on_train_start(self):
            assert os.environ.get("XRT_HOST_WORLD_SIZE") == str(1)

        def teardown(self, stage):
            assert "XRT_HOST_WORLD_SIZE" not in os.environ

    tutils.reset_seed()
    trainer_options = dict(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=4,
        accelerator="tpu",
        devices=8,
        limit_train_batches=0.4,
        limit_val_batches=0.4,
    )

    model = DebugModel()
    tpipes.run_model_test(trainer_options, model, with_hpc=False)
Ejemplo n.º 26
0
def test_full_loop(tmpdir):
    reset_seed()

    dm = ClassifDataModule()
    model = ClassificationModel()

    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False, deterministic=True)

    # fit model
    trainer.fit(model, dm)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    assert dm.trainer is not None

    # validate
    result = trainer.validate(model, dm)
    assert dm.trainer is not None
    assert result[0]["val_acc"] > 0.7

    # test
    result = trainer.test(model, dm)
    assert dm.trainer is not None
    assert result[0]["test_acc"] > 0.6
def test_model_reset_correctly(tmpdir):
    """Check that model weights are correctly reset after scaling batch size."""
    tutils.reset_seed()

    model = BatchSizeModel(batch_size=2)

    # logger file to get meta
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

    before_state_dict = deepcopy(model.state_dict())

    trainer.tuner.scale_batch_size(model, max_trials=5)

    after_state_dict = model.state_dict()

    for key in before_state_dict.keys():
        assert torch.all(
            torch.eq(before_state_dict[key], after_state_dict[key])
        ), "Model was not reset correctly after scaling batch size"

    assert not any(
        f for f in os.listdir(tmpdir) if f.startswith(".scale_batch_size"))
Ejemplo n.º 28
0
def test_model_tpu_early_stop(tmpdir):
    """Test if single TPU core training works."""
    class CustomBoringModel(BoringModel):
        def validation_step(self, *args, **kwargs):
            out = super().validation_step(*args, **kwargs)
            self.log("val_loss", out["x"])
            return out

    tutils.reset_seed()
    model = CustomBoringModel()
    trainer = Trainer(
        callbacks=[EarlyStopping(monitor="val_loss")],
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        accelerator="tpu",
        devices=8,
    )
    trainer.fit(model)
    trainer.test(
        dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32))
Ejemplo n.º 29
0
def test_lr_monitor_no_lr_scheduler_single_lr(tmpdir):
    """Test that learning rates are extracted and logged for no lr scheduler."""
    tutils.reset_seed()

    class CustomBoringModel(BoringModel):
        def configure_optimizers(self):
            optimizer = optim.SGD(self.parameters(), lr=0.1)
            return optimizer

    model = CustomBoringModel()

    lr_monitor = LearningRateMonitor()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=2,
                      limit_val_batches=0.1,
                      limit_train_batches=0.5,
                      callbacks=[lr_monitor])

    trainer.fit(model)

    assert lr_monitor.lrs, "No learning rates logged"
    assert len(lr_monitor.lrs) == len(trainer.optimizers)
    assert list(lr_monitor.lrs) == ["lr-SGD"]
Ejemplo n.º 30
0
def test_lr_monitor_param_groups(tmpdir):
    """Test that learning rates are extracted and logged for single lr scheduler."""
    tutils.reset_seed()

    class CustomClassificationModel(ClassificationModel):
        def configure_optimizers(self):
            param_groups = [
                {
                    "params": list(self.parameters())[:2],
                    "lr": self.lr * 0.1
                },
                {
                    "params": list(self.parameters())[2:],
                    "lr": self.lr
                },
            ]

            optimizer = optim.Adam(param_groups)
            lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
            return [optimizer], [lr_scheduler]

    model = CustomClassificationModel()
    dm = ClassifDataModule()

    lr_monitor = LearningRateMonitor()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=2,
                      limit_val_batches=0.1,
                      limit_train_batches=0.5,
                      callbacks=[lr_monitor])
    trainer.fit(model, datamodule=dm)

    assert lr_monitor.lrs, "No learning rates logged"
    assert len(lr_monitor.lrs) == 2 * len(trainer.lr_scheduler_configs)
    assert list(lr_monitor.lrs) == [
        "lr-Adam/pg1", "lr-Adam/pg2"
    ], "Names of learning rates not set correctly"