Example #1
0
    def train(self):
        self.run_sanity_check(self.get_model())

        # enable train mode
        model = self.get_model()
        model.train()
        torch.set_grad_enabled(True)

        # reload data when needed
        self.train_loop.reset_train_val_dataloaders(model)

        # hook
        self.train_loop.on_train_start()

        try:
            # run all epochs
            for epoch in range(self.current_epoch, self.max_epochs):

                # reset train dataloader
                if self.reload_dataloaders_every_epoch:
                    self.reset_train_dataloader(model)

                # hook
                self.train_loop.on_train_epoch_start(epoch)

                # run train epoch
                self.train_loop.run_training_epoch()

                if self.max_steps and self.max_steps <= self.global_step:

                    # hook
                    self.train_loop.on_train_end()
                    return

                # update LR schedulers
                self.optimizer_connector.update_learning_rates(
                    interval='epoch')

                # early stopping
                met_min_epochs = epoch >= self.min_epochs - 1
                met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

                if self.should_stop:
                    if (met_min_epochs and met_min_steps):
                        self.train_loop.on_train_end()
                        return
                    else:
                        log.info(
                            'Trainer was signaled to stop but required minimum epochs'
                            f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
                            ' not been met. Training will continue...')

            # hook
            self.train_loop.on_train_end()

        except KeyboardInterrupt:
            rank_zero_warn(
                'Detected KeyboardInterrupt, attempting graceful shutdown...')

            # user could press ctrl+c many times... only shutdown once
            if not self.interrupted:
                self.interrupted = True
                self._state = TrainerState.INTERRUPTED
                self.on_keyboard_interrupt()

                # hook
                self.train_loop.on_train_end()
Example #2
0
    def run_train(self):

        self._pre_training_routine()

        if not self.is_global_zero and self.progress_bar_callback is not None:
            self.progress_bar_callback.disable()

        self.run_sanity_check(self.lightning_module)

        # set stage for logging
        self._running_stage = RunningStage.TRAINING

        self.checkpoint_connector.has_trained = False

        # enable train mode
        model = self.lightning_module
        model.train()
        torch.set_grad_enabled(True)

        # reload data when needed
        self.train_loop.reset_train_val_dataloaders(model)

        # hook
        self.train_loop.on_train_start()

        try:
            if self.train_loop.should_skip_training():
                return
            # run all epochs
            epochs = range(self.current_epoch,
                           self.max_epochs) if self.max_epochs else count(
                               self.current_epoch)
            for epoch in epochs:

                # hook
                self.train_loop.on_train_epoch_start(epoch)

                with self.profiler.profile("run_training_epoch"):
                    # run train epoch
                    self.train_loop.run_training_epoch()

                if self.max_steps and self.max_steps <= self.global_step:
                    return

                # early stopping
                met_min_epochs = (
                    epoch >= self.min_epochs - 1) if self.min_epochs else True
                met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

                if self.should_stop:
                    if met_min_epochs and met_min_steps:
                        return
                    else:
                        log.info(
                            'Trainer was signaled to stop but required minimum epochs'
                            f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
                            ' not been met. Training will continue...')

            # hook
            self.train_loop.on_train_end()

        except KeyboardInterrupt:
            rank_zero_warn(
                'Detected KeyboardInterrupt, attempting graceful shutdown...')

            # user could press ctrl+c many times... only shutdown once
            if not self.interrupted:
                self.interrupted = True
                self._state = TrainerState.INTERRUPTED
                self.on_keyboard_interrupt()
        finally:
            # hook
            self.train_loop.on_train_end()
 def kth_best_model(self):
     rank_zero_warn(
         "Attribute `kth_best_model` has been renamed to `kth_best_model_path` since v0.8.0"
         " and will be removed in v0.10.0", DeprecationWarning)
     return self.kth_best_model_path
Example #4
0
    def run_evaluation(self, on_epoch=False):
        if not (self.evaluating or self.sanity_checking):
            rank_zero_warn(
                f"`trainer.run_evaluation()` was called but the running stage is set to {self._running_stage}."
                " This should not happen normally. Setting it to `RunningStage.VALIDATING`",
                RuntimeWarning)
            self.validating = True

        # prepare dataloaders
        dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(
        )

        # check if we want to skip this evaluation
        if self.evaluation_loop.should_skip_evaluation(max_batches):
            return [], []

        # enable eval mode + no grads
        self.evaluation_loop.on_evaluation_model_eval()
        # ref model
        model = self.lightning_module
        model.zero_grad()
        torch.set_grad_enabled(False)

        # hook
        self.evaluation_loop.on_evaluation_start()

        # set up the eval loop
        self.evaluation_loop.setup(max_batches, dataloaders)

        # hook
        self.evaluation_loop.on_evaluation_epoch_start()

        # run validation/testing
        for dataloader_idx, dataloader in enumerate(dataloaders):
            # bookkeeping
            dl_outputs = []
            dataloader = self.accelerator.process_dataloader(dataloader)
            dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]

            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue

                # stop short when running on limited batches
                if batch_idx >= dl_max_batches:
                    break

                # hook
                self.evaluation_loop.on_evaluation_batch_start(
                    batch, batch_idx, dataloader_idx)

                # lightning module methods
                with self.profiler.profile("evaluation_step_and_end"):
                    output = self.evaluation_loop.evaluation_step(
                        batch, batch_idx, dataloader_idx)
                    output = self.evaluation_loop.evaluation_step_end(output)

                # hook + store predictions
                self.evaluation_loop.on_evaluation_batch_end(
                    output, batch, batch_idx, dataloader_idx)

                # log batch metrics
                self.evaluation_loop.log_evaluation_step_metrics(batch_idx)

                # track epoch level outputs
                dl_outputs = self.track_output_for_epoch_end(
                    dl_outputs, output)

            # store batch level output per dataloader
            self.evaluation_loop.outputs.append(dl_outputs)

        outputs = self.evaluation_loop.outputs

        # reset outputs
        self.evaluation_loop.outputs = []

        # with a single dataloader don't pass a 2D list
        if self.evaluation_loop.num_dataloaders == 1:
            outputs = outputs[0]

        # lightning module method
        self.evaluation_loop.evaluation_epoch_end(outputs)

        # hook
        self.evaluation_loop.on_evaluation_epoch_end(outputs)

        # update epoch-level lr_schedulers
        if on_epoch:
            self.optimizer_connector.update_learning_rates(interval='epoch')

        # hook
        self.evaluation_loop.on_evaluation_end()

        # log epoch metrics
        eval_loop_results = self.logger_connector.get_evaluate_epoch_results()

        # save predictions to disk
        self.evaluation_loop.predictions.to_disk()

        # enable train mode again
        self.evaluation_loop.on_evaluation_model_train()

        # reset cached results
        self.logger_connector.reset()

        torch.set_grad_enabled(True)

        return eval_loop_results
Example #5
0
 def _load_bolts_unet(_, num_classes: int, **kwargs) -> nn.Module:
     rank_zero_warn(
         "The UNet model does not require a backbone, so the backbone will be ignored.",
         UserWarning)
     return UNet(num_classes, **kwargs)
