def _run_lr_finder_internally(trainer, model: LightningModule):
    """ Call lr finder internally during Trainer.fit() """
    lr_finder = lr_find(trainer, model)

    if lr_finder is None:
        return

    lr = lr_finder.suggestion()

    # TODO: log lr.results to self.logger
    if isinstance(trainer.auto_lr_find, str):
        # Try to find requested field, may be nested
        if lightning_hasattr(model, trainer.auto_lr_find):
            lightning_setattr(model, trainer.auto_lr_find, lr)
        else:
            raise MisconfigurationException(
                f'`auto_lr_find` was set to {trainer.auto_lr_find}, however'
                ' could not find this as a field in `model` or `model.hparams`.'
            )
    else:
        if lightning_hasattr(model, 'lr'):
            lightning_setattr(model, 'lr', lr)
        elif lightning_hasattr(model, 'learning_rate'):
            lightning_setattr(model, 'learning_rate', lr)
        else:
            raise MisconfigurationException(
                'When auto_lr_find is set to True, expects that `model` or'
                ' `model.hparams` either has field `lr` or `learning_rate`'
                ' that can overridden')
    log.info(f'Learning rate set to {lr}')
Beispiel #2
0
def test_lightning_hasattr(tmpdir):
    """ Test that the lightning_hasattr works in all cases"""
    model1, model2, model3, model4 = _get_test_cases()
    assert lightning_hasattr(model1, 'learning_rate'), \
        'lightning_hasattr failed to find namespace variable'
    assert lightning_hasattr(model2, 'learning_rate'), \
        'lightning_hasattr failed to find hparams namespace variable'
    assert lightning_hasattr(model3, 'learning_rate'), \
        'lightning_hasattr failed to find hparams dict variable'
    assert not lightning_hasattr(model4, 'learning_rate'), \
        'lightning_hasattr found variable when it should not'
Beispiel #3
0
def test_lightning_hasattr(tmpdir, model_cases):
    """Test that the lightning_hasattr works in all cases."""
    model1, model2, model3, model4, model5, model6, model7 = models = model_cases
    assert lightning_hasattr(
        model1,
        "learning_rate"), "lightning_hasattr failed to find namespace variable"
    assert lightning_hasattr(
        model2, "learning_rate"
    ), "lightning_hasattr failed to find hparams namespace variable"
    assert lightning_hasattr(
        model3, "learning_rate"
    ), "lightning_hasattr failed to find hparams dict variable"
    assert not lightning_hasattr(
        model4,
        "learning_rate"), "lightning_hasattr found variable when it should not"
    assert lightning_hasattr(
        model5, "batch_size"
    ), "lightning_hasattr failed to find batch_size in datamodule"
    assert lightning_hasattr(
        model6, "batch_size"
    ), "lightning_hasattr failed to find batch_size in datamodule w/ hparams present"
    assert lightning_hasattr(
        model7, "batch_size"
    ), "lightning_hasattr failed to find batch_size in hparams w/ datamodule present"

    for m in models:
        assert not lightning_hasattr(m, "this_attr_not_exist")
Beispiel #4
0
def _determine_lr_attr_name(trainer, model: LightningModule) -> str:
    if isinstance(trainer.auto_lr_find, str):
        if not lightning_hasattr(model, trainer.auto_lr_find):
            raise MisconfigurationException(
                f'`auto_lr_find` was set to {trainer.auto_lr_find}, however'
                ' could not find this as a field in `model` or `model.hparams`.'
            )
        return trainer.auto_lr_find

    attr_options = ('lr', 'learning_rate')
    for attr in attr_options:
        if lightning_hasattr(model, attr):
            return attr

    raise MisconfigurationException(
        'When `auto_lr_find=True`, either `model` or `model.hparams` should'
        f' have one of these fields: {attr_options} overridden.')
Beispiel #5
0
def _determine_lr_attr_name(trainer: "pl.Trainer", model: "pl.LightningModule") -> str:
    if isinstance(trainer.auto_lr_find, str):
        if not lightning_hasattr(model, trainer.auto_lr_find):
            raise MisconfigurationException(
                f"`auto_lr_find` was set to {trainer.auto_lr_find}, however"
                " could not find this as a field in `model` or `model.hparams`."
            )
        return trainer.auto_lr_find

    attr_options = ("lr", "learning_rate")
    for attr in attr_options:
        if lightning_hasattr(model, attr):
            return attr

    raise MisconfigurationException(
        "When `auto_lr_find=True`, either `model` or `model.hparams` should"
        f" have one of these fields: {attr_options} overridden."
    )
