Esempio n. 1
0
def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> None:
    trainer.auto_scale_batch_size = None  # prevent recursion
    trainer.auto_lr_find = False  # avoid lr find being called multiple times
    trainer.fit_loop.max_steps = steps_per_trial  # take few steps
    trainer.loggers = [DummyLogger()] if trainer.loggers else []
    trainer.callbacks = []  # not needed before full run
    trainer.limit_train_batches = 1.0
Esempio n. 2
0
def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: float) -> None:
    # avoid lr find being called multiple times
    trainer.auto_lr_find = False
    # Use special lr logger callback
    trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]
    # No logging
    trainer.loggers = [DummyLogger()] if trainer.loggers else []
    # Max step set to number of iterations
    trainer.fit_loop.max_steps = num_training
Esempio n. 3
0
    def on_init_start(
        self,
        limit_train_batches,
        limit_val_batches,
        limit_test_batches,
        limit_predict_batches,
        val_check_interval,
        overfit_batches,
        fast_dev_run,
    ):
        if not isinstance(fast_dev_run, (bool, int)):
            raise MisconfigurationException(
                f'fast_dev_run={fast_dev_run} is not a valid configuration.'
                ' It should be either a bool or an int >= 0')

        if isinstance(fast_dev_run, int) and (fast_dev_run < 0):
            raise MisconfigurationException(
                f'fast_dev_run={fast_dev_run} is not a'
                ' valid configuration. It should be >= 0.')

        self.trainer.fast_dev_run = fast_dev_run
        fast_dev_run = int(fast_dev_run)

        # set fast_dev_run=True when it is 1, used while logging
        if fast_dev_run == 1:
            self.trainer.fast_dev_run = True

        if fast_dev_run:
            limit_train_batches = fast_dev_run
            limit_val_batches = fast_dev_run
            limit_test_batches = fast_dev_run
            limit_predict_batches = fast_dev_run
            self.trainer.fit_loop.max_steps = fast_dev_run
            self.trainer.num_sanity_val_steps = 0
            self.trainer.fit_loop.max_epochs = 1
            val_check_interval = 1.0
            self.trainer.check_val_every_n_epoch = 1
            self.trainer.logger = DummyLogger()

            rank_zero_info(
                'Running in fast_dev_run mode: will run a full train,'
                f' val, test and prediction loop using {fast_dev_run} batch(es).'
            )

        self.trainer.limit_train_batches = _determine_batch_limits(
            limit_train_batches, 'limit_train_batches')
        self.trainer.limit_val_batches = _determine_batch_limits(
            limit_val_batches, 'limit_val_batches')
        self.trainer.limit_test_batches = _determine_batch_limits(
            limit_test_batches, 'limit_test_batches')
        self.trainer.limit_predict_batches = _determine_batch_limits(
            limit_predict_batches, 'limit_predict_batches')
        self.trainer.val_check_interval = _determine_batch_limits(
            val_check_interval, 'val_check_interval')
        self.trainer.overfit_batches = _determine_batch_limits(
            overfit_batches, 'overfit_batches')
        self.determine_data_use_amount(self.trainer.overfit_batches)
def __scale_batch_reset_params(trainer: "pl.Trainer", model: "pl.LightningModule", steps_per_trial: int) -> None:
    trainer.auto_scale_batch_size = None  # prevent recursion
    trainer.auto_lr_find = False  # avoid lr find being called multiple times
    trainer.fit_loop.current_epoch = 0
    trainer.fit_loop.max_steps = steps_per_trial  # take few steps
    trainer.logger = DummyLogger() if trainer.logger is not None else None
    trainer.callbacks = []  # not needed before full run
    trainer.limit_train_batches = 1.0
    trainer.optimizers, trainer.lr_schedulers = [], []  # required for saving
    trainer.model = model  # required for saving
def __scale_batch_reset_params(trainer, model, steps_per_trial):
    trainer.auto_scale_batch_size = None  # prevent recursion
    trainer.auto_lr_find = False  # avoid lr find being called multiple times
    trainer.current_epoch = 0
    trainer.max_steps = steps_per_trial  # take few steps
    trainer.weights_summary = None  # not needed before full run
    trainer.logger = DummyLogger()
    trainer.callbacks = []  # not needed before full run
    trainer.limit_train_batches = 1.0
    trainer.optimizers, trainer.schedulers = [], []  # required for saving
    trainer.model = model  # required for saving
 def __scale_batch_reset_params(self, model, steps_per_trial):
     self.auto_scale_batch_size = None  # prevent recursion
     self.max_steps = steps_per_trial  # take few steps
     self.weights_summary = None  # not needed before full run
     self.logger = DummyLogger()
     self.callbacks = []  # not needed before full run
     self.checkpoint_callback = False  # required for saving
     self.early_stop_callback = None
     self.limit_train_batches = 1.0
     self.optimizers, self.schedulers = [], []  # required for saving
     self.model = model  # required for saving