Example #6
0
 def use_ddp2(self, val: bool) -> None:
     rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning)
     if val:
         self.accelerator_connector._distrib_type = DistributedType.DDP2
    def _reset_eval_dataloader(
            self, model: LightningModule,
            mode: str) -> Tuple[List[Union[int, float]], List[DataLoader]]:
        """Generic method to reset a dataloader for evaluation.

        Args:
            model: The current `LightningModule`
            mode: Either `'val'` or `'test'`

        Returns:
            Tuple (num_batches, dataloaders)
        """
        # use the training loader as val and test when overfitting
        if self.overfit_batches > 0:
            dataloaders = self.request_dataloader(
                getattr(model, 'train_dataloader'))
        else:
            dataloaders = self.request_dataloader(
                getattr(model, f'{mode}_dataloader'))

        if not isinstance(dataloaders, list):
            dataloaders = [dataloaders]

        for loader_i in range(len(dataloaders)):
            loader = dataloaders[loader_i]

            # shuffling in val and test set is bad practice
            if mode in ('val', 'test') and hasattr(
                    loader, 'sampler') and isinstance(loader.sampler,
                                                      RandomSampler):

                # when overfitting, the dataloader should not have sampler
                if self.overfit_batches > 0:
                    rank_zero_warn(
                        'You requested to overfit but enabled training dataloader shuffling.'
                        ' We are turning it off for you.')
                    dataloaders[loader_i] = self.replace_sampler(
                        loader, SequentialSampler(loader.dataset))

                else:
                    rank_zero_warn(
                        f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn'
                        ' this off for validation and test dataloaders.')

        if any([dl is None for dl in dataloaders]):
            rank_zero_warn(
                "One of given dataloaders is None and it will be skipped.")

        # add samplers
        dataloaders = [
            self.auto_add_sampler(dl, train=False) for dl in dataloaders
            if dl is not None
        ]

        loader_num_batches = []

        # determine number of batches
        # datasets could be none, 1 or 2+
        if len(dataloaders) != 0:
            for i, dataloader in enumerate(dataloaders):
                num_batches = len(dataloader) if _has_len(
                    dataloader) else float('inf')
                self._worker_check(dataloader, f'{mode} dataloader {i}')

                # percent or num_steps
                limit_eval_batches = getattr(self, f'limit_{mode}_batches')

                if num_batches != float('inf'):
                    self._check_batch_limits(f'limit_{mode}_batches')

                    # limit num batches either as a percent or num steps
                    if isinstance(limit_eval_batches, float):
                        num_batches = int(num_batches * limit_eval_batches)
                    else:
                        num_batches = min(len(dataloader), limit_eval_batches)

                elif limit_eval_batches not in (0.0, 1.0):
                    raise MisconfigurationException(
                        'When using an infinite DataLoader (e.g. with an IterableDataset'
                        f' or when DataLoader does not implement `__len__`) for `limit_{mode}_batches`,'
                        f' `Trainer(limit_{mode}_batches)` must be `0.0` or `1.0`.'
                    )

                if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(
                        limit_eval_batches, float):
                    min_pct = 1.0 / len(dataloader)
                    raise MisconfigurationException(
                        f'you requested to check {limit_eval_batches} of the {mode} dataloader but'
                        f' {limit_eval_batches}*{num_batches} = 0. Please increase the limit_{mode}_batches.'
                        f' Try at least limit_{mode}_batches={min_pct}')

                loader_num_batches.append(num_batches)

        return loader_num_batches, dataloaders