def test_lightning_hasattr(tmpdir):
    """ Test that the lightning_hasattr works in all cases"""
    model1, model2, model3, model4, model5, model6, model7 = _get_test_cases()
    assert lightning_hasattr(model1, 'learning_rate'), \
        'lightning_hasattr failed to find namespace variable'
    assert lightning_hasattr(model2, 'learning_rate'), \
        'lightning_hasattr failed to find hparams namespace variable'
    assert lightning_hasattr(model3, 'learning_rate'), \
        'lightning_hasattr failed to find hparams dict variable'
    assert not lightning_hasattr(model4, 'learning_rate'), \
        'lightning_hasattr found variable when it should not'
    assert lightning_hasattr(model5, 'batch_size'), \
        'lightning_hasattr failed to find batch_size in datamodule'
    assert lightning_hasattr(model6, 'batch_size'), \
        'lightning_hasattr failed to find batch_size in datamodule w/ hparams present'
    assert lightning_hasattr(model7, 'batch_size'), \
        'lightning_hasattr failed to find batch_size in hparams w/ datamodule present'
Beispiel #7
0
def scale_batch_size(
    trainer: 'pl.Trainer',
    model: 'pl.LightningModule',
    mode: str = 'power',
    steps_per_trial: int = 3,
    init_val: int = 2,
    max_trials: int = 25,
    batch_arg_name: str = 'batch_size',
) -> Optional[int]:
    """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size`"""
    if trainer.fast_dev_run:
        rank_zero_warn(
            'Skipping batch size scaler since fast_dev_run is enabled.',
            UserWarning)
        return

    if not lightning_hasattr(model, batch_arg_name):
        raise MisconfigurationException(
            f'Field {batch_arg_name} not found in both `model` and `model.hparams`'
        )
    if hasattr(model, batch_arg_name) and hasattr(
            model, "hparams") and batch_arg_name in model.hparams:
        rank_zero_warn(
            f'Field `model.{batch_arg_name}` and `model.hparams.{batch_arg_name}` are mutually exclusive!'
            f' `model.{batch_arg_name}` will be used as the initial batch size for scaling.'
            f' If this is not the intended behavior, please remove either one.'
        )

    if hasattr(model.train_dataloader, 'patch_loader_code'):
        raise MisconfigurationException(
            'The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`.'
            ' Please disable the feature or incorporate the dataloader into the model.'
        )

    # Arguments we adjust during the batch size finder, save for restoring
    __scale_batch_dump_params(trainer)

    # Set to values that are required by the algorithm
    __scale_batch_reset_params(trainer, model, steps_per_trial)

    # Save initial model, that is loaded after batch size is found
    save_path = os.path.join(trainer.default_root_dir,
                             'scale_batch_size_temp_model.ckpt')
    trainer.save_checkpoint(str(save_path))

    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.disable()

    # Initially we just double in size until an OOM is encountered
    new_size, _ = _adjust_batch_size(
        trainer, batch_arg_name, value=init_val)  # initially set to init_val
    if mode == 'power':
        new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name,
                                      max_trials)
    elif mode == 'binsearch':
        new_size = _run_binsearch_scaling(trainer, model, new_size,
                                          batch_arg_name, max_trials)
    else:
        raise ValueError(
            'mode in method `scale_batch_size` could either be `power` or `binsearch`'
        )

    garbage_collection_cuda()
    log.info(
        f'Finished batch size finder, will continue with full run using batch size {new_size}'
    )

    # Restore initial state of model
    if trainer.is_global_zero:
        trainer.checkpoint_connector.restore(str(save_path))
        fs = get_filesystem(str(save_path))
        if fs.exists(save_path):
            fs.rm(save_path)

    # Finish by resetting variables so trainer is ready to fit model
    __scale_batch_restore_params(trainer)
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.enable()

    return new_size