Esempio n. 7
0
def test_dummylogger_empty_iterable():
    """Test that DummyLogger represents an empty iterable."""
    logger = DummyLogger()
    for _ in logger:
        assert False
Esempio n. 8
0
def test_dummylogger_support_indexing():
    """Test that the DummyLogger can imitate indexing of a LoggerCollection."""
    logger = DummyLogger()
    assert logger[0] == logger
def get_fake_logger(*args, **kwargs):
    return DummyLogger()
Esempio n. 10
0
 def __init__(self):
     self.current_epoch = 1
     self.global_step = 1
     self.logger = DummyLogger()
Esempio n. 11
0
    def lr_find(self,
                model: LightningModule,
                train_dataloader: Optional[DataLoader] = None,
                val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
                min_lr: float = 1e-8,
                max_lr: float = 1,
                num_training: int = 100,
                mode: str = 'exponential',
                early_stop_threshold: float = 4.0,
                num_accumulation_steps=None):
        r"""
        lr_find enables the user to do a range test of good initial learning rates,
        to reduce the amount of guesswork in picking a good starting learning rate.

        Args:
            model: Model to do range testing for

            train_dataloader: A PyTorch
                DataLoader with training samples. If the model has
                a predefined train_dataloader method this will be skipped.

            min_lr: minimum learning rate to investigate

            max_lr: maximum learning rate to investigate

            num_training: number of learning rates to test

            mode: search strategy, either 'linear' or 'exponential'. If set to
                'linear' the learning rate will be searched by linearly increasing
                after each batch. If set to 'exponential', will increase learning
                rate exponentially.

            early_stop_threshold: threshold for stopping the search. If the
                loss at any point is larger than early_stop_threshold*best_loss
                then the search is stopped. To disable, set to None.

            num_accumulation_steps: deprepecated, number of batches to calculate loss over.
                Set trainer argument ``accumulate_grad_batches`` instead.

        Example::

            # Setup model and trainer
            model = MyModelClass(hparams)
            trainer = pl.Trainer()

            # Run lr finder
            lr_finder = trainer.lr_find(model, ...)

            # Inspect results
            fig = lr_finder.plot(); fig.show()
            suggested_lr = lr_finder.suggestion()

            # Overwrite lr and create new model
            hparams.lr = suggested_lr
            model = MyModelClass(hparams)

            # Ready to train with new learning rate
            trainer.fit(model)

        """
        if num_accumulation_steps is not None:
            rank_zero_warn("Argument `num_accumulation_steps` has been deprepecated"
                           " since v0.7.6 and will be removed in 0.9. Please"
                           " set trainer argument `accumulate_grad_batches` instead.",
                           DeprecationWarning)

        save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt')

        self.__lr_finder_dump_params(model)

        # Prevent going into infinite loop
        self.auto_lr_find = False

        # Initialize lr finder object (stores results)
        lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

        # Use special lr logger callback
        self.callbacks = [_LRCallback(num_training,
                                      early_stop_threshold,
                                      progress_bar_refresh_rate=1)]

        # No logging
        self.logger = DummyLogger()

        # Max step set to number of iterations
        self.max_steps = num_training

        # Disable standard progress bar for fit
        if self.progress_bar_callback:
            self.progress_bar_callback.disable()

        # Disable standard checkpoint & early stopping
        self.checkpoint_callback = False
        self.early_stop_callback = None
        self.enable_early_stop = False

        # Required for saving the model
        self.optimizers, self.schedulers = [], [],
        self.model = model

        # Dump model checkpoint
        self.save_checkpoint(str(save_path))

        # Configure optimizer and scheduler
        optimizers, _, _ = self.init_optimizers(model)

        if len(optimizers) != 1:
            raise MisconfigurationException(
                f'`model.configure_optimizers()` returned {len(optimizers)}, but'
                ' learning rate finder only works with single optimizer')
        model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0])

        # Fit, lr & loss logged in callback
        self.fit(model,
                 train_dataloader=train_dataloader,
                 val_dataloaders=val_dataloaders)

        # Prompt if we stopped early
        if self.global_step != num_training:
            log.info('LR finder stopped early due to diverging loss.')

        # Transfer results from callback to lr finder object
        lr_finder.results.update({'lr': self.callbacks[0].lrs,
                                  'loss': self.callbacks[0].losses})
        lr_finder._total_batch_idx = self.total_batch_idx  # for debug purpose

        # Reset model state
        self.restore(str(save_path), on_gpu=self.on_gpu)
        os.remove(save_path)

        # Finish by resetting variables so trainer is ready to fit model
        self.__lr_finder_restore_params(model)
        if self.progress_bar_callback:
            self.progress_bar_callback.enable()

        return lr_finder