Example #8
0
    def restore_training_state(self, checkpoint):
        """
        Restore trainer state.
        Model will get its change to update
        :param checkpoint:
        :return:
        """
        # validation
        if 'optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint:
            raise KeyError(
                'Trying to restore training state but checkpoint contains only the model.'
                ' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
            )

        if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]):
            raise ValueError(
                "The checkpoint you're attempting to load follows an"
                " outdated schema. You can upgrade to the current schema by running"
                " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
                " where `model.ckpt` is your checkpoint file.")

        # restore amp scaling
        if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
            self.trainer.scaler.load_state_dict(
                checkpoint['native_amp_scaling_state'])
        elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
            amp.load_state_dict(checkpoint['amp_scaling_state'])

        # restore callback states
        self.trainer.on_load_checkpoint(checkpoint)

        self.trainer.global_step = checkpoint['global_step']
        self.trainer.current_epoch = checkpoint['epoch']

        # crash if max_epochs is lower then the current epoch from the checkpoint
        if self.trainer.current_epoch > self.trainer.max_epochs:
            m = f"""
            you restored a checkpoint with current_epoch={self.trainer.current_epoch}
            but the Trainer(max_epochs={self.trainer.max_epochs})
            """
            raise MisconfigurationException(m)

        # Division deals with global step stepping once per accumulated batch
        # Inequality deals with different global step for odd vs even num_training_batches
        n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches
        expected_steps = self.trainer.num_training_batches / n_accum
        if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1:
            rank_zero_warn(
                "You're resuming from a checkpoint that ended mid-epoch."
                " Training will start from the beginning of the next epoch."
                " This can cause unreliable results if further training is done,"
                " consider using an end of epoch checkpoint.")

        # restore the optimizers
        optimizer_states = checkpoint['optimizer_states']
        for optimizer, opt_state in zip(self.trainer.optimizers,
                                        optimizer_states):
            optimizer.load_state_dict(opt_state)

            # move optimizer to GPU 1 weight at a time
            # avoids OOM
            if self.trainer.root_gpu is not None:
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.cuda(self.trainer.root_gpu)

        # restore the lr schedulers
        lr_schedulers = checkpoint['lr_schedulers']
        for scheduler, lrs_state in zip(self.trainer.lr_schedulers,
                                        lr_schedulers):
            scheduler['scheduler'].load_state_dict(lrs_state)
    def train(self):
        rank_zero_warn(
            'Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
            ' but will start from "0" in v0.8.0.', RuntimeWarning)

        # get model
        model = self.get_model()

        # load data
        # if reload_dataloaders_every_epoch, this is moved to the epoch loop
        if not self.reload_dataloaders_every_epoch:
            self.reset_train_dataloader(model)
        self.reset_val_dataloader(model)

        # Train start events
        with self.profiler.profile('on_train_start'):
            # callbacks
            self.on_train_start()
            # initialize early stop callback
            if self.early_stop_callback is not None:
                self.early_stop_callback.on_train_start(self, self.get_model())
            # model hooks
            model.on_train_start()

        try:
            # run all epochs
            for epoch in range(self.current_epoch, self.max_epochs):
                # reset train dataloader
                if self.reload_dataloaders_every_epoch:
                    self.reset_train_dataloader(model)
                # set seed for distributed sampler (enables shuffling for each epoch)
                if self.use_ddp \
                        and hasattr(self.train_dataloader.sampler, 'set_epoch'):
                    self.train_dataloader.sampler.set_epoch(epoch)

                # update training progress in trainer and model
                model.current_epoch = epoch
                self.current_epoch = epoch

                total_val_batches = 0
                is_val_epoch = False
                if not self.disable_validation and self.num_training_batches != float(
                        'inf'):
                    # val can be checked multiple times in epoch
                    is_val_epoch = (self.current_epoch +
                                    1) % self.check_val_every_n_epoch == 0
                    val_checks_per_epoch = self.num_training_batches // self.val_check_batch
                    val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
                    total_val_batches = self.num_val_batches * val_checks_per_epoch

                # total batches includes multiple val checks
                self.total_batches = self.num_training_batches + total_val_batches

                # changing gradient according accumulation_scheduler
                self.accumulation_scheduler.on_epoch_start(
                    self, self.get_model())

                # stores accumulated grad fractions per batch
                self.batch_loss_value = TensorRunningAccum(
                    window_length=self.accumulate_grad_batches)

                if self.fast_dev_run:
                    # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
                    num_iterations = 2
                elif self.total_batches == float('inf'):
                    # for infinite train or val loader, the progress bar never ends
                    num_iterations = None
                else:
                    num_iterations = self.total_batches

                # reset progress bar
                # .reset() doesn't work on disabled progress bar so we should check
                if not self.main_progress_bar.disable:
                    self.main_progress_bar.reset(num_iterations)
                desc = f'Epoch {epoch + 1}'
                self.main_progress_bar.set_description(desc)

                # -----------------
                # RUN TNG EPOCH
                # -----------------
                self.run_training_epoch()

                # update LR schedulers
                self.update_learning_rates(interval='epoch')

                if self.max_steps and self.max_steps == self.global_step:
                    self.run_training_teardown()
                    return

                # early stopping
                met_min_epochs = epoch >= self.min_epochs - 1
                met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

                # TODO wrap this logic into the callback
                if self.enable_early_stop:
                    if (met_min_epochs and met_min_steps) or self.fast_dev_run:
                        should_stop = self.early_stop_callback.on_epoch_end(
                            self, self.get_model())
                        # stop training
                        stop = should_stop and met_min_epochs
                        if stop:
                            self.run_training_teardown()
                            return

            self.run_training_teardown()

        except KeyboardInterrupt:
            if self.proc_rank == 0:
                log.info(
                    'Detected KeyboardInterrupt, attempting graceful shutdown...'
                )
            self.interrupted = True
            self.run_training_teardown()
    def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
        """Resets the train dataloader and initialises required variables (number of batches, when to validate,
        etc.).

        Args:
            model: The `LightningModule` if calling this outside of the trainer scope.
        """
        self.train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model)

        if self.overfit_batches > 0:
            self.train_dataloader = self._resolve_overfit_batches(self.train_dataloader)

        # automatically add samplers
        self.train_dataloader = apply_to_collection(
            self.train_dataloader, DataLoader, self.prepare_dataloader, shuffle=True, mode=RunningStage.TRAINING
        )

        # check the workers recursively
        apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train_dataloader")

        # add worker_init_fn for correct seeding in worker processes
        apply_to_collection(self.train_dataloader, DataLoader, self._auto_add_worker_init_fn, rank=self.global_rank)

        # add collate_fn to collect metadata for fault tolerant training
        if _fault_tolerant_training():
            apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate)

        # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
        self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode)

        module = model or self.lightning_module or self.datamodule
        self.num_training_batches = (
            len(self.train_dataloader)
            if has_len_all_ranks(self.train_dataloader, self.training_type_plugin, module)
            else float("inf")
        )

        if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
            self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
        elif self.num_training_batches != float("inf"):
            self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
        elif self.limit_train_batches != 1.0:
            raise MisconfigurationException(
                "When using an IterableDataset for `limit_train_batches`,"
                " `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
                " `num_training_batches` to use."
            )

        # determine when to check validation
        # if int passed in, val checks that often
        # otherwise, it checks in [0, 1.0] % range of a training epoch
        if isinstance(self.val_check_interval, int):
            self.val_check_batch = self.val_check_interval
            if self.val_check_batch > self.num_training_batches:
                raise ValueError(
                    f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
                    f"to the number of the training batches ({self.num_training_batches}). "
                    "If you want to disable validation set `limit_val_batches` to 0.0 instead."
                )
        else:
            if not has_len_all_ranks(self.train_dataloader, self.training_type_plugin, module):
                if self.val_check_interval == 1.0:
                    self.val_check_batch = float("inf")
                else:
                    raise MisconfigurationException(
                        "When using an IterableDataset for `train_dataloader`,"
                        " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
                        " checking validation every k training batches."
                    )
            else:
                self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
                self.val_check_batch = max(1, self.val_check_batch)

        if self.logger and self.num_training_batches < self.log_every_n_steps:
            rank_zero_warn(
                f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
                f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
                " you want to see logs for the training epoch."
            )
    def _reset_eval_dataloader(
        self, mode: RunningStage, model: Optional["pl.LightningModule"] = None
    ) -> Tuple[List[Union[int, float]], List[DataLoader]]:
        """Generic method to reset a dataloader for evaluation.

        Args:
            mode: The running stage of the ``Trainer``
            model: The ``LightningModule`` if calling this outside of the trainer scope.

        Returns:
            Tuple (num_batches, dataloaders)
        """
        assert mode.evaluating or mode == RunningStage.PREDICTING

        # always get the loaders first so we can count how many there are
        dataloaders = self.request_dataloader(mode, model=model)

        if not isinstance(dataloaders, list):
            dataloaders = [dataloaders]

        # when overfitting, use the training loader as val and test
        # duplicate it the numb of times needed to match the train loaders
        if self.overfit_batches > 0:
            train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model)
            dataloaders = [deepcopy(train_dataloader) for _ in range(len(dataloaders))]

        for loader_i in range(len(dataloaders)):
            loader = dataloaders[loader_i]

            if hasattr(loader, "sampler") and not isinstance(loader.sampler, SequentialSampler):
                # when overfitting, the dataloader should not have sampler
                if self.overfit_batches > 0 and mode.evaluating:
                    rank_zero_warn(
                        "You requested to overfit but enabled val/test dataloader shuffling."
                        " We are turning it off for you."
                    )
                    dataloaders[loader_i] = self._update_dataloader(
                        loader, SequentialSampler(loader.dataset), mode=mode
                    )
                else:
                    rank_zero_warn(
                        f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
                        "it is strongly recommended that you turn this off for val/test/predict dataloaders."
                    )

        if any(dl is None for dl in dataloaders):
            rank_zero_warn("One of given dataloaders is None and it will be skipped.")

        # add samplers
        dataloaders = [self.prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None]

        # add worker_init_fn for correct seeding in worker processes
        apply_to_collection(
            dataloaders, dtype=DataLoader, function=self._auto_add_worker_init_fn, rank=self.global_rank
        )

        loader_num_batches = []

        # determine number of batches
        # datasets could be none, 1 or 2+
        module = model or self.lightning_module or self.datamodule
        if len(dataloaders) != 0:
            for i, dataloader in enumerate(dataloaders):
                num_batches = (
                    len(dataloader)
                    if has_len_all_ranks(dataloader, self.training_type_plugin, module)
                    else float("inf")
                )
                self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}")

                # percent or num_steps
                limit_eval_batches = getattr(self, f"limit_{mode.dataloader_prefix}_batches")

                # limit num batches either as a percent or num steps
                if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
                    num_batches = min(num_batches, int(limit_eval_batches))
                elif num_batches != float("inf"):
                    num_batches = int(num_batches * limit_eval_batches)
                elif limit_eval_batches != 1.0:
                    raise MisconfigurationException(
                        f"When using an IterableDataset for `limit_{mode}_batches`,"
                        f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `0.0`, `1.0` or an int. An int k"
                        f" specifies `num_{mode.dataloader_prefix}_batches` to use."
                    )

                if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
                    min_pct = 1.0 / len(dataloader)
                    raise MisconfigurationException(
                        f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but"
                        f" {limit_eval_batches}*{num_batches} < 1. Please increase the"
                        f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least"
                        f" `limit_{mode.dataloader_prefix}_batches={min_pct}`"
                    )

                loader_num_batches.append(num_batches)

        return loader_num_batches, dataloaders
