Ejemplo n.º 1
0
    def __init_monitor_mode(self):
        # TODO: Update with MisconfigurationException when auto mode is removed in v1.3
        if self.mode not in self.mode_dict and self.mode != 'auto':
            if self.verbose > 0:
                rank_zero_warn(
                    f'EarlyStopping mode={self.mode} is unknown, fallback to auto mode.',
                    RuntimeWarning,
                )
            self.mode = 'auto'

        if self.mode == 'auto':
            rank_zero_warn(
                "mode='auto' is deprecated in v1.1 and will be removed in v1.3."
                " Default value for mode with be 'min' in v1.3.",
                DeprecationWarning)

            if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
                self.mode = 'max'
            else:
                self.mode = 'min'

            if self.verbose > 0:
                rank_zero_info(
                    f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.'
                )
Ejemplo n.º 2
0
    def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str,
                                                                       Any]):
        if self.monitor is None or self.save_top_k == 0:
            return

        current = monitor_candidates.get(self.monitor)
        epoch = monitor_candidates.get("epoch")
        step = monitor_candidates.get("step")

        # when `val_loss` is being logged and no ModelCheckpoint is being provided
        # `val_loss` will be selected for monitor and need to be reduced to
        # prevent processes divergence
        # TODO: Move this logic to logger_connector. This also needs to be fixed for any
        # other monitor logged value which aren't produced from a Metric.
        if self.monitor == "val_loss":
            current = trainer.training_type_plugin.reduce(current,
                                                          reduce_op="mean")

        if self.check_monitor_top_k(current):
            self._update_best_and_save(current, epoch, step, trainer,
                                       monitor_candidates)
        elif self.verbose:
            rank_zero_info(
                f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}"
            )
Ejemplo n.º 3
0
    def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]) -> None:
        """
        Called before model state is restored. Explicitly handles old model
        states so we can resume training from D2Go checkpoints transparently.

        Args:
            checkpointed_state: The raw checkpoint state as returned by torch.load
                or equivalent.
        """
        # If this is a non-Lightning checkpoint, we need to convert it.
        if not _is_lightning_checkpoint(checkpointed_state) and not _is_d2go_checkpoint(
            checkpointed_state
        ):
            raise ValueError(
                f"Invalid checkpoint state with keys: {checkpointed_state.keys()}"
            )
        if not _is_lightning_checkpoint(checkpointed_state):
            _convert_to_lightning(checkpointed_state)

        if self.ema_state:
            if "model_ema" not in checkpointed_state:
                rank_zero_info(
                    "EMA is enabled but EMA state is not found in given checkpoint"
                )
            else:
                self.ema_state = EMAState()
                self.ema_state.load_state_dict(checkpointed_state["model_ema"])
                if not self.ema_state.device:
                    # EMA state device not given, move to module device
                    self.ema_state.to(self.device)
Ejemplo n.º 4
0
    def _initialize_model_specific_parameters(self):
        task_specific_params = self.model.config.task_specific_params

        if task_specific_params:
            pars = task_specific_params.get(self.task, {})
            rank_zero_info(f"Overriding model paramameters for {self.task} as defined within the model:\n {pars}")
            self.model.config.update(pars)
Ejemplo n.º 5
0
 def on_validation_end(self, trainer: pl.Trainer,
                       pl_module: pl.LightningModule):
     rank_zero_info("***** Validation results *****")
     metrics = trainer.callback_metrics
     # Log results
     for key in sorted(metrics):
         rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
Ejemplo n.º 6
0
 def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
     rank_zero_info("***** Validation results *****")
     metrics = trainer.callback_metrics
     # Log results
     for key in sorted(metrics):
         if key not in ["log", "progress_bar"]:
             rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
    def resume_start(self) -> None:
        """
        Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:

        1. from HPC weights if found
        2. from `resume_from_checkpoint` file if provided
        3. don't restore

        Raises:
            FileNotFoundError: If the path to the checkpoint file is provided but the file does not exist.
        """
        self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path
        checkpoint_path = self.resume_checkpoint_path
        if not checkpoint_path:
            return

        # clear cache before restore
        torch.cuda.empty_cache()

        # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
        fs = get_filesystem(checkpoint_path)
        if not fs.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint at {checkpoint_path} not found. Aborting training.")

        rank_zero_info(f"Restoring states from the checkpoint file at {checkpoint_path}")
        self._loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path)