def scale_batch_size(trainer,
                     model: LightningModule,
                     mode: str = 'power',
                     steps_per_trial: int = 3,
                     init_val: int = 2,
                     max_trials: int = 25,
                     batch_arg_name: str = 'batch_size',
                     **fit_kwargs):
    r"""
    Will iteratively try to find the largest batch size for a given model
    that does not give an out of memory (OOM) error.

    Args:
        trainer: The Trainer
        model: Model to fit.

        mode: string setting the search mode. Either `power` or `binsearch`.
            If mode is `power` we keep multiplying the batch size by 2, until
            we get an OOM error. If mode is 'binsearch', we will initially
            also keep multiplying by 2 and after encountering an OOM error
            do a binary search between the last successful batch size and the
            batch size that failed.

        steps_per_trial: number of steps to run with a given batch size.
            Idealy 1 should be enough to test if a OOM error occurs,
            however in practise a few are needed

        init_val: initial batch size to start the search with

        max_trials: max number of increase in batch size done before
           algorithm is terminated

        batch_arg_name: name of the attribute that stores the batch size.
            It is expected that the user has provided a model or datamodule that has a hyperparameter
            with that name. We will look for this attribute name in the following places

            - `model`
            - `model.hparams`
            - `model.datamodule`
            - `trainer.datamodule` (the datamodule passed to the tune method)

        **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
            or datamodule.
    """
    if trainer.fast_dev_run:
        rank_zero_warn('Skipping batch size scaler since `fast_dev_run=True`',
                       UserWarning)
        return

    if not lightning_hasattr(model, batch_arg_name):
        raise MisconfigurationException(
            f'Field {batch_arg_name} not found in both `model` and `model.hparams`'
        )
    if hasattr(model, batch_arg_name) and hasattr(
            model, "hparams") and batch_arg_name in model.hparams:
        rank_zero_warn(
            f'Field `model.{batch_arg_name}` and `model.hparams.{batch_arg_name}` are mutually exclusive!'
            f' `model.{batch_arg_name}` will be used as the initial batch size for scaling.'
            f' If this is not the intended behavior, please remove either one.'
        )

    if hasattr(model.train_dataloader, 'patch_loader_code'):
        raise MisconfigurationException(
            'The batch scaling feature cannot be used with dataloaders'
            ' passed directly to `.fit()`. Please disable the feature or'
            ' incorporate the dataloader into the model.')

    # Arguments we adjust during the batch size finder, save for restoring
    __scale_batch_dump_params(trainer)

    # Set to values that are required by the algorithm
    __scale_batch_reset_params(trainer, model, steps_per_trial)

    # Save initial model, that is loaded after batch size is found
    save_path = os.path.join(trainer.default_root_dir,
                             'scale_batch_size_temp_model.ckpt')
    trainer.save_checkpoint(str(save_path))

    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.disable()

    # Initially we just double in size until an OOM is encountered
    new_size = _adjust_batch_size(trainer, batch_arg_name,
                                  value=init_val)  # initially set to init_val
    if mode == 'power':
        new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name,
                                      max_trials, **fit_kwargs)
    elif mode == 'binsearch':
        new_size = _run_binsearch_scaling(trainer, model, new_size,
                                          batch_arg_name, max_trials,
                                          **fit_kwargs)
    else:
        raise ValueError(
            'mode in method `scale_batch_size` can only be `power` or `binsearch'
        )

    garbage_collection_cuda()
    log.info(
        f'Finished batch size finder, will continue with full run using batch size {new_size}'
    )

    # Restore initial state of model
    if trainer.is_global_zero:
        trainer.checkpoint_connector.restore(str(save_path),
                                             on_gpu=trainer.on_gpu)
        fs = get_filesystem(str(save_path))
        if fs.exists(save_path):
            fs.rm(save_path)

    # Finish by resetting variables so trainer is ready to fit model
    __scale_batch_restore_params(trainer)
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.enable()

    return new_size