Example #12
0
    def run_train(self) -> None:

        self._pre_training_routine()

        if not self.is_global_zero and self.progress_bar_callback is not None:
            self.progress_bar_callback.disable()

        self.run_sanity_check(self.lightning_module)

        self.checkpoint_connector.has_trained = False

        # enable train mode
        model = self.lightning_module
        model.train()
        torch.set_grad_enabled(True)

        # reload data when needed
        self.train_loop.reset_train_val_dataloaders(model)

        # hook
        self.train_loop.on_train_start()

        try:
            if self.train_loop.should_skip_training():
                return
            # run all epochs
            epochs = range(self.current_epoch, self.max_epochs) if self.max_epochs else count(self.current_epoch)
            for epoch in epochs:

                # hook
                self.train_loop.on_train_epoch_start(epoch)

                with self.profiler.profile("run_training_epoch"):
                    # run train epoch
                    self.train_loop.run_training_epoch()

                if self.max_steps and self.max_steps <= self.global_step:
                    return

                # early stopping
                met_min_epochs = (epoch >= self.min_epochs - 1) if self.min_epochs else True
                met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

                if self.should_stop:
                    if met_min_epochs and met_min_steps:
                        return
                    else:
                        log.info(
                            'Trainer was signaled to stop but required minimum epochs'
                            f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
                            ' not been met. Training will continue...'
                        )

            # hook
            self.train_loop.on_train_end()

        except KeyboardInterrupt:
            rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')
            # user could press Ctrl+c many times... only shutdown once
            if not self.interrupted:
                self.state = TrainerState.INTERRUPTED
                self.on_keyboard_interrupt()
        except (RuntimeError, AssertionError):
            # if an exception is raised, the finally block is executed and can hide the actual exception
            # that was initially raised if `on_train_end` also raises an exception. we want to avoid that
            # for assertions and other runtime errors so we aren't misled while debugging
            print_exc()
        finally:
            # hook
            self.train_loop.on_train_end()
Example #13
0
    def init_optimizers(
            self,
            model: Optional["pl.LightningModule"]) -> Tuple[List, List, List]:
        pl_module = self.lightning_module or model
        self._lightning_optimizers = None
        optim_conf = self.call_hook("configure_optimizers",
                                    pl_module=pl_module)
        if optim_conf is None:
            rank_zero_warn(
                "`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
                UserWarning,
            )
            optim_conf = _MockOptimizer()

        optimizers, lr_schedulers, optimizer_frequencies = [], [], []
        monitor = None

        # single output, single optimizer
        if isinstance(optim_conf, Optimizer):
            optimizers = [optim_conf]
        # two lists, optimizer + lr schedulers
        elif (isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2
              and isinstance(optim_conf[0], list)
              and all(isinstance(opt, Optimizer) for opt in optim_conf[0])):
            opt, sch = optim_conf
            optimizers = opt
            lr_schedulers = sch if isinstance(sch, list) else [sch]
        # single dictionary
        elif isinstance(optim_conf, dict):
            optimizers = [optim_conf["optimizer"]]
            monitor = optim_conf.get("monitor", None)
            lr_schedulers = [optim_conf["lr_scheduler"]
                             ] if "lr_scheduler" in optim_conf else []
        # multiple dictionaries
        elif isinstance(optim_conf, (list, tuple)) and all(
                isinstance(d, dict) for d in optim_conf):
            optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
            scheduler_dict = (
                lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx)
                if isinstance(scheduler, dict) else {
                    "scheduler": scheduler,
                    "opt_idx": opt_idx
                })

            lr_schedulers = [
                scheduler_dict(opt_dict["lr_scheduler"], opt_idx)
                for opt_idx, opt_dict in enumerate(optim_conf)
                if "lr_scheduler" in opt_dict
            ]
            optimizer_frequencies = [
                opt_dict["frequency"] for opt_dict in optim_conf
                if opt_dict.get("frequency", None) is not None
            ]
            # assert that if frequencies are present, they are given for all optimizers
            if optimizer_frequencies and len(optimizer_frequencies) != len(
                    optimizers):
                raise ValueError(
                    "A frequency must be given to each optimizer.")
        # single list or tuple, multiple optimizer
        elif isinstance(optim_conf, (list, tuple)) and all(
                isinstance(opt, Optimizer) for opt in optim_conf):
            optimizers = list(optim_conf)
        # unknown configuration
        else:
            raise MisconfigurationException(
                "Unknown configuration for model optimizers."
                " Output from `model.configure_optimizers()` should either be:\n"
                " * `torch.optim.Optimizer`\n"
                " * [`torch.optim.Optimizer`]\n"
                " * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n"
                ' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n'
                ' * A list of the previously described dict format, with an optional "frequency" key (int)'
            )

        is_manual_optimization = not pl_module.automatic_optimization
        lr_schedulers = self.configure_schedulers(lr_schedulers, monitor,
                                                  is_manual_optimization)
        _validate_scheduler_optimizer(optimizers, lr_schedulers)

        return optimizers, lr_schedulers, optimizer_frequencies
Example #14
0
    def configure_schedulers(
            self, schedulers: list, monitor: Optional[str],
            is_manual_optimization: bool) -> List[Dict[str, Any]]:
        """Convert each scheduler into dict structure with relevant information"""
        lr_schedulers = []
        default_config = _get_default_scheduler_config()
        for scheduler in schedulers:
            if is_manual_optimization:
                if isinstance(scheduler, dict):
                    invalid_keys = {
                        "interval", "frequency", "reduce_on_plateau",
                        "monitor", "strict"
                    }
                    keys_to_warn = [
                        k for k in scheduler.keys() if k in invalid_keys
                    ]

                    if keys_to_warn:
                        rank_zero_warn(
                            f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."
                            " You need to call `lr_scheduler.step()` manually in manual optimization.",
                            RuntimeWarning,
                        )

                    scheduler = {
                        key: scheduler[key]
                        for key in scheduler if key not in invalid_keys
                    }
                    lr_schedulers.append({**default_config, **scheduler})
                else:
                    lr_schedulers.append({
                        **default_config, "scheduler":
                        scheduler
                    })
            else:
                if isinstance(scheduler, dict):
                    # check provided keys
                    extra_keys = [
                        k for k in scheduler.keys()
                        if k not in default_config.keys()
                    ]
                    if extra_keys:
                        rank_zero_warn(
                            f"Found unsupported keys in the lr scheduler dict: {extra_keys}",
                            RuntimeWarning)
                    if "scheduler" not in scheduler:
                        raise MisconfigurationException(
                            'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
                        )
                    if "interval" in scheduler and scheduler[
                            "interval"] not in ("step", "epoch"):
                        raise MisconfigurationException(
                            'The "interval" key in lr scheduler dict must be "step" or "epoch"'
                            f' but is "{scheduler["interval"]}"')
                    scheduler["reduce_on_plateau"] = isinstance(
                        scheduler["scheduler"],
                        optim.lr_scheduler.ReduceLROnPlateau)
                    if scheduler["reduce_on_plateau"] and scheduler.get(
                            "monitor", None) is None:
                        raise MisconfigurationException(
                            "The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
                            ' For example: {"optimizer": optimizer, "lr_scheduler":'
                            ' {"scheduler": scheduler, "monitor": "your_loss"}}'
                        )
                    lr_schedulers.append({**default_config, **scheduler})
                elif isinstance(scheduler,
                                optim.lr_scheduler.ReduceLROnPlateau):
                    if monitor is None:
                        raise MisconfigurationException(
                            "`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
                            " scheduler is used. For example:"
                            ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
                        )
                    lr_schedulers.append({
                        **default_config, "scheduler": scheduler,
                        "reduce_on_plateau": True,
                        "monitor": monitor
                    })
                elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                    lr_schedulers.append({
                        **default_config, "scheduler":
                        scheduler
                    })
                else:
                    raise ValueError(
                        f'The provided lr scheduler "{scheduler}" is invalid')
        return lr_schedulers
