Exemple #1
0
        def lightning_scheduler_dict_to_det(lrs: dict) -> _LRScheduler:
            """
            wrap user defined lr_scheduler and switch the attached optimizer with the
            wrapped version.

            input_dict = {
                'scheduler': None,
                'name': None,  # no custom name
                'interval': 'epoch',  # after epoch is over
                'frequency': 1,  # every epoch/batch
                'reduce_on_plateau': False,  # most often not ReduceLROnPlateau scheduler
                'monitor': monitor,  # value to monitor for ReduceLROnPlateau
                'strict': True,  # enforce that the monitor exists for ReduceLROnPlateau
            }
            """
            if lrs["reduce_on_plateau"]:
                raise InvalidModelException(
                    "LRScheduler reduce_on_plateau is not supported")
            if lrs["monitor"] is not None:
                raise InvalidModelException(
                    "LRScheduler monitor is not supported")

            step_mode = (LRScheduler.StepMode.STEP_EVERY_EPOCH
                         if lrs["interval"] == "epoch" else
                         LRScheduler.StepMode.STEP_EVERY_BATCH)

            opt_key = cast(Optimizer,
                           getattr(lrs["scheduler"], "optimizer", None))
            wrapped_opt = optimizers_dict.get(opt_key)
            if wrapped_opt is None:
                raise InvalidModelException(
                    "An LRScheduler is returned in `configure_optimizers` without having "
                    "returned the optimizer itself. Please follow PyTorchLightning's documenation"
                    "to make sure you're returning one of the expected values."
                    "- Single optimizer.\n"
                    "- List or Tuple - List of optimizers.\n"
                    "- Two lists - The first list has multiple optimizers, the second a list of"
                    "LRSchedulers (or lr_dict).\n"
                    "- Dictionary, with an ‘optimizer’ key, and (optionally) a ‘lr_scheduler’ key"
                    "whose value is a single LR scheduler or lr_dict.\n"
                    "- Tuple of dictionaries as described, with an optional ‘frequency’ key.\n"
                )

            check.check_isinstance(
                lrs["scheduler"].optimizer,
                Optimizer,
                "A returned LRScheduler from `configure_optimizers` is "
                "missing the optimizer attribute.",
            )

            # switch the user's unwrapped optimizer with the wrapped version.
            lrs["scheduler"].optimizer = wrapped_opt
            return self._pls.context.wrap_lr_scheduler(
                lrs["scheduler"], step_mode, frequency=lrs["frequency"])
Exemple #2
0
def check_compatibility(lm: pl.LightningModule) -> None:
    prefix = "Unsupported usage in PLAdapter: "
    unsupported_members = {
        "backward",
        "get_progress_bar_dict",
        "manual_backward",
        "on_fit_end",
        "on_fit_start",
        "on_load_checkpoint",
        "on_pretrain_routine_end",
        "on_pretrain_routine_start",
        "on_save_checkpoint",
        "on_test_batch_end",
        "on_test_batch_start",
        "on_test_epoch_end",
        "on_test_epoch_start",
        "on_train_epoch_end",
        "optimizer_step",
        "optimizer_zero_grad",
        "setup",
        "tbptt_split_batch",
        "teardown",
        "test_dataloader",
        "test_epoch_end",
        "test_step",
        "test_step_end",
        "training_step_end",
        "transfer_batch_to_device",
        "validation_step_end",
    }

    members = inspect.getmembers(lm, predicate=inspect.ismethod)
    overridden_members = set(
        map(lambda m: m[0], filter(lambda m: is_overridden(m[0], lm), members))
    )

    matches = unsupported_members & overridden_members
    if len(matches) > 0:
        raise InvalidModelException(prefix + f"{matches}")

    for member in overridden_members:
        if has_param(getattr(lm, member), "dataloader_idx"):
            raise InvalidModelException(
                prefix
                + f'multiple dataloaders and `dataloader_idx` are not supported in "{member}"'
            )

    if has_param(lm.training_step, "hiddens", 4):
        raise InvalidModelException(prefix + '`hiddens` argument in "training_step"')

    if lm.trainer is not None:
        raise InvalidModelException(prefix + "Lightning Trainer")
Exemple #3
0
 def lm_log_dict(a_dict: Dict, *args: Any, **kwargs: Any) -> None:
     if len(args) != 0 or len(kwargs) != 0:
         raise InvalidModelException(
             f"unsupported arguments to LightningModule.log {args} {kwargs}"
         )
     for metric, value in a_dict.items():
         if type(value) == int or type(value) == float:
             writer.add_scalar(metric, value, context.current_train_batch())
Exemple #4
0
        def lightning_scheduler_dict_to_det(lrs: dict) -> _LRScheduler:
            """
            input_dict = {
                'scheduler': None,
                'name': None,  # no custom name
                'interval': 'epoch',  # after epoch is over
                'frequency': 1,  # every epoch/batch
                'reduce_on_plateau': False,  # most often not ReduceLROnPlateau scheduler
                'monitor': monitor,  # value to monitor for ReduceLROnPlateau
                'strict': True,  # enforce that the monitor exists for ReduceLROnPlateau
            }
            """
            if lrs["reduce_on_plateau"]:
                raise InvalidModelException("LRScheduler reduce_on_plateaue is not supported")
            if lrs["monitor"] is not None:
                raise InvalidModelException("LRScheduler monitor is not supported")

            step_mode = (
                LRScheduler.StepMode.STEP_EVERY_EPOCH
                if lrs["interval"] == "epoch"
                else LRScheduler.StepMode.STEP_EVERY_BATCH
            )
            return context.wrap_lr_scheduler(lrs["scheduler"], step_mode)
Exemple #5
0
    def _build_train_args(self, batch: TorchData, batch_idx: int,
                          opt_idx: int) -> List[Any]:
        # taken from pytorch_lightning
        args = [batch, batch_idx]

        if len(self._pls.optimizers) > 1:
            if has_param(self._pls.lm.training_step, "optimizer_idx"):
                args.append(opt_idx)
            else:
                num_opts = len(self._pls.optimizers)
                raise InvalidModelException(
                    f"Your LightningModule defines {num_opts} optimizers but "
                    f'training_step is missing the "optimizer_idx" argument.')

        return args