Ejemplo n.º 8
0
    def _register_function(self,
                           fn: Callable,
                           name: Optional[str] = None,
                           override: bool = False,
                           metadata: Optional[Dict[str, Any]] = None):
        if not isinstance(fn, FunctionType) and not isinstance(fn, partial):
            raise MisconfigurationException(
                f"You can only register a function, found: {fn}")

        name = name or fn.__name__

        if self._verbose:
            rank_zero_info(
                f"Registering: {fn.__name__} function with name: {name} and metadata: {metadata}"
            )

        item = {"fn": fn, "name": name, "metadata": metadata or {}}

        matching_index = self._find_matching_index(item)
        if override and matching_index is not None:
            self.functions[matching_index] = item
        else:
            if matching_index is not None:
                raise MisconfigurationException(
                    f"Function with name: {name} and metadata: {metadata} is already present within {self}."
                    " HINT: Use `override=True`.")
            self.functions.append(item)
    def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
        """
        Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
        All restored states are listed in return value description of `dump_checkpoint`.
        """
        # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
        fs = get_filesystem(checkpoint_path)
        if not fs.exists(checkpoint_path):
            rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch")
            return False

        # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
        checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

        # acquire the model
        model = self.trainer.get_model()

        # restore model and datamodule state
        self.restore_model_state(model, checkpoint)

        if on_gpu:
            model.cuda(self.trainer.root_gpu)

        # restore training state
        self.restore_training_state(checkpoint)

        rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}")
        return True
Ejemplo n.º 10
0
    def configure_optimizers(self) -> Dict:
        if self.instantiator is None:
            rank_zero_warn(
                "You haven't specified an optimizer or lr scheduler. "
                "Defaulting to AdamW with an lr of 1e-5 and linear warmup for 10% of steps. "
                "To change this, either use Hydra configs or override ``configure_optimizers`` in the Task."
                "For more information: <todo>")
            self._set_default_optimizer_scheduler()
            return super().configure_optimizers()

        self.optimizer = self.instantiator.optimizer(self.model,
                                                     self.optimizer_cfg)
        # compute_warmup needs the datamodule to be available when `self.num_training_steps`
        # is called that is why this is done here and not in the __init__
        self.scheduler_cfg.num_training_steps, self.scheduler_cfg.num_warmup_steps = self.compute_warmup(
            num_training_steps=self.scheduler_cfg.num_training_steps,
            num_warmup_steps=self.scheduler_cfg.num_warmup_steps,
        )
        rank_zero_info(
            f"Inferring number of training steps, set to {self.scheduler_cfg.num_training_steps}"
        )
        rank_zero_info(
            f"Inferring number of warmup steps from ratio, set to {self.scheduler_cfg.num_warmup_steps}"
        )
        self.scheduler = self.instantiator.scheduler(self.scheduler_cfg,
                                                     self.optimizer)
        return super().configure_optimizers()
