Example #1
0
 def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None:
     """Adds subcommands to the input parser."""
     parser_subcommands = parser.add_subcommands()
     # the user might have passed a builder function
     trainer_class = (
         self.trainer_class if isinstance(self.trainer_class, type) else class_from_function(self.trainer_class)
     )
     # register all subcommands in separate subcommand parsers under the main parser
     for subcommand in self.subcommands():
         subcommand_parser = self._prepare_subcommand_parser(trainer_class, subcommand, **kwargs.get(subcommand, {}))
         fn = getattr(trainer_class, subcommand)
         # extract the first line description in the docstring for the subcommand help message
         description = _get_short_description(fn)
         parser_subcommands.add_subcommand(subcommand, subcommand_parser, help=description)
Example #2
0
    def add_lightning_class_args(
        self,
        lightning_class: Union[Callable[...,
                                        Union[Trainer, LightningModule,
                                              LightningDataModule, Callback]],
                               Type[Trainer], Type[LightningModule],
                               Type[LightningDataModule], Type[Callback], ],
        nested_key: str,
        subclass_mode: bool = False,
        required: bool = True,
    ) -> List[str]:
        """Adds arguments from a lightning class to a nested key of the parser.

        Args:
            lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}.
            nested_key: Name of the nested namespace to store arguments.
            subclass_mode: Whether allow any subclass of the given class.
            required: Whether the argument group is required.

        Returns:
            A list with the names of the class arguments added.
        """
        if callable(lightning_class) and not isinstance(lightning_class, type):
            lightning_class = class_from_function(lightning_class)

        if isinstance(lightning_class, type) and issubclass(
                lightning_class,
            (Trainer, LightningModule, LightningDataModule, Callback)):
            if issubclass(lightning_class, Callback):
                self.callback_keys.append(nested_key)
            if subclass_mode:
                return self.add_subclass_arguments(lightning_class,
                                                   nested_key,
                                                   fail_untyped=False,
                                                   required=required)
            return self.add_class_arguments(
                lightning_class,
                nested_key,
                fail_untyped=False,
                instantiate=not issubclass(lightning_class, Trainer),
                sub_configs=True,
            )
        raise MisconfigurationException(
            f"Cannot add arguments from: {lightning_class}. You should provide either a callable or a subclass of: "
            "Trainer, LightningModule, LightningDataModule, or Callback.")