def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None: """Adds subcommands to the input parser.""" parser_subcommands = parser.add_subcommands() # the user might have passed a builder function trainer_class = ( self.trainer_class if isinstance(self.trainer_class, type) else class_from_function(self.trainer_class) ) # register all subcommands in separate subcommand parsers under the main parser for subcommand in self.subcommands(): subcommand_parser = self._prepare_subcommand_parser(trainer_class, subcommand, **kwargs.get(subcommand, {})) fn = getattr(trainer_class, subcommand) # extract the first line description in the docstring for the subcommand help message description = _get_short_description(fn) parser_subcommands.add_subcommand(subcommand, subcommand_parser, help=description)
def add_lightning_class_args( self, lightning_class: Union[Callable[..., Union[Trainer, LightningModule, LightningDataModule, Callback]], Type[Trainer], Type[LightningModule], Type[LightningDataModule], Type[Callback], ], nested_key: str, subclass_mode: bool = False, required: bool = True, ) -> List[str]: """Adds arguments from a lightning class to a nested key of the parser. Args: lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}. nested_key: Name of the nested namespace to store arguments. subclass_mode: Whether allow any subclass of the given class. required: Whether the argument group is required. Returns: A list with the names of the class arguments added. """ if callable(lightning_class) and not isinstance(lightning_class, type): lightning_class = class_from_function(lightning_class) if isinstance(lightning_class, type) and issubclass( lightning_class, (Trainer, LightningModule, LightningDataModule, Callback)): if issubclass(lightning_class, Callback): self.callback_keys.append(nested_key) if subclass_mode: return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required) return self.add_class_arguments( lightning_class, nested_key, fail_untyped=False, instantiate=not issubclass(lightning_class, Trainer), sub_configs=True, ) raise MisconfigurationException( f"Cannot add arguments from: {lightning_class}. You should provide either a callable or a subclass of: " "Trainer, LightningModule, LightningDataModule, or Callback.")