Ejemplo n.º 11
0
    def on_train_batch_start(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        """ Applies model transforms at as specified during training. """
        apply_only_once = []
        current_step = trainer.global_step
        for i, transform in enumerate(self.transforms):
            if (transform.step is not None and transform.step <= current_step) or (
                transform.interval is not None
                and current_step % transform.interval == 0
            ):
                self.prepared.apply(transform.fn)
                rank_zero_info(
                    f"[QAT] {transform.message} at step={trainer.global_step}."
                )
            if transform.step is not None and transform.step <= current_step:
                apply_only_once.append(i)

        if apply_only_once:
            self.transforms = [
                transform
                for i, transform in enumerate(self.transforms)
                if i not in set(apply_only_once)
            ]
Ejemplo n.º 12
0
 def __validate_init_configuration(self):
     if self.save_top_k is not None and self.save_top_k < -1:
         raise MisconfigurationException(
             f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1'
         )
     if self._every_n_train_steps < 0:
         raise MisconfigurationException(
             f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0'
         )
     if self._every_n_val_epochs < 0:
         raise MisconfigurationException(
             f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0'
         )
     if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0:
         raise MisconfigurationException(
             f'Invalid values for every_n_train_steps={self._every_n_train_steps}'
             ' and every_n_val_epochs={self._every_n_val_epochs}.'
             ' Both cannot be enabled at the same time.')
     if self.monitor is None:
         # None: save last epoch, -1: save all epochs, 0: nothing is saved
         if self.save_top_k not in (None, -1, 0):
             raise MisconfigurationException(
                 f'ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid'
                 ' configuration. No quantity for top_k to track.')
         if self.save_last:
             rank_zero_warn(
                 'ModelCheckpoint(save_last=True, save_top_k=None, monitor=None) is a redundant configuration.'
                 ' You can save the last checkpoint with ModelCheckpoint(save_top_k=None, monitor=None).'
             )
         if self.save_top_k == -1 and self.save_last:
             rank_zero_info(
                 'ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)'
                 ' will duplicate the last checkpoint saved.')
Ejemplo n.º 13
0
    def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
        """
        Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
        All restored states are listed in return value description of `dump_checkpoint`.
        """
        # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
        fs = get_filesystem(checkpoint_path)
        if not fs.exists(checkpoint_path):
            raise FileNotFoundError(
                f"Checkpoint at {checkpoint_path} not found. Aborting training."
            )

        checkpoint, load_optimizer_states = self.trainer.training_type_plugin.restore_model_state_from_ckpt_path(
            checkpoint_path, map_location=lambda storage, loc: storage)

        model = self.trainer.lightning_module

        if on_gpu:
            model.cuda(self.trainer.root_gpu)

        # restore training state
        self.restore_training_state(checkpoint, load_optimizer_states)

        rank_zero_info(
            f"Restored states from the checkpoint file at {checkpoint_path}")
        return True
    def restore_weights(self, model: LightningModule) -> None:
        """
        Attempt to restore a checkpoint (e.g. weights) in this priority:
        1. from HPC weights
        2. from `resume_from_checkpoint` file
        3. don't restore
        """
        # clear cache before restore
        if self.trainer.on_gpu:
            torch.cuda.empty_cache()

        # 1. Attempt to restore states from HPC checkpoint
        dir_path_hpc = str(self.trainer.weights_save_path)
        max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_")
        if max_suffix is not None:
            checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt'
            self.hpc_load(checkpoint_path, self.trainer.on_gpu)
            rank_zero_info(f'restored hpc model from: {checkpoint_path}')

        # 2. Attempt to restore states from `resume_from_checkpoint` file
        elif self.trainer.resume_from_checkpoint is not None:
            self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)

        # wait for all to catch up
        self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights')

        # clear cache after restore
        if self.trainer.on_gpu:
            torch.cuda.empty_cache()
    def configure_slurm_ddp(self):
        # extract SLURM flag vars
        # whenever we have the correct number of tasks, we let slurm manage processes
        # otherwise we launch the required number of processes
        if self.use_ddp or self.use_ddp2:
            num_requested_gpus = self.num_gpus * self.num_nodes
            num_slurm_tasks = 0
            try:
                num_slurm_tasks = int(os.environ["SLURM_NTASKS"])
                self.is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus

                # enable slurm cpu
                if num_requested_gpus == 0:
                    self.is_slurm_managing_tasks = num_slurm_tasks == self.num_processes

                # in interactive mode we don't manage tasks
                job_name = os.environ["SLURM_JOB_NAME"]
                if job_name == "bash":
                    self.is_slurm_managing_tasks = False

            except Exception:
                # likely not on slurm, so set the slurm managed flag to false
                self.is_slurm_managing_tasks = False

        # used for tests only, set this flag to simulate slurm managing a task
        try:
            should_fake = int(os.environ["FAKE_SLURM_MANAGING_TASKS"])
            if should_fake:
                self.is_slurm_managing_tasks = True
        except Exception:
            pass

        # notify user the that slurm is managing tasks
        if self.is_slurm_managing_tasks:
            rank_zero_info("Multi-processing is handled by Slurm.")