Beispiel #9
0
def scale_batch_size(
    trainer: "pl.Trainer",
    model: "pl.LightningModule",
    mode: str = "power",
    steps_per_trial: int = 3,
    init_val: int = 2,
    max_trials: int = 25,
    batch_arg_name: str = "batch_size",
) -> Optional[int]:
    """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size`"""
    if trainer.fast_dev_run:
        rank_zero_warn("Skipping batch size scaler since fast_dev_run is enabled.")
        return

    if not lightning_hasattr(model, batch_arg_name):
        raise MisconfigurationException(f"Field {batch_arg_name} not found in both `model` and `model.hparams`")
    if hasattr(model, batch_arg_name) and hasattr(model, "hparams") and batch_arg_name in model.hparams:
        rank_zero_warn(
            f"Field `model.{batch_arg_name}` and `model.hparams.{batch_arg_name}` are mutually exclusive!"
            f" `model.{batch_arg_name}` will be used as the initial batch size for scaling."
            " If this is not the intended behavior, please remove either one."
        )

    if not trainer._data_connector._train_dataloader_source.is_module():
        raise MisconfigurationException(
            "The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`."
            " Please disable the feature or incorporate the dataloader into the model."
        )

    # Save initial model, that is loaded after batch size is found
    ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
    trainer.save_checkpoint(ckpt_path)
    params = __scale_batch_dump_params(trainer)

    # Set to values that are required by the algorithm
    __scale_batch_reset_params(trainer, steps_per_trial)

    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.disable()

    # Initially we just double in size until an OOM is encountered
    new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)  # initially set to init_val
    if mode == "power":
        new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials)
    elif mode == "binsearch":
        new_size = _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials)
    else:
        raise ValueError("mode in method `scale_batch_size` could either be `power` or `binsearch`")

    garbage_collection_cuda()
    log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")

    # Restore initial state of model
    trainer._checkpoint_connector.restore(ckpt_path)
    trainer.strategy.remove_checkpoint(ckpt_path)
    __scale_batch_restore_params(trainer, params)

    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.enable()

    return new_size
    def scale_batch_size(self,
                         model: LightningModule,
                         mode: str = 'power',
                         steps_per_trial: int = 3,
                         init_val: int = 2,
                         max_trials: int = 25,
                         batch_arg_name: str = 'batch_size'):
        r"""
        Will iteratively try to find the largest batch size for a given model
        that does not give an out of memory (OOM) error.

        Args:
            model: Model to fit.

            mode: string setting the search mode. Either `power` or `binsearch`.
                If mode is `power` we keep multiplying the batch size by 2, until
                we get an OOM error. If mode is 'binsearch', we will initially
                also keep multiplying by 2 and after encountering an OOM error
                do a binary search between the last successful batch size and the
                batch size that failed.

            steps_per_trial: number of steps to run with a given batch size.
                Idealy 1 should be enough to test if a OOM error occurs,
                however in practise a few are needed

            init_val: initial batch size to start the search with

            max_trials: max number of increase in batch size done before
               algorithm is terminated

        """
        if not lightning_hasattr(model, batch_arg_name):
            raise MisconfigurationException(
                f'Field {batch_arg_name} not found in both `model` and `model.hparams`')
        if hasattr(model, batch_arg_name) and hasattr(model, "hparams") and batch_arg_name in model.hparams:
            rank_zero_warn(
                f'Field `model.{batch_arg_name}` and `model.hparams.{batch_arg_name}` are mutually exclusive!'
                f' `model.{batch_arg_name}` will be used as the initial batch size for scaling.'
                f' If this is not the intended behavior, please remove either one.'
            )

        if hasattr(model.train_dataloader, 'patch_loader_code'):
            raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders'
                                            ' passed directly to `.fit()`. Please disable the feature or'
                                            ' incorporate the dataloader into the model.')

        # Arguments we adjust during the batch size finder, save for restoring
        self.__scale_batch_dump_params()

        # Set to values that are required by the algorithm
        self.__scale_batch_reset_params(model, steps_per_trial)

        # Save initial model, that is loaded after batch size is found
        save_path = os.path.join(self.default_root_dir, 'temp_model.ckpt')
        self.save_checkpoint(str(save_path))

        if self.progress_bar_callback:
            self.progress_bar_callback.disable()

        # Initially we just double in size until an OOM is encountered
        new_size = _adjust_batch_size(self, value=init_val)  # initially set to init_val
        if mode == 'power':
            new_size = _run_power_scaling(self, model, new_size, batch_arg_name, max_trials)
        elif mode == 'binsearch':
            new_size = _run_binsearch_scaling(self, model, new_size, batch_arg_name, max_trials)
        else:
            raise ValueError('mode in method `scale_batch_size` can only be `power` or `binsearch')

        garbage_collection_cuda()
        log.info(f'Finished batch size finder, will continue with full run using batch size {new_size}')

        # Restore initial state of model
        self.restore(str(save_path), on_gpu=self.on_gpu)
        os.remove(save_path)

        # Finish by resetting variables so trainer is ready to fit model
        self.__scale_batch_restore_params()
        if self.progress_bar_callback:
            self.progress_bar_callback.enable()

        return new_size