Esempio n. 12
0
def lr_find(
    trainer: 'pl.Trainer',
    model: 'pl.LightningModule',
    min_lr: float = 1e-8,
    max_lr: float = 1,
    num_training: int = 100,
    mode: str = 'exponential',
    early_stop_threshold: float = 4.0,
    update_attr: bool = False,
) -> Optional[_LRFinder]:
    """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`"""
    if trainer.fast_dev_run:
        rank_zero_warn(
            'Skipping learning rate finder since fast_dev_run is enabled.',
            UserWarning)
        return

    # Determine lr attr
    if update_attr:
        lr_attr_name = _determine_lr_attr_name(trainer, model)

    save_path = os.path.join(trainer.default_root_dir,
                             'lr_find_temp_model.ckpt')

    __lr_finder_dump_params(trainer, model)

    # Prevent going into infinite loop
    trainer.auto_lr_find = False

    # Initialize lr finder object (stores results)
    lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

    # Use special lr logger callback
    trainer.callbacks = [
        _LRCallback(num_training,
                    early_stop_threshold,
                    progress_bar_refresh_rate=1)
    ]

    # No logging
    trainer.logger = DummyLogger()

    # Max step set to number of iterations
    trainer.max_steps = num_training

    # Disable standard progress bar for fit
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.disable()

    # Required for saving the model
    trainer.optimizers, trainer.schedulers = [], [],
    trainer.model = model

    # Dump model checkpoint
    trainer.save_checkpoint(str(save_path))

    # Configure optimizer and scheduler
    model.configure_optimizers = lr_finder._exchange_scheduler(
        model.configure_optimizers)

    # Fit, lr & loss logged in callback
    trainer.tuner._run(model)

    # Prompt if we stopped early
    if trainer.global_step != num_training:
        log.info(
            f'LR finder stopped early after {trainer.global_step} steps due to diverging loss.'
        )

    # Transfer results from callback to lr finder object
    lr_finder.results.update({
        'lr': trainer.callbacks[0].lrs,
        'loss': trainer.callbacks[0].losses
    })
    lr_finder._total_batch_idx = trainer.total_batch_idx  # for debug purpose

    # Reset model state
    if trainer.is_global_zero:
        trainer.checkpoint_connector.restore(
            str(save_path), on_gpu=trainer._device_type == DeviceType.GPU)
        fs = get_filesystem(str(save_path))
        if fs.exists(save_path):
            fs.rm(save_path)

    # Finish by resetting variables so trainer is ready to fit model
    __lr_finder_restore_params(trainer, model)
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.enable()

    # Update lr attr if required
    if update_attr:
        lr = lr_finder.suggestion()

        # TODO: log lr.results to self.logger
        lightning_setattr(model, lr_attr_name, lr)
        log.info(f'Learning rate set to {lr}')

    return lr_finder