Ejemplo n.º 16
0
    def on_epoch_end(self, trainer, pl_module):
        logs = trainer.callback_metrics
        self.epochs_since_last_check += 1

        if self.epochs_since_last_check >= self.period:
            self.epochs_since_last_check = 0
            current = logs.get(self.monitor)

            if current is None:
                warnings.warn(
                    f"Can save best module state only with {self.monitor} available,"
                    " skipping.",
                    RuntimeWarning,
                )
            else:
                if isinstance(current, torch.Tensor):
                    current = current.item()
                if self.check_monitor_top(current):
                    self.best_module_state = deepcopy(pl_module.module.state_dict())
                    self.best_module_metric_val = current

                    if self.verbose:
                        rank_zero_info(
                            f"\nEpoch {trainer.current_epoch:05d}: {self.monitor} reached."
                            f" Module best state updated."
                        )
Ejemplo n.º 17
0
    def _check_and_init_precision(self) -> PrecisionPlugin:
        self._validate_precision_choice()
        if isinstance(self._precision_plugin_flag, PrecisionPlugin):
            return self._precision_plugin_flag

        if isinstance(self.accelerator, IPUAccelerator):
            return IPUPrecisionPlugin(self._precision_flag)  # type: ignore
        if isinstance(self.accelerator, HPUAccelerator):
            return HPUPrecisionPlugin(self._precision_flag)  # type: ignore
        if isinstance(self.accelerator, TPUAccelerator):
            if self._precision_flag == 32:
                return TPUPrecisionPlugin()
            elif self._precision_flag in (16, "bf16"):
                if self._precision_flag == 16:
                    rank_zero_warn(
                        "You passed `Trainer(accelerator='tpu', precision=16)` but AMP"
                        " is not supported with TPUs. Using `precision='bf16'` instead."
                    )
                return TPUBf16PrecisionPlugin()
        if isinstance(self.strategy, DeepSpeedStrategy):
            return DeepSpeedPrecisionPlugin(
                self._precision_flag,
                self._amp_type_flag,
                self._amp_level_flag  # type: ignore
            )

        if self._precision_flag == 32:
            return PrecisionPlugin()
        if self._precision_flag == 64:
            return DoublePrecisionPlugin()

        if self._precision_flag == 16 and self._accelerator_flag == "cpu":
            rank_zero_warn(
                "You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU."
                " Using `precision='bf16'` instead.")
            self._precision_flag = "bf16"

        if self._precision_flag in (16, "bf16"):
            rank_zero_info(
                f"Using 16bit {self._amp_type_flag.value} Automatic Mixed Precision (AMP)"  # type: ignore
                if self._precision_flag ==
                16 else "Using bfloat16 Automatic Mixed Precision (AMP)")

            if self._amp_type_flag == AMPType.NATIVE:
                device = "cpu" if self._accelerator_flag == "cpu" else "cuda"

                if isinstance(self.strategy,
                              (DDPShardedStrategy, DDPSpawnShardedStrategy)):
                    return ShardedNativeMixedPrecisionPlugin(
                        self._precision_flag, device)
                if isinstance(self.strategy, DDPFullyShardedStrategy):
                    return FullyShardedNativeMixedPrecisionPlugin(
                        self._precision_flag, device)
                return NativeMixedPrecisionPlugin(self._precision_flag, device)

            if self._amp_type_flag == AMPType.APEX:
                self._amp_level_flag = self._amp_level_flag or "O2"
                return ApexMixedPrecisionPlugin(self._amp_level_flag)

        raise RuntimeError("No precision set")
Ejemplo n.º 18
0
    def on_init_start(self, limit_train_batches, limit_val_batches,
                      limit_test_batches, val_check_interval, overfit_batches,
                      fast_dev_run):

        self.trainer.fast_dev_run = fast_dev_run
        if self.trainer.fast_dev_run:
            limit_train_batches = 1
            limit_val_batches = 1
            limit_test_batches = 1
            self.trainer.num_sanity_val_steps = 0
            self.trainer.max_epochs = 1
            rank_zero_info(
                'Running in fast_dev_run mode: will run a full train,'
                ' val and test loop using a single batch')

        self.trainer.limit_train_batches = _determine_batch_limits(
            limit_train_batches, 'limit_train_batches')
        self.trainer.limit_val_batches = _determine_batch_limits(
            limit_val_batches, 'limit_val_batches')
        self.trainer.limit_test_batches = _determine_batch_limits(
            limit_test_batches, 'limit_test_batches')
        self.trainer.val_check_interval = _determine_batch_limits(
            val_check_interval, 'val_check_interval')
        self.trainer.overfit_batches = _determine_batch_limits(
            overfit_batches, 'overfit_batches')
        self.determine_data_use_amount(self.trainer.overfit_batches)