Example #15
0
 def on_gpu(self, val: bool) -> None:
     rank_zero_warn("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning)
     if val:
         self.accelerator_connector._device_type = DeviceType.GPU
    def training_forward(self, batch, batch_idx, opt_idx, hiddens):
        """
        Handle forward for each training case (distributed, single gpu, etc...)
        :param batch:
        :param batch_idx:
        :return:
        """
        # ---------------
        # FORWARD
        # ---------------
        # enable not needing to add opt_idx to training_step
        args = [batch, batch_idx]

        if len(self.optimizers) > 1:
            if self.has_arg('training_step', 'optimizer_idx'):
                args.append(opt_idx)
            else:
                num_opts = len(self.optimizers)
                raise ValueError(
                    f'Your LightningModule defines {num_opts} optimizers but '
                    f'training_step is missing the "optimizer_idx" argument.')

        # pass hiddens if using tbptt
        if self.truncated_bptt_steps is not None:
            args.append(hiddens)

        # distributed forward
        if self.use_ddp or self.use_ddp2 or self.use_dp:
            output = self.model(*args)

        # single GPU forward
        elif self.single_gpu:
            gpu_id = 0
            if isinstance(self.data_parallel_device_ids, list):
                gpu_id = self.data_parallel_device_ids[0]
            batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id)
            args[0] = batch
            output = self.model.training_step(*args)

        # TPU support
        elif self.use_tpu:
            batch = self.transfer_batch_to_tpu(copy.copy(batch))
            args[0] = batch
            output = self.model.training_step(*args)

        # CPU forward
        else:
            output = self.model.training_step(*args)

        # allow any mode to define training_step_end
        # do something will all the dp outputs (like softmax)
        if self.is_overriden('training_step_end'):
            model_ref = self.get_model()
            with self.profiler.profile('training_step_end'):
                output = model_ref.training_step_end(output)

        # allow any mode to define training_end
        # TODO: remove in 1.0.0
        if self.is_overriden('training_end'):
            model_ref = self.get_model()
            with self.profiler.profile('training_end'):
                output = model_ref.training_end(output)

            rank_zero_warn(
                '`training_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
                ' Use training_epoch_end instead', DeprecationWarning)

        return output
Example #17
0
 def use_ddp2(self) -> bool:
     rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning)
     return self.accelerator_connector._distrib_type == DistributedType.DDP2
    def transform(
        self,
        sample: Any,
    ) -> Union[Classification, Classifications, Dict[str, Any]]:
        pred = sample[DataKeys.PREDS] if isinstance(sample, Dict) else sample
        pred = torch.tensor(pred)

        logits = None
        if self.store_logits:
            logits = pred.tolist()

        if self.multi_label:
            one_hot = (pred.sigmoid() > self.threshold).int().tolist()
            classes = []
            for index, value in enumerate(one_hot):
                if value == 1:
                    classes.append(index)
            probabilities = torch.sigmoid(pred).tolist()
        else:
            classes = torch.argmax(pred, -1).tolist()
            probabilities = torch.softmax(pred, -1).tolist()

        if self._labels is not None:
            if self.multi_label:
                classifications = []
                for idx in classes:
                    fo_cls = fol.Classification(
                        label=self._labels[idx],
                        confidence=probabilities[idx],
                    )
                    classifications.append(fo_cls)
                fo_predictions = fol.Classifications(
                    classifications=classifications,
                    logits=logits,
                )
            else:
                confidence = max(probabilities)
                if self.threshold is not None and confidence < self.threshold:
                    fo_predictions = None
                else:
                    fo_predictions = fol.Classification(
                        label=self._labels[classes],
                        confidence=confidence,
                        logits=logits,
                    )
        else:
            rank_zero_warn("No labels were provided, int targets will be used as label strings.", category=UserWarning)

            if self.multi_label:
                classifications = []
                for idx in classes:
                    fo_cls = fol.Classification(
                        label=str(idx),
                        confidence=probabilities[idx],
                    )
                    classifications.append(fo_cls)
                fo_predictions = fol.Classifications(
                    classifications=classifications,
                    logits=logits,
                )
            else:
                confidence = max(probabilities)
                if self.threshold is not None and confidence < self.threshold:
                    fo_predictions = None
                else:
                    fo_predictions = fol.Classification(
                        label=str(classes),
                        confidence=confidence,
                        logits=logits,
                    )

        if self.return_filepath:
            filepath = sample[DataKeys.METADATA]["filepath"]
            return {"filepath": filepath, "predictions": fo_predictions}
        return fo_predictions