Esempio n. 13
0
def lr_find(
        trainer,
        model: LightningModule,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        min_lr: float = 1e-8,
        max_lr: float = 1,
        num_training: int = 100,
        mode: str = 'exponential',
        early_stop_threshold: float = 4.0,
        datamodule: Optional[LightningDataModule] = None,
):
    r"""
    lr_find enables the user to do a range test of good initial learning rates,
    to reduce the amount of guesswork in picking a good starting learning rate.

    Args:
        model: Model to do range testing for

        train_dataloader: A PyTorch
            DataLoader with training samples. If the model has
            a predefined train_dataloader method, this will be skipped.

        min_lr: minimum learning rate to investigate

        max_lr: maximum learning rate to investigate

        num_training: number of learning rates to test

        mode: search strategy, either 'linear' or 'exponential'. If set to
            'linear' the learning rate will be searched by linearly increasing
            after each batch. If set to 'exponential', will increase learning
            rate exponentially.

        early_stop_threshold: threshold for stopping the search. If the
            loss at any point is larger than early_stop_threshold*best_loss
            then the search is stopped. To disable, set to None.

        datamodule: An optional `LightningDataModule` which holds the training
            and validation dataloader(s). Note that the `train_dataloader` and
            `val_dataloaders` parameters cannot be used at the same time as
            this parameter, or a `MisconfigurationException` will be raised.


    Example::

        # Setup model and trainer
        model = MyModelClass(hparams)
        trainer = pl.Trainer()

        # Run lr finder
        lr_finder = trainer.lr_find(model, ...)

        # Inspect results
        fig = lr_finder.plot(); fig.show()
        suggested_lr = lr_finder.suggestion()

        # Overwrite lr and create new model
        hparams.lr = suggested_lr
        model = MyModelClass(hparams)

        # Ready to train with new learning rate
        trainer.fit(model)

    """
    save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp.ckpt')

    __lr_finder_dump_params(trainer, model)

    # Prevent going into infinite loop
    trainer.auto_lr_find = False

    # Initialize lr finder object (stores results)
    lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

    # Use special lr logger callback
    trainer.callbacks = [_LRCallback(num_training,
                                     early_stop_threshold,
                                     progress_bar_refresh_rate=1)]

    # No logging
    trainer.logger = DummyLogger()

    # Max step set to number of iterations
    trainer.max_steps = num_training

    # Disable standard progress bar for fit
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.disable()

    # Disable standard checkpoint & early stopping
    trainer.checkpoint_callback = False
    trainer.early_stop_callback = None

    # Required for saving the model
    trainer.optimizers, trainer.schedulers = [], [],
    trainer.model = model

    # Dump model checkpoint
    trainer.save_checkpoint(str(save_path))

    # Configure optimizer and scheduler
    model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers)

    # Fit, lr & loss logged in callback
    trainer.fit(model,
                train_dataloader=train_dataloader,
                val_dataloaders=val_dataloaders,
                datamodule=datamodule)

    # Prompt if we stopped early
    if trainer.global_step != num_training:
        log.info('LR finder stopped early due to diverging loss.')

    # Transfer results from callback to lr finder object
    lr_finder.results.update({'lr': trainer.callbacks[0].lrs,
                              'loss': trainer.callbacks[0].losses})
    lr_finder._total_batch_idx = trainer.total_batch_idx  # for debug purpose

    # Reset model state
    trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu)
    os.remove(save_path)

    # Finish by resetting variables so trainer is ready to fit model
    __lr_finder_restore_params(trainer, model)
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.enable()

    return lr_finder