def infer_test_tta(config: Config, test_source: DataSource, transforms,
                   experiment_name):
    experiment_root = get_my_isic2020_experiments_root() / experiment_name
    experiment_index_path = experiment_root / "index.json"
    experiment = json.load(experiment_index_path.open("r"))

    test_loader = DataLoader(
        MelanomaDataset(test_source, train=False, transforms=transforms),
        batch_size=config.batch_size,
        num_workers=config.num_workers,
    )

    sample_submission_path = get_my_isic2020_csv_root(
    ) / "sample_submission.csv"
    label = cast(pd.DataFrame, pd.read_csv(
        sample_submission_path))["image_name"].iloc[:len(test_source.df)]

    inferences = []
    model_type = experiment["model_type"]
    for ckpt in experiment["checkpoints"]:
        fold_index = ckpt["fold_index"]
        n_fold = ckpt["n_fold"]
        ckpt_path = experiment_root / ckpt["file"]

        rank_zero_info(
            f"Infer test data using {str(ckpt_path)} - {fold_index} / {n_fold}"
        )
        model = load_from_checkpoint(model_type, ckpt_path)
        inference = Classifier(
            model, tta_epochs=config.tta_epochs).predict(test_loader)
        inferences.append((f"fold-{fold_index}", inference.ravel()))
    result = pd.concat([label, pd.DataFrame(dict(inferences))], axis=1)
    result.to_csv(f"cv_test_{model_type}.csv", index=False)
    def _attach_model_callbacks(self) -> None:
        """Attaches the callbacks defined in the model.

        If a callback returned by the model's configure_callback method has the same type as one or several
        callbacks already present in the trainer callbacks list, it will replace them.
        In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
        will be pushed to the end of the list, ensuring they run last.
        """
        model_callbacks = self.trainer.call_hook("configure_callbacks")
        if not model_callbacks:
            return
        model_callback_types = {type(c) for c in model_callbacks}
        trainer_callback_types = {type(c) for c in self.trainer.callbacks}
        override_types = model_callback_types.intersection(
            trainer_callback_types)
        if override_types:
            rank_zero_info(
                "The following callbacks returned in `LightningModule.configure_callbacks` will override"
                " existing callbacks passed to Trainer:"
                f" {', '.join(sorted(t.__name__ for t in override_types))}")
        # remove all callbacks with a type that occurs in model callbacks
        all_callbacks = [
            c for c in self.trainer.callbacks if type(c) not in override_types
        ]
        all_callbacks.extend(model_callbacks)
        all_callbacks = CallbackConnector._reorder_callbacks(all_callbacks)
        # TODO: connectors refactor: move callbacks list to connector and do not write Trainer state
        self.trainer.callbacks = all_callbacks
    def resume_end(self) -> None:
        """Signal the connector that all states have resumed and memory for the checkpoint object can be
        released."""
        assert self.trainer.state.fn is not None
        if self.resume_checkpoint_path:
            if self.trainer.state.fn == TrainerFn.FITTING:
                rank_zero_info(
                    f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}"
                )
            elif self.trainer.state.fn in (TrainerFn.VALIDATING,
                                           TrainerFn.TESTING,
                                           TrainerFn.PREDICTING):
                rank_zero_info(
                    f"Loaded model weights from checkpoint at {self.resume_checkpoint_path}"
                )
        # TODO: remove resume_from_checkpoint_fit_path in v1.7
        if (self.trainer.state.fn == TrainerFn.FITTING
                and self.resume_checkpoint_path
                == self.resume_from_checkpoint_fit_path):
            self.resume_from_checkpoint_fit_path = None
        self.resume_checkpoint_path = None
        self._loaded_checkpoint = {}

        # clear cache after restore
        torch.cuda.empty_cache()

        # wait for all to catch up
        self.trainer.strategy.barrier("CheckpointConnector.resume_end")
