Exemplo n.º 1
0
def lr_find(
    trainer,
    model: LightningModule,
    train_dataloader: Optional[DataLoader] = None,
    val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
    min_lr: float = 1e-8,
    max_lr: float = 1,
    num_training: int = 100,
    mode: str = 'exponential',
    early_stop_threshold: float = 4.0,
    datamodule: Optional[LightningDataModule] = None,
    update_attr: bool = False,
):
    r"""
    ``lr_find`` enables the user to do a range test of good initial learning rates,
    to reduce the amount of guesswork in picking a good starting learning rate.

    Args:
        model: Model to do range testing for

        train_dataloader: A PyTorch
            ``DataLoader`` with training samples. If the model has
            a predefined train_dataloader method, this will be skipped.

        min_lr: minimum learning rate to investigate

        max_lr: maximum learning rate to investigate

        num_training: number of learning rates to test

        mode: Search strategy to update learning rate after each batch:

            - ``'exponential'`` (default): Will increase the learning rate exponentially.
            - ``'linear'``: Will increase the learning rate linearly.

        early_stop_threshold: threshold for stopping the search. If the
            loss at any point is larger than early_stop_threshold*best_loss
            then the search is stopped. To disable, set to None.

        datamodule: An optional ``LightningDataModule`` which holds the training
            and validation dataloader(s). Note that the ``train_dataloader`` and
            ``val_dataloaders`` parameters cannot be used at the same time as
            this parameter, or a ``MisconfigurationException`` will be raised.

        update_attr: Whether to update the learning rate attribute or not.

    Raises:
        MisconfigurationException:
            If learning rate/lr in ``model`` or ``model.hparams`` isn't overriden when ``auto_lr_find=True``, or
            if you are using `more than one optimizer` with learning rate finder.

    Example::

        # Setup model and trainer
        model = MyModelClass(hparams)
        trainer = pl.Trainer()

        # Run lr finder
        lr_finder = trainer.tuner.lr_find(model, ...)

        # Inspect results
        fig = lr_finder.plot(); fig.show()
        suggested_lr = lr_finder.suggestion()

        # Overwrite lr and create new model
        hparams.lr = suggested_lr
        model = MyModelClass(hparams)

        # Ready to train with new learning rate
        trainer.fit(model)

    """
    if trainer.fast_dev_run:
        rank_zero_warn(
            'Skipping learning rate finder since fast_dev_run is enabled.',
            UserWarning)
        return

    # Determine lr attr
    if update_attr:
        lr_attr_name = _determine_lr_attr_name(trainer, model)

    save_path = os.path.join(trainer.default_root_dir,
                             'lr_find_temp_model.ckpt')

    __lr_finder_dump_params(trainer, model)

    # Prevent going into infinite loop
    trainer.auto_lr_find = False

    # Initialize lr finder object (stores results)
    lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

    # Use special lr logger callback
    trainer.callbacks = [
        _LRCallback(num_training,
                    early_stop_threshold,
                    progress_bar_refresh_rate=1)
    ]

    # No logging
    trainer.logger = DummyLogger()

    # Max step set to number of iterations
    trainer.max_steps = num_training

    # Disable standard progress bar for fit
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.disable()

    # Required for saving the model
    trainer.optimizers, trainer.schedulers = [], [],
    trainer.model = model

    # Dump model checkpoint
    trainer.save_checkpoint(str(save_path))

    # Configure optimizer and scheduler
    model.configure_optimizers = lr_finder._exchange_scheduler(
        model.configure_optimizers)

    # Fit, lr & loss logged in callback
    trainer.fit(model,
                train_dataloader=train_dataloader,
                val_dataloaders=val_dataloaders,
                datamodule=datamodule)

    # Prompt if we stopped early
    if trainer.global_step != num_training:
        log.info('LR finder stopped early due to diverging loss.')

    # Transfer results from callback to lr finder object
    lr_finder.results.update({
        'lr': trainer.callbacks[0].lrs,
        'loss': trainer.callbacks[0].losses
    })
    lr_finder._total_batch_idx = trainer.total_batch_idx  # for debug purpose

    # Reset model state
    if trainer.is_global_zero:
        trainer.checkpoint_connector.restore(
            str(save_path), on_gpu=trainer._device_type == DeviceType.GPU)
        fs = get_filesystem(str(save_path))
        if fs.exists(save_path):
            fs.rm(save_path)

    # Finish by resetting variables so trainer is ready to fit model
    __lr_finder_restore_params(trainer, model)
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.enable()

    # Update lr attr if required
    if update_attr:
        lr = lr_finder.suggestion()

        # TODO: log lr.results to self.logger
        lightning_setattr(model, lr_attr_name, lr)
        log.info(f'Learning rate set to {lr}')

    return lr_finder