Esempio n. 14
0
def train(**args):
    params = EasyDict(args)
    params.gpu = int(params.gpu)

    config = ConfigParser()
    config.read('config.ini')
    if params.datasets == ['all']:
        params.datasets = ['imdb', 'amazon', 'yelp', 'rottentomatoes', 'hotel']

    is_tokenizer_length_dataset_specific = Models(params.model) == Models.distilbert and (
            params.tokenizer_length is None or params.tokenizer_length)
    is_number_prototypes_dataset_specific = Models(params.model) == Models.protoconv and (
            params.pc_number_of_prototypes is None or params.pc_number_of_prototypes == -1)
    is_sep_loss_dataset_specific = Models(params.model) == Models.protoconv and (
            params.pc_sep_loss_weight is None or params.pc_sep_loss_weight == -1)
    if_ce_loss_dataset_specific = Models(params.model) == Models.protoconv and (
            params.pc_ce_loss_weight is None or params.pc_ce_loss_weight == -1)

    for dataset in params.datasets:
        params.data_set = dataset
        seed_everything(params.seed)

        if is_tokenizer_length_dataset_specific:
            params.tokenizer_length = dataset_tokens_length[params.data_set]

        if is_number_prototypes_dataset_specific:
            params.pc_number_of_prototypes = dataset_to_number_of_prototypes[params.data_set]

        if is_sep_loss_dataset_specific:
            params.pc_sep_loss_weight = dataset_to_separation_loss[params.data_set]

        if if_ce_loss_dataset_specific:
            weight = 1 - (params.pc_cls_loss_weight + params.pc_sep_loss_weight + params.pc_l1_loss_weight)
            assert weight > 0, f'Weight {weight} of cross entropy loss cannot be less or equal to 0'
            params.pc_ce_loss_weight = weight

        logger = DummyLogger()
        if params.logger:
            comet_config = EasyDict(config['cometml'])
            project_name = params.project_name if params.project_name else comet_config.projectname
            logger = CometLogger(api_key=comet_config.apikey, project_name=project_name,
                                 workspace=comet_config.workspace)

        # logger.experiment.log_code(folder='src')
        logger.log_hyperparams(params)
        base_callbacks = [LearningRateMonitor(logging_interval='epoch')]

        df_dataset = pd.read_csv(f'data/{params.data_set}/tokenized_data.csv')
        n_splits = get_n_splits(dataset=df_dataset, x_label='text', y_label='label', folds=params.fold)
        log_splits(n_splits, logger)

        embeddings = GloVe('42B', cache=params.cache) if Models(params.model) != Models.distilbert else None

        best_models_scores, number_of_prototypes = [], []
        for fold_id, (train_index, val_index, test_index) in enumerate(n_splits):
            i = str(fold_id)

            model_checkpoint = ModelCheckpoint(
                filepath='checkpoints/fold_' + i + '_{epoch:02d}-{val_loss_' + i + ':.4f}-{val_acc_' + i + ':.4f}',
                save_weights_only=True, save_top_k=1, monitor='val_acc_' + i,
                period=params.pc_project_prototypes_every_n
            )
            early_stop = EarlyStopping(monitor=f'val_loss_{i}', patience=10, verbose=True, mode='min', min_delta=0.005)
            callbacks = deepcopy(base_callbacks) + [model_checkpoint, early_stop]

            lit_module = model_to_litmodule[params.model]
            train_df, valid_df = df_dataset.iloc[train_index + val_index], df_dataset.iloc[test_index]
            model, train_loader, val_loader, *utils = lit_module.from_params_and_dataset(train_df, valid_df, params,
                                                                                         fold_id, embeddings)
            trainer = Trainer(auto_lr_find=params.find_lr, logger=logger, max_epochs=params.epoch, callbacks=callbacks,
                              gpus=params.gpu, deterministic=True, fast_dev_run=params.fast_dev_run,
                              num_sanity_val_steps=0)

            trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)
            trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)

            for absolute_path in model_checkpoint.best_k_models.keys():
                logger.experiment.log_model(Path(absolute_path).name, absolute_path)

            if model_checkpoint.best_model_score:
                best_models_scores.append(model_checkpoint.best_model_score.tolist())
                logger.log_metrics({'best_model_score_' + i: model_checkpoint.best_model_score.tolist()}, step=0)

            if Models(params.model) == Models.protoconv and model_checkpoint.best_model_path:
                best_model = lit_module.load_from_checkpoint(model_checkpoint.best_model_path)
                saved_number_of_prototypes = sum(best_model.enabled_prototypes_mask.tolist())
                number_of_prototypes.append(saved_number_of_prototypes)
                logger.log_hyperparams({
                    f'saved_prototypes_{fold_id}': saved_number_of_prototypes,
                    f'best_model_path_{fold_id}': str(Path(model_checkpoint.best_model_path).name)
                })

                if params.pc_visualize:
                    data_visualizer = DataVisualizer(best_model)
                    logger.experiment.log_html(f'<h1>Split {fold_id}</h1><br> <h3>Prototypes:</h3><br>'
                                               f'{data_visualizer.visualize_prototypes()}<br>')
                    logger.experiment.log_figure(f'Prototypes similarity_{fold_id}',
                                                 data_visualizer.visualize_similarity().figure)
                    logger.experiment.log_html(f'<h3>Random prediction explanations:</h3><br>'
                                               f'{data_visualizer.visualize_random_predictions(val_loader, n=15)}')

        if len(best_models_scores) >= 1:
            avg_best, std_best = float(np.mean(np.array(best_models_scores))), float(
                np.std(np.array(best_models_scores)))
            table_entry = f'{avg_best:.3f} ($\pm${std_best:.3f})'

            logger.log_hyperparams({
                'avg_best_scores': avg_best,
                'std_best_scores': std_best,
                'table_entry': table_entry
            })

        if len(number_of_prototypes) >= 1:
            logger.log_hyperparams({'avg_saved_prototypes': float(np.mean(np.array(number_of_prototypes)))})

        logger.experiment.end()