Ejemplo n.º 22
0
    def __validate_init_configuration(self) -> None:
        if self.save_top_k < -1:
            raise MisconfigurationException(
                f"Invalid value for save_top_k={self.save_top_k}. Must be >= -1"
            )
        if self._every_n_train_steps < 0:
            raise MisconfigurationException(
                f"Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0"
            )
        if self._every_n_epochs < 0:
            raise MisconfigurationException(
                f"Invalid value for every_n_epochs={self._every_n_epochs}. Must be >= 0"
            )

        every_n_train_steps_triggered = self._every_n_train_steps >= 1
        every_n_epochs_triggered = self._every_n_epochs >= 1
        train_time_interval_triggered = self._train_time_interval is not None
        if every_n_train_steps_triggered + every_n_epochs_triggered + train_time_interval_triggered > 1:
            raise MisconfigurationException(
                f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, "
                f"every_n_epochs={self._every_n_epochs} and train_time_interval={self._train_time_interval} "
                "should be mutually exclusive.")

        if self.monitor is None:
            # -1: save all epochs, 0: nothing is saved, 1: save last epoch
            if self.save_top_k not in (-1, 0, 1):
                raise MisconfigurationException(
                    f"ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid"
                    " configuration. No quantity for top_k to track.")

            if self.save_top_k == -1 and self.save_last:
                rank_zero_info(
                    "ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)"
                    " will duplicate the last checkpoint saved.")
Ejemplo n.º 23
0
    def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
        if trainer.current_epoch == self.swa_start:
            # move average model to request device.
            self._average_model = self._average_model.to(self._device or pl_module.device)

            optimizer = trainer.optimizers[0]
            if self._swa_lrs is None:
                self._swa_lrs = [param_group["lr"] for param_group in optimizer.param_groups]
            if isinstance(self._swa_lrs, float):
                self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups)

            for lr, group in zip(self._swa_lrs, optimizer.param_groups):
                group["initial_lr"] = lr

            self._swa_scheduler = SWALR(
                optimizer,
                swa_lr=self._swa_lrs,
                anneal_epochs=self._annealing_epochs,
                anneal_strategy=self._annealing_strategy,
                last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
            )
            default_scheduler_cfg = _get_default_scheduler_config()
            assert default_scheduler_cfg["interval"] == "epoch" and default_scheduler_cfg["frequency"] == 1
            default_scheduler_cfg["scheduler"] = self._swa_scheduler

            if trainer.lr_schedulers:
                scheduler_cfg = trainer.lr_schedulers[0]
                if scheduler_cfg["interval"] != "epoch" or scheduler_cfg["frequency"] != 1:
                    rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}")
                rank_zero_info(
                    f"Swapping scheduler `{scheduler_cfg['scheduler'].__class__.__name__}`"
                    f" for `{self._swa_scheduler.__class__.__name__}`"
                )
                trainer.lr_schedulers[0] = default_scheduler_cfg
            else:
                trainer.lr_schedulers.append(default_scheduler_cfg)

            self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)

        if self.swa_start <= trainer.current_epoch <= self.swa_end:
            self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn)

        # Note: No > here in case the callback is saved with the model and training continues
        if trainer.current_epoch == self.swa_end + 1:

            # Transfer weights from average model to pl_module
            self.transfer_weights(self._average_model, pl_module)

            # Reset BatchNorm for update
            self.reset_batch_norm_and_save_state(pl_module)

            # There is no need to perform either backward or optimizer.step as we are
            # performing only one pass over the train data-loader to compute activation statistics
            # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
            trainer.num_training_batches += 1
            trainer.fit_loop._skip_backward = True
            self._accumulate_grad_batches = trainer.accumulate_grad_batches

            trainer.accumulate_grad_batches = trainer.num_training_batches
Ejemplo n.º 24
0
 def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None:
     if max_time is None:
         return
     if any(isinstance(cb, Timer) for cb in self.trainer.callbacks):
         rank_zero_info("Ignoring `Trainer(max_time=...)`, callbacks list already contains a Timer.")
         return
     timer = Timer(duration=max_time, interval="step")
     self.trainer.callbacks.append(timer)