Example #19
0
    def __init__(
        self,
        logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
        checkpoint_callback: Union[ModelCheckpoint, bool] = True,
        early_stop_callback: Optional[Union[EarlyStopping, bool]] = False,
        callbacks: Optional[List[Callback]] = None,
        default_root_dir: Optional[str] = None,
        gradient_clip_val: float = 0,
        process_position: int = 0,
        num_nodes: int = 1,
        num_processes: int = 1,
        gpus: Optional[Union[List[int], str, int]] = None,
        auto_select_gpus: bool = False,
        tpu_cores: Optional[Union[List[int], int]] = None,
        log_gpu_memory: Optional[str] = None,
        progress_bar_refresh_rate: int = 1,
        overfit_pct: float = 0.0,
        track_grad_norm: Union[int, float, str] = -1,
        check_val_every_n_epoch: int = 1,
        fast_dev_run: bool = False,
        accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
        max_epochs: int = 1000,
        min_epochs: int = 1,
        max_steps: Optional[int] = None,
        min_steps: Optional[int] = None,
        train_percent_check: float = 1.0,
        val_percent_check: float = 1.0,
        test_percent_check: float = 1.0,
        val_check_interval: float = 1.0,
        log_save_interval: int = 100,
        row_log_interval: int = 50,
        distributed_backend: Optional[str] = None,
        precision: int = 32,
        print_nan_grads: bool = False,  # backward compatible, todo: remove in v0.9.0
        weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
        weights_save_path: Optional[str] = None,
        num_sanity_val_steps: int = 2,
        truncated_bptt_steps: Optional[int] = None,
        resume_from_checkpoint: Optional[str] = None,
        profiler: Optional[Union[BaseProfiler, bool]] = None,
        benchmark: bool = False,
        deterministic: bool = False,
        reload_dataloaders_every_epoch: bool = False,
        auto_lr_find: Union[bool, str] = False,
        replace_sampler_ddp: bool = True,
        terminate_on_nan: bool = False,
        auto_scale_batch_size: Union[str, bool] = False,
        prepare_data_per_node: bool = True,
        amp_level: str = 'O1',  # backward compatible, todo: remove in v1.0.0
        num_tpu_cores: Optional[int] = None,  # backward compatible, todo: remove in v0.9.0
        use_amp=None,  # backward compatible, todo: remove in v0.9.0
        show_progress_bar=None,  # backward compatible, todo: remove in v0.9.0
    ):
        r"""

        Customize every aspect of training via flags

        Args:
            logger: Logger (or iterable collection of loggers) for experiment tracking.

            checkpoint_callback: Callback for checkpointing.

            early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`):

            callbacks: Add a list of callbacks.

            default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed

            gradient_clip_val: 0 means don't clip.

            gradient_clip:
                .. warning:: .. deprecated:: 0.7.0

                    Use `gradient_clip_val` instead. Will remove 0.9.0.

            process_position: orders the progress bar when running multiple models on same machine.

            num_nodes: number of GPU nodes for distributed training.

            nb_gpu_nodes:
                .. warning:: .. deprecated:: 0.7.0

                    Use `num_nodes` instead. Will remove 0.9.0.

            gpus: Which GPUs to train on.

            auto_select_gpus:

                If enabled and `gpus` is an integer, pick available
                gpus automatically. This is especially useful when
                GPUs are configured to be in "exclusive mode", such
                that only one process at a time can access them.

            tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1]

            num_tpu_cores: How many TPU cores to train on (1 or 8)
                .. warning:: .. deprecated:: 0.7.6. Will remove 0.9.0.

            log_gpu_memory: None, 'min_max', 'all'. Might slow performance

            show_progress_bar:
                .. warning:: .. deprecated:: 0.7.2

                        Set `progress_bar_refresh_rate` to positive integer to enable. Will remove 0.9.0.

            progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
                Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`.

            overfit_pct: How much of training-, validation-, and test dataset to check.

            track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.

            check_val_every_n_epoch: Check val every n train epochs.

            fast_dev_run: runs 1 batch of train, test  and val to find any bugs (ie: a sort of unit test).

            accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.

            max_epochs: Stop training once this number of epochs is reached.

            max_nb_epochs:
                .. warning:: .. deprecated:: 0.7.0

                    Use `max_epochs` instead. Will remove 0.9.0.

            min_epochs: Force training for at least these many epochs

            min_nb_epochs:
                .. warning:: .. deprecated:: 0.7.0

                    Use `min_epochs` instead. Will remove 0.9.0.

            max_steps: Stop training after this number of steps. Disabled by default (None).

            min_steps: Force training for at least these number of steps. Disabled by default (None).

            train_percent_check: How much of training dataset to check.

            val_percent_check: How much of validation dataset to check.

            test_percent_check: How much of test dataset to check.

            val_check_interval: How often within one training epoch to check the validation set

            log_save_interval: Writes logs to disk this often

            row_log_interval: How often to add logging rows (does not write to disk)

            add_row_log_interval:
                .. warning:: .. deprecated:: 0.7.0

                    Use `row_log_interval` instead. Will remove 0.9.0.

            distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn)

            use_amp:
                .. warning:: .. deprecated:: 0.7.0

                    Use `precision` instead. Will remove 0.9.0.

            precision: Full precision (32), half precision (16).

            print_nan_grads:
                .. warning:: .. deprecated:: 0.7.2

                    Has no effect. When detected, NaN grads will be printed automatically.
                    Will remove 0.9.0.

            weights_summary: Prints a summary of the weights when training begins.

            weights_save_path: Where to save weights if specified. Will override default_root_dir
                    for checkpoints only. Use this if for whatever reason you need the checkpoints
                    stored in a different place than the logs written in `default_root_dir`.

            amp_level: The optimization level to use (O1, O2, etc...).

            num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.

            truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of

            resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.
                This can be a URL.

            profiler:  To profile individual steps during training and assist in

            reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch

            auto_lr_find: If set to True, will `initially` run a learning rate finder,
                trying to optimize initial learning for faster convergence. Sets learning
                rate in self.lr or self.learning_rate in the LightningModule.
                To use a different key, set a string instead of True with the key name.

            replace_sampler_ddp: Explicitly enables or disables sampler replacement.
                If not specified this will toggled automatically ddp is used

            benchmark: If true enables cudnn.benchmark.

            deterministic: If true enables cudnn.deterministic

            terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
                end of each training batch, if any of the parameters or the loss are NaN or +/-inf.

            auto_scale_batch_size: If set to True, will `initially` run a batch size
                finder trying to find the largest batch size that fits into memory.
                The result will be stored in self.batch_size in the LightningModule.
                Additionally, can be set to either `power` that estimates the batch size through
                a power search or `binsearch` that estimates the batch size through a binary search.

            prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
                Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data
        """
        super().__init__()

        self.deterministic = deterministic
        torch.backends.cudnn.deterministic = self.deterministic
        if self.deterministic:
            # fixing non-deterministic part of horovod
            # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
            os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

        # Init callbacks
        self.prepare_data_per_node = prepare_data_per_node
        self.callbacks = callbacks or []
        self.on_init_start()

        # benchmarking
        self.benchmark = benchmark
        torch.backends.cudnn.benchmark = self.benchmark

        # Transfer params
        self.num_nodes = num_nodes
        self.log_gpu_memory = log_gpu_memory

        self.gradient_clip_val = gradient_clip_val
        self.check_val_every_n_epoch = check_val_every_n_epoch

        if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
            raise MisconfigurationException(
                "track_grad_norm can be an int, a float or 'inf' (infinity norm).")
        self.track_grad_norm = float(track_grad_norm)

        self.on_gpu = True if (gpus and torch.cuda.is_available()) else False

        # tpu config
        if num_tpu_cores is not None:
            rank_zero_warn("Argument `num_tpu_cores` is now set by `tpu_cores` since v0.7.6"
                           " and this argument will be removed in v0.9.0", DeprecationWarning)

        if tpu_cores is None:
            tpu_cores = num_tpu_cores
        self.on_tpu = tpu_cores is not None
        self.tpu_cores = tpu_cores
        assert self.tpu_cores in (1, 8, None) or (
            isinstance(self.tpu_cores, (list, tuple, set)) and len(self.tpu_cores) == 1
        ), '`tpu_cores` can only be 1, 8 or [<1-8>]'

        self.tpu_id = tpu_cores[0] if isinstance(tpu_cores, list) else None

        if num_processes != 1 and distributed_backend != "ddp_cpu":
            rank_zero_warn("num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it.")
        self.num_processes = num_processes

        self.weights_summary = weights_summary

        self.max_epochs = max_epochs
        self.min_epochs = min_epochs
        self.max_steps = max_steps
        self.min_steps = min_steps

        self.num_sanity_val_steps = num_sanity_val_steps
        # Backward compatibility, TODO: remove in v0.9.0
        if print_nan_grads:
            rank_zero_warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0."
                           " NaN grads will be printed automatically when detected.", DeprecationWarning)

        self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

        self.auto_lr_find = auto_lr_find
        self.auto_scale_batch_size = auto_scale_batch_size
        self._is_data_prepared = False
        self.replace_sampler_ddp = replace_sampler_ddp

        self.truncated_bptt_steps = truncated_bptt_steps
        self.resume_from_checkpoint = resume_from_checkpoint
        self.terminate_on_nan = terminate_on_nan
        self.shown_warnings = set()

        self.fast_dev_run = fast_dev_run
        if self.fast_dev_run:
            self.num_sanity_val_steps = 0
            self.max_epochs = 1
            rank_zero_info('Running in fast_dev_run mode: will run a full train,'
                           ' val and test loop using a single batch')

        # set default save path if user didn't provide one
        self.default_root_dir = default_root_dir

        if self.default_root_dir is None:
            self.default_root_dir = os.getcwd()

        # training bookeeping
        self.total_batch_idx = 0
        self.running_loss = TensorRunningAccum(window_length=20)
        self.batch_idx = 0
        self.progress_bar_metrics = {}
        self.callback_metrics = {}
        self.num_val_batches = 0
        self.num_training_batches = 0
        self.num_test_batches = 0
        self.train_dataloader = None
        self.test_dataloaders = None
        self.val_dataloaders = None

        # training state
        self.model = None
        self.testing = False
        self.disable_validation = False
        self.lr_schedulers = []
        self.optimizers = None
        self.optimizer_frequencies = []
        self.global_step = 0
        self.current_epoch = 0
        self.interrupted = False

        # configure logger
        self.configure_logger(logger)

        # configure profiler
        if profiler is True:
            profiler = SimpleProfiler()
        self.profiler = profiler or PassThroughProfiler()

        # configure early stop callback
        # creates a default one if none passed in
        self.configure_early_stopping(early_stop_callback)

        # configure checkpoint callback
        self.checkpoint_callback = checkpoint_callback
        self.weights_save_path = weights_save_path

        # accumulated grads
        self.accumulate_grad_batches = accumulate_grad_batches
        self.configure_accumulated_gradients(accumulate_grad_batches)

        # for gpus allow int, string and gpu list
        if auto_select_gpus and isinstance(gpus, int):
            self.gpus = pick_multiple_gpus(gpus)
        else:
            self.gpus = gpus

        self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
        self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
        self.root_device = torch.device("cpu")

        # tpu state flags
        self.use_tpu = False
        self.tpu_local_core_rank = None
        self.tpu_global_core_rank = None

        # distributed backend choice
        self.distributed_backend = distributed_backend
        self.set_distributed_mode(distributed_backend)

        # override dist backend when using tpus
        if self.on_tpu:
            self.init_tpu()

        # init flags for SLURM+ddp to work
        self.world_size = 1
        self.interactive_ddp_procs = []
        self.configure_slurm_ddp(self.num_nodes)
        self.node_rank = self.determine_ddp_node_rank()
        self.local_rank = self.determine_local_rank()
        self.global_rank = 0

        # nvidia setup
        self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)

        # backward compatibility
        if show_progress_bar is not None:
            self.show_progress_bar = show_progress_bar

        self._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)

        # logging
        self.log_save_interval = log_save_interval
        self.val_check_interval = val_check_interval

        self.row_log_interval = row_log_interval

        # how much of the data to use
        self.overfit_pct = overfit_pct
        self.determine_data_use_amount(train_percent_check, val_percent_check,
                                       test_percent_check, overfit_pct)

        # AMP init
        # These are the only lines needed after v0.8.0
        # we wrap the user's forward with autocast and give it back at the end of fit
        self.autocast_original_forward = None
        self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
        self.precision = precision
        self.scaler = None

        self.amp_level = amp_level
        self.init_amp(use_amp)

        self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')

        # Callback system
        self.on_init_end()