Esempio n. 15
0
def test_dummylogger_noop_method_calls():
    """Test that the DummyLogger methods can be called with arbitrary arguments."""
    logger = DummyLogger()
    logger.log_hyperparams("1", 2, three="three")
    logger.log_metrics("1", 2, three="three")
Esempio n. 16
0
def lr_find(
    trainer,
    model: LightningModule,
    train_dataloader: Optional[DataLoader] = None,
    val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
    min_lr: float = 1e-8,
    max_lr: float = 1,
    num_training: int = 100,
    mode: str = 'exponential',
    early_stop_threshold: float = 4.0,
    datamodule: Optional[LightningDataModule] = None,
    update_attr: bool = False,
):
    r"""
    ``lr_find`` enables the user to do a range test of good initial learning rates,
    to reduce the amount of guesswork in picking a good starting learning rate.

    Args:
        model: Model to do range testing for

        train_dataloader: A PyTorch
            ``DataLoader`` with training samples. If the model has
            a predefined train_dataloader method, this will be skipped.

        min_lr: minimum learning rate to investigate

        max_lr: maximum learning rate to investigate

        num_training: number of learning rates to test

        mode: Search strategy to update learning rate after each batch:

            - ``'exponential'`` (default): Will increase the learning rate exponentially.
            - ``'linear'``: Will increase the learning rate linearly.

        early_stop_threshold: threshold for stopping the search. If the
            loss at any point is larger than early_stop_threshold*best_loss
            then the search is stopped. To disable, set to None.

        datamodule: An optional ``LightningDataModule`` which holds the training
            and validation dataloader(s). Note that the ``train_dataloader`` and
            ``val_dataloaders`` parameters cannot be used at the same time as
            this parameter, or a ``MisconfigurationException`` will be raised.

        update_attr: Whether to update the learning rate attribute or not.

    Raises:
        MisconfigurationException:
            If learning rate/lr in ``model`` or ``model.hparams`` isn't overriden when ``auto_lr_find=True``, or
            if you are using `more than one optimizer` with learning rate finder.

    Example::

        # Setup model and trainer
        model = MyModelClass(hparams)
        trainer = pl.Trainer()

        # Run lr finder
        lr_finder = trainer.tuner.lr_find(model, ...)

        # Inspect results
        fig = lr_finder.plot(); fig.show()
        suggested_lr = lr_finder.suggestion()

        # Overwrite lr and create new model
        hparams.lr = suggested_lr
        model = MyModelClass(hparams)

        # Ready to train with new learning rate
        trainer.fit(model)

    """
    if trainer.fast_dev_run:
        rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning)
        return

    # Determine lr attr
    if update_attr:
        lr_attr_name = _determine_lr_attr_name(trainer, model)

    save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt')

    __lr_finder_dump_params(trainer, model)

    # Prevent going into infinite loop
    trainer.auto_lr_find = False

    # Initialize lr finder object (stores results)
    lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

    # Use special lr logger callback
    trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]

    # No logging
    trainer.logger = DummyLogger()

    # Max step set to number of iterations
    trainer.max_steps = num_training

    # Disable standard progress bar for fit
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.disable()

    # Required for saving the model
    trainer.optimizers, trainer.schedulers = [], [],
    trainer.model = model

    # Dump model checkpoint
    trainer.save_checkpoint(str(save_path))

    # Configure optimizer and scheduler
    model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers)

    # Fit, lr & loss logged in callback
    trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule)

    # Prompt if we stopped early
    if trainer.global_step != num_training:
        log.info('LR finder stopped early due to diverging loss.')

    # Transfer results from callback to lr finder object
    lr_finder.results.update({'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses})
    lr_finder._total_batch_idx = trainer.total_batch_idx  # for debug purpose

    # Reset model state
    if trainer.is_global_zero:
        trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU)
        fs = get_filesystem(str(save_path))
        if fs.exists(save_path):
            fs.rm(save_path)

    # Finish by resetting variables so trainer is ready to fit model
    __lr_finder_restore_params(trainer, model)
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.enable()

    # Update lr attr if required
    if update_attr:
        lr = lr_finder.suggestion()

        # TODO: log lr.results to self.logger
        lightning_setattr(model, lr_attr_name, lr)
        log.info(f'Learning rate set to {lr}')

    return lr_finder
Esempio n. 17
0
def test_dummylogger_support_indexing():
    logger = DummyLogger()
    assert logger[0] == logger