Ejemplo n.º 25
0
 def print_weight_summary(self):
     string = 'Summary Teacher:\n'
     for j in range(len(self.models)):
         sum = 0
         for i in self.models[j].parameters():
             sum += i[0].sum()
         string += f'   Teacher Level {j}: WeightSum == {sum}\n'
     rank_zero_info(string)
Ejemplo n.º 26
0
    def setup(self):
        rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')

        if not XLA_AVAILABLE:
            raise MisconfigurationException('No TPU devices found.')

        #  COLAB_GPU is an env var available by default in Colab environments.
        self.start_method = 'fork' if self.trainer.on_colab_kaggle else 'spawn'
Ejemplo n.º 27
0
    def on_validation_end(self, trainer: pl.Trainer, pl_module):
        save_json(pl_module.metrics, pl_module.metrics_save_path)

        rank_zero_info("***** Validation results *****")
        metrics = trainer.callback_metrics
        # Log results
        for key in sorted(metrics):
            if key not in ["log", "progress_bar", "preds"]:
                rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
Ejemplo n.º 28
0
    def on_init_start(
        self,
        limit_train_batches,
        limit_val_batches,
        limit_test_batches,
        limit_predict_batches,
        val_check_interval,
        overfit_batches,
        fast_dev_run,
    ):
        if not isinstance(fast_dev_run, (bool, int)):
            raise MisconfigurationException(
                f'fast_dev_run={fast_dev_run} is not a valid configuration.'
                ' It should be either a bool or an int >= 0')

        if isinstance(fast_dev_run, int) and (fast_dev_run < 0):
            raise MisconfigurationException(
                f'fast_dev_run={fast_dev_run} is not a'
                ' valid configuration. It should be >= 0.')

        self.trainer.fast_dev_run = fast_dev_run
        fast_dev_run = int(fast_dev_run)

        # set fast_dev_run=True when it is 1, used while logging
        if fast_dev_run == 1:
            self.trainer.fast_dev_run = True

        if fast_dev_run:
            limit_train_batches = fast_dev_run
            limit_val_batches = fast_dev_run
            limit_test_batches = fast_dev_run
            limit_predict_batches = fast_dev_run
            self.trainer.fit_loop.max_steps = fast_dev_run
            self.trainer.num_sanity_val_steps = 0
            self.trainer.fit_loop.max_epochs = 1
            val_check_interval = 1.0
            self.trainer.check_val_every_n_epoch = 1
            self.trainer.logger = DummyLogger()

            rank_zero_info(
                'Running in fast_dev_run mode: will run a full train,'
                f' val, test and prediction loop using {fast_dev_run} batch(es).'
            )

        self.trainer.limit_train_batches = _determine_batch_limits(
            limit_train_batches, 'limit_train_batches')
        self.trainer.limit_val_batches = _determine_batch_limits(
            limit_val_batches, 'limit_val_batches')
        self.trainer.limit_test_batches = _determine_batch_limits(
            limit_test_batches, 'limit_test_batches')
        self.trainer.limit_predict_batches = _determine_batch_limits(
            limit_predict_batches, 'limit_predict_batches')
        self.trainer.val_check_interval = _determine_batch_limits(
            val_check_interval, 'val_check_interval')
        self.trainer.overfit_batches = _determine_batch_limits(
            overfit_batches, 'overfit_batches')
        self.determine_data_use_amount(self.trainer.overfit_batches)
Ejemplo n.º 29
0
    def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
        current = metrics.get(self.monitor)
        epoch = metrics.get("epoch")
        step = metrics.get("step")

        if self.check_monitor_top_k(current):
            self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
        elif self.verbose:
            rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}")
Ejemplo n.º 30
0
 def _load_backbone(self,
                    repo_or_dir: str,
                    model_name: str,
                    model_path: Optional[str] = None):
     if model_path:
         return BarlowTwins.load_from_checkpoint(model_path).backbone
     rank_zero_info(
         "No model provided, loading torch hub pre-trained model")
     return torch.hub.load(repo_or_dir, model_name)