Example #20
0
 def accelerator_backend(self) -> Accelerator:
     rank_zero_warn(
         "The `Trainer.accelerator_backend` attribute is deprecated in favor of `Trainer.accelerator`"
         " since 1.2 and will be removed in v1.4.", DeprecationWarning
     )
     return self.accelerator
Example #21
0
 def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
     if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
         rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
Example #22
0
 def get_model(self) -> LightningModule:
     rank_zero_warn(
         "The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`"
         " and will be removed in v1.4.", DeprecationWarning
     )
     return self.lightning_module
Example #23
0
def lr_find(
        trainer,
        model: LightningModule,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        min_lr: float = 1e-8,
        max_lr: float = 1,
        num_training: int = 100,
        mode: str = 'exponential',
        early_stop_threshold: float = 4.0,
        datamodule: Optional[LightningDataModule] = None,
):
    r"""
    `lr_find` enables the user to do a range test of good initial learning rates,
    to reduce the amount of guesswork in picking a good starting learning rate.

    Args:
        model: Model to do range testing for

        train_dataloader: A PyTorch
            `DataLoader` with training samples. If the model has
            a predefined train_dataloader method, this will be skipped.

        min_lr: minimum learning rate to investigate

        max_lr: maximum learning rate to investigate

        num_training: number of learning rates to test

        mode: search strategy, either 'linear' or 'exponential'. If set to
            'linear' the learning rate will be searched by linearly increasing
            after each batch. If set to 'exponential', will increase learning
            rate exponentially.

        early_stop_threshold: threshold for stopping the search. If the
            loss at any point is larger than early_stop_threshold*best_loss
            then the search is stopped. To disable, set to None.

        datamodule: An optional `LightningDataModule` which holds the training
            and validation dataloader(s). Note that the `train_dataloader` and
            `val_dataloaders` parameters cannot be used at the same time as
            this parameter, or a `MisconfigurationException` will be raised.


    Example::

        # Setup model and trainer
        model = MyModelClass(hparams)
        trainer = pl.Trainer()

        # Run lr finder
        lr_finder = trainer.tuner.lr_find(model, ...)

        # Inspect results
        fig = lr_finder.plot(); fig.show()
        suggested_lr = lr_finder.suggestion()

        # Overwrite lr and create new model
        hparams.lr = suggested_lr
        model = MyModelClass(hparams)

        # Ready to train with new learning rate
        trainer.fit(model)

    """
    if trainer.fast_dev_run:
        rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning)
        return

    save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt')

    __lr_finder_dump_params(trainer, model)

    # Prevent going into infinite loop
    trainer.auto_lr_find = False

    # Initialize lr finder object (stores results)
    lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

    # Use special lr logger callback
    trainer.callbacks = [_LRCallback(num_training,
                                     early_stop_threshold,
                                     progress_bar_refresh_rate=1)]

    # No logging
    trainer.logger = DummyLogger()

    # Max step set to number of iterations
    trainer.max_steps = num_training

    # Disable standard progress bar for fit
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.disable()

    # Required for saving the model
    trainer.optimizers, trainer.schedulers = [], [],
    trainer.model = model

    # Dump model checkpoint
    trainer.save_checkpoint(str(save_path))

    # Configure optimizer and scheduler
    model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers)

    # Fit, lr & loss logged in callback
    trainer.fit(model,
                train_dataloader=train_dataloader,
                val_dataloaders=val_dataloaders,
                datamodule=datamodule)

    # Prompt if we stopped early
    if trainer.global_step != num_training:
        log.info('LR finder stopped early due to diverging loss.')

    # Transfer results from callback to lr finder object
    lr_finder.results.update({'lr': trainer.callbacks[0].lrs,
                              'loss': trainer.callbacks[0].losses})
    lr_finder._total_batch_idx = trainer.total_batch_idx  # for debug purpose

    # Reset model state
    if trainer.is_global_zero:
        trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu)
        fs = get_filesystem(str(save_path))
        if fs.exists(save_path):
            fs.rm(save_path)

    # Finish by resetting variables so trainer is ready to fit model
    __lr_finder_restore_params(trainer, model)
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.enable()

    return lr_finder