Exemplo n.º 2
0
def lr_find(
    trainer: "pl.Trainer",
    model: "pl.LightningModule",
    min_lr: float = 1e-8,
    max_lr: float = 1,
    num_training: int = 100,
    mode: str = "exponential",
    early_stop_threshold: float = 4.0,
    update_attr: bool = False,
) -> Optional[_LRFinder]:
    """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`"""
    if trainer.fast_dev_run:
        rank_zero_warn("Skipping learning rate finder since fast_dev_run is enabled.")
        return None

    # Determine lr attr
    if update_attr:
        lr_attr_name = _determine_lr_attr_name(trainer, model)

    # Save initial model, that is loaded after learning rate is found
    ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
    trainer.save_checkpoint(ckpt_path)
    params = __lr_finder_dump_params(trainer)

    # Set to values that are required by the algorithm
    __lr_finder_reset_params(trainer, num_training, early_stop_threshold)

    # Initialize lr finder object (stores results)
    lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

    # Disable standard progress bar for fit
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.disable()

    # Configure optimizer and scheduler
    trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model)  # type: ignore[assignment]

    # Fit, lr & loss logged in callback
    trainer.tuner._run(model)

    # Prompt if we stopped early
    if trainer.global_step != num_training:
        log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")

    # Transfer results from callback to lr finder object
    lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
    lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx  # for debug purpose

    __lr_finder_restore_params(trainer, params)

    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.enable()

    # Update lr attr if required
    if update_attr:
        lr = lr_finder.suggestion()

        # TODO: log lr.results to self.logger
        lightning_setattr(model, lr_attr_name, lr)
        log.info(f"Learning rate set to {lr}")

    # Restore initial state of model
    trainer._checkpoint_connector.restore(ckpt_path)
    trainer.strategy.remove_checkpoint(ckpt_path)

    return lr_finder
Exemplo n.º 3
0
def lr_find(
    trainer: "pl.Trainer",
    model: "pl.LightningModule",
    min_lr: float = 1e-8,
    max_lr: float = 1,
    num_training: int = 100,
    mode: str = "exponential",
    early_stop_threshold: float = 4.0,
    update_attr: bool = False,
) -> Optional[_LRFinder]:
    """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`"""
    if trainer.fast_dev_run:
        rank_zero_warn(
            "Skipping learning rate finder since fast_dev_run is enabled.",
            UserWarning)
        return

    # Determine lr attr
    if update_attr:
        lr_attr_name = _determine_lr_attr_name(trainer, model)

    save_path = os.path.join(trainer.default_root_dir,
                             "lr_find_temp_model.ckpt")

    __lr_finder_dump_params(trainer, model)

    # Prevent going into infinite loop
    trainer.auto_lr_find = False

    # Initialize lr finder object (stores results)
    lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

    # Use special lr logger callback
    trainer.callbacks = [
        _LRCallback(num_training,
                    early_stop_threshold,
                    progress_bar_refresh_rate=1)
    ]

    # No logging
    trainer.logger = DummyLogger()

    # Max step set to number of iterations
    trainer.fit_loop.max_steps = num_training

    # Disable standard progress bar for fit
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.disable()

    # Required for saving the model
    trainer.optimizers, trainer.lr_schedulers = [], []
    trainer.model = model

    # Dump model checkpoint
    trainer.save_checkpoint(str(save_path))

    # Configure optimizer and scheduler
    model.configure_optimizers = lr_finder._exchange_scheduler(
        model.configure_optimizers)

    # Fit, lr & loss logged in callback
    trainer.tuner._run(model)

    # Prompt if we stopped early
    if trainer.global_step != num_training:
        log.info(
            f"LR finder stopped early after {trainer.global_step} steps due to diverging loss."
        )

    # Transfer results from callback to lr finder object
    lr_finder.results.update({
        "lr": trainer.callbacks[0].lrs,
        "loss": trainer.callbacks[0].losses
    })
    lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx  # for debug purpose

    # Reset model state
    if trainer.is_global_zero:
        trainer.checkpoint_connector.restore(str(save_path))
        fs = get_filesystem(str(save_path))
        if fs.exists(save_path):
            fs.rm(save_path)

    # Finish by resetting variables so trainer is ready to fit model
    __lr_finder_restore_params(trainer, model)
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.enable()

    # Update lr attr if required
    if update_attr:
        lr = lr_finder.suggestion()

        # TODO: log lr.results to self.logger
        lightning_setattr(model, lr_attr_name, lr)
        log.info(f"Learning rate set to {lr}")

    return lr_finder