Example #24
0
 def use_tpu(self) -> bool:
     rank_zero_warn("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning)
     return self.on_tpu
Example #25
0
    def _evaluate(
        self,
        model: LightningModule,
        dataloaders: List[DataLoader],
        max_batches: Union[int, List[int]],
        test_mode: bool = False
    ):
        """Run evaluation code.

        Args:
            model: The model to evaluate.
            dataloaders: A list of PyTorch dataloaders.
            max_batches: An integer or list of integers with length of the number of dataloaders. Each
                entry is the number of batches to process in the corresponding dataloader.
            test_mode:
        """
        # enable eval mode
        model.zero_grad()
        model.eval()

        # copy properties for forward overrides
        self.copy_trainer_model_properties(model)

        # disable gradients to save memory
        torch.set_grad_enabled(False)

        # bookkeeping
        outputs = []

        # convert max_batches to list
        if isinstance(max_batches, int):
            max_batches = [max_batches] * len(dataloaders)

        # run validation
        for dataloader_idx, dataloader in enumerate(dataloaders):
            dl_outputs = []

            # on TPU we have to wrap it under the ParallelLoader
            if self.use_tpu:
                device = xm.xla_device(self.tpu_id)
                dataloader = xla_pl.ParallelLoader(dataloader, [device])
                dataloader = dataloader.per_device_loader(device)

            # each dataloader has a max num batches
            dl_max_batches = max_batches[dataloader_idx]

            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue

                # stop short when on fast_dev_run (sets max_batch=1)
                if batch_idx >= dl_max_batches:
                    break

                # callbacks
                if test_mode:
                    self.on_test_batch_start()
                else:
                    self.on_validation_batch_start()

                # -----------------
                # RUN EVALUATION STEP
                # -----------------
                if self.use_amp and NATIVE_AMP_AVALAIBLE:
                    with torch.cuda.amp.autocast():
                        output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
                else:
                    output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)

                # on dp / ddp2 might still want to do something with the batch parts
                if test_mode:
                    if self.is_overridden('test_step_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('test_step_end'):
                            output = model_ref.test_step_end(output)
                    self.on_test_batch_end()
                else:
                    if self.is_overridden('validation_step_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('validation_step_end'):
                            output = model_ref.validation_step_end(output)
                    self.on_validation_batch_end()

                # track outputs for collation
                dl_outputs.append(output)

            outputs.append(dl_outputs)

        eval_results = {}

        # with a single dataloader don't pass an array
        if len(dataloaders) == 1:
            outputs = outputs[0]

        # give model a chance to do something with the outputs (and method defined)
        if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)):
            model = model.module

        if test_mode:
            if self.is_overridden('test_end', model=model):
                # TODO: remove in v1.0.0
                eval_results = model.test_end(outputs)
                rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed v1.0.'
                               ' Use `test_epoch_end` instead.', DeprecationWarning)

            elif self.is_overridden('test_epoch_end', model=model):
                eval_results = model.test_epoch_end(outputs)

        else:
            if self.is_overridden('validation_end', model=model):
                # TODO: remove in v1.0.0
                eval_results = model.validation_end(outputs)
                rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed v1.0.'
                               ' Use `validation_epoch_end` instead.', DeprecationWarning)

            elif self.is_overridden('validation_epoch_end', model=model):
                eval_results = model.validation_epoch_end(outputs)

        # aggregate ddp stats across
        has_content = eval_results is not None and len(eval_results) > 0
        if has_content and (self.use_ddp or self.use_ddp2):
            self.reduce_eval_ddp(eval_results)

        # enable train mode again
        model.train()

        # enable gradients to save memory
        torch.set_grad_enabled(True)

        return eval_results
Example #26
0
 def use_tpu(self, val: bool) -> None:
     rank_zero_warn("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning)
     self.on_tpu = val
 def best(self):
     rank_zero_warn(
         "Attribute `best` has been renamed to `best_model_score` since v0.8.0"
         " and will be removed in v0.10.0", DeprecationWarning)
     return self.best_model_score
Example #28
0
 def on_gpu(self) -> bool:
     rank_zero_warn("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning)
     return self.accelerator_connector._device_type == DeviceType.GPU
    def on_validation_end(self, trainer, pl_module):
        # only run on main process
        if trainer.global_rank != 0:
            return

        metrics = trainer.callback_metrics
        epoch = trainer.current_epoch

        # support structured results
        if metrics.get('checkpoint_on') is not None:
            self.monitor = 'checkpoint_on'

        if self.save_top_k == 0:
            # no models are saved
            return
        if self.epoch_last_check is not None and (
                epoch - self.epoch_last_check) < self.period:
            # skipping in this term
            return

        self.epoch_last_check = epoch

        if self.save_last:
            filepath = os.path.join(self.dirpath, self.prefix + 'last.ckpt')
            self._save_model(filepath, trainer, pl_module)

        filepath = self.format_checkpoint_name(epoch, metrics)
        version_cnt = 0
        while os.path.isfile(filepath):
            filepath = self.format_checkpoint_name(epoch,
                                                   metrics,
                                                   ver=version_cnt)
            # this epoch called before
            version_cnt += 1

        if self.save_top_k != -1:
            current = metrics.get(self.monitor)

            if not isinstance(current, torch.Tensor):
                rank_zero_warn(
                    f'The metric you returned {current} must be a `torch.Tensor` instance, checkpoint not saved'
                    f' HINT: what is the value of {self.monitor} in validation_epoch_end()?',
                    RuntimeWarning)
                if current is not None:
                    current = torch.tensor(current)

            if current is None:
                rank_zero_warn(
                    f'Can save best model only with {self.monitor} available, skipping.',
                    RuntimeWarning)
            elif self.check_monitor_top_k(current):
                self._do_check_save(filepath, current, epoch, trainer,
                                    pl_module)
            elif self.verbose > 0:
                log.info(
                    f'\nEpoch {epoch:05d}: {self.monitor}  was not in top {self.save_top_k}'
                )

        else:
            if self.verbose > 0:
                log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')

            assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
            self._save_model(filepath, trainer, pl_module)
Example #30
0
    def _reset_eval_dataloader(
            self, model: LightningModule,
            mode: str) -> Tuple[Union[int, float], List[DataLoader]]:
        """Generic method to reset a dataloader for evaluation.

        Args:
            model: The current `LightningModule`
            mode: Either `'val'` or `'test'`

        Returns:
            Tuple (num_batches, dataloaders)
        """
        dataloaders = self.request_dataloader(
            getattr(model, f'{mode}_dataloader'))

        if not isinstance(dataloaders, list):
            dataloaders = [dataloaders]

        # shuffling in val and test set is bad practice
        for loader in dataloaders:
            if mode in ('val', 'test') and hasattr(
                    loader, 'sampler') and isinstance(loader.sampler,
                                                      RandomSampler):
                rank_zero_warn(
                    f'Your {mode}_dataloader has shuffle=True, it is best practice to turn'
                    ' this off for validation and test dataloaders.')

        if any([dl is None for dl in dataloaders]):
            rank_zero_warn(
                "One of given dataloaders is None and it will be skipped.")

        # add samplers
        dataloaders = [
            self.auto_add_sampler(dl, train=False) for dl in dataloaders
            if dl is not None
        ]

        num_batches = 0

        # determine number of batches
        # datasets could be none, 1 or 2+
        if len(dataloaders) != 0:
            for i, dataloader in enumerate(dataloaders):
                self._worker_check(dataloader, f'{mode} dataloader {i}')
                if not _has_len(dataloader):
                    num_batches = float('inf')

            percent_check = getattr(self, f'{mode}_percent_check')

            if num_batches != float('inf'):
                self._percent_range_check(f'{mode}_percent_check')

                num_batches = sum(
                    len(dataloader) for dataloader in dataloaders)
                num_batches = int(num_batches * percent_check)
            elif percent_check not in (0.0, 1.0):
                raise MisconfigurationException(
                    'When using an infinite DataLoader (e.g. with an IterableDataset'
                    f' or when DataLoader does not implement `__len__`) for `{mode}_dataloader`,'
                    f' `Trainer({mode}_percent_check)` must be `0.0` or `1.0`.'
                )
        return num_batches, dataloaders