def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: bool = True, callbacks: Optional[Union[List[Callback], Callback]] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, process_position: int = 0, num_nodes: int = 1, num_processes: int = 1, gpus: Optional[Union[List[int], str, int]] = None, auto_select_gpus: bool = False, tpu_cores: Optional[Union[List[int], str, int]] = None, log_gpu_memory: Optional[str] = None, progress_bar_refresh_rate: Optional[int] = None, overfit_batches: Union[int, float] = 0.0, track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, fast_dev_run: Union[int, bool] = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_epochs: Optional[int] = None, min_epochs: Optional[int] = None, max_steps: Optional[int] = None, min_steps: Optional[int] = None, limit_train_batches: Union[int, float] = 1.0, limit_val_batches: Union[int, float] = 1.0, limit_test_batches: Union[int, float] = 1.0, limit_predict_batches: Union[int, float] = 1.0, val_check_interval: Union[int, float] = 1.0, flush_logs_every_n_steps: int = 100, log_every_n_steps: int = 50, accelerator: Optional[Union[str, Accelerator]] = None, sync_batchnorm: bool = False, precision: int = 32, weights_summary: Optional[str] = 'top', weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[Union[Path, str]] = None, profiler: Optional[Union[BaseProfiler, bool, str]] = None, benchmark: bool = False, deterministic: bool = False, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: bool = True, plugins: Optional[Union[str, list]] = None, amp_backend: str = 'native', amp_level: str = 'O2', distributed_backend: Optional[str] = None, automatic_optimization: Optional[bool] = None, move_metrics_to_cpu: bool = False, enable_pl_optimizer: bool = None, # todo: remove in v1.3 multiple_trainloader_mode: str = 'max_size_cycle', ): r""" Customize every aspect of training via flags Args: accelerator: Previously known as distributed_backend (dp, ddp, ddp2, etc...). Can also take in an accelerator object for custom hardware. accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict. amp_backend: The mixed precision backend to use ("native" or "apex") amp_level: The optimization level to use (O1, O2, etc...). auto_lr_find: If set to True, will make trainer.tune() run a learning rate finder, trying to optimize initial learning for faster convergence. trainer.tune() method will set the suggested learning rate in self.lr or self.learning_rate in the LightningModule. To use a different key set a string instead of True with the key name. auto_scale_batch_size: If set to True, will `initially` run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either `power` that estimates the batch size through a power search or `binsearch` that estimates the batch size through a binary search. auto_select_gpus: If enabled and `gpus` is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in "exclusive mode", such that only one process at a time can access them. benchmark: If true enables cudnn.benchmark. callbacks: Add a callback or list of callbacks. checkpoint_callback: If ``True``, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``. .. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since v1.1 and will be unsupported from v1.3. Use `callbacks` argument instead. check_val_every_n_epoch: Check val every n train epochs. default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' deterministic: If true enables cudnn.deterministic. distributed_backend: deprecated. Please use 'accelerator' fast_dev_run: runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val and test to find any bugs (ie: a sort of unit test). flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). gpus: number of gpus to train on (int) or which GPUs to train on (list or str) applied per node gradient_clip_val: 0 means don't clip. limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches) limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches) limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches) logger: Logger (or iterable collection of loggers) for experiment tracking. log_gpu_memory: None, 'min_max', 'all'. Might slow performance log_every_n_steps: How often to log within steps (defaults to every 50 steps). automatic_optimization: If False you are responsible for calling .backward, .step, zero_grad in LightningModule. This argument has been moved to LightningModule. It is deprecated here in v1.1 and will be removed in v1.3. prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data process_position: orders the progress bar when running multiple models on same machine. progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.). profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool value is deprecated in v1.1 and will be removed in v1.3. overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0 plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs. max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000. min_epochs: Force training for at least these many epochs. Disabled by default (None). If both min_epochs and min_steps are not specified, defaults to ``min_epochs`` = 1. max_steps: Stop training after this number of steps. Disabled by default (None). min_steps: Force training for at least these number of steps. Disabled by default (None). num_nodes: number of GPU nodes for distributed training. num_processes: number of processes for distributed training with distributed_backend="ddp_cpu" num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. Set it to `-1` to run all batches in all validation dataloaders. Default: 2 reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch. replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this will toggled automatically when DDP is used. By default it will add ``shuffle=True`` for train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it, you can set ``replace_sampler_ddp=False`` and add your own distributed sampler. resume_from_checkpoint: Path/URL of the checkpoint from which training is resumed. If there is no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch. sync_batchnorm: Synchronize batch norm layers between process groups/whole world. terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf. tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1] track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm. truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer sequence. val_check_interval: How often to check the validation set. Use float to check within a training epoch, use int to check every n steps (batches). weights_summary: Prints a summary of the weights when training begins. weights_save_path: Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in `default_root_dir`. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' Defaults to `default_root_dir`. move_metrics_to_cpu: Whether to force internal logged metrics to be moved to cpu. This can save some gpu memory, but can make training slower. Use with attention. enable_pl_optimizer: If True, each optimizer will be wrapped by `pytorch_lightning.core.optimizer.LightningOptimizer`. It allows Lightning to handle AMP, TPU, accumulated_gradients, etc. .. warning:: Currently deprecated and it will be removed in v1.3 multiple_trainloader_mode: How to loop over the datasets when there are multiple train loaders. In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed, and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets reload when reaching the minimum length of datasets. """ super().__init__() self._running_stage = None distributed_backend = distributed_backend or accelerator # init connectors self.dev_debugger = InternalDebugger(self) self.config_validator = ConfigValidator(self) self.data_connector = DataConnector(self) self.optimizer_connector = OptimizerConnector(self) self.accelerator_connector = BackendConnector( num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, benchmark, replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins) self.logger_connector = LoggerConnector(self, log_gpu_memory) self.model_connector = ModelConnector(self) self.callback_connector = CallbackConnector(self) self.debugging_connector = DebuggingConnector(self) self.training_tricks_connector = TrainingTricksConnector(self) self.profile_connector = ProfilerConnector(self) self.checkpoint_connector = CheckpointConnector(self) self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) self.train_loop = TrainLoop(self, multiple_trainloader_mode) self.evaluation_loop = EvaluationLoop(self) self.predict_loop = PredictLoop(self) # training state self.weights_summary = weights_summary self.shown_warnings = set() # init callbacks # Declare attributes to be set in callback_connector on_trainer_init self.callback_connector.on_trainer_init( callbacks, checkpoint_callback, progress_bar_refresh_rate, process_position, default_root_dir, weights_save_path, resume_from_checkpoint, ) # hook self.on_init_start() # init optimizer + lr scheduler related flags self.optimizer_connector.on_trainer_init(enable_pl_optimizer) # init data flags self.data_connector.on_trainer_init(check_val_every_n_epoch, reload_dataloaders_every_epoch, prepare_data_per_node) # init training tricks self.training_tricks_connector.on_trainer_init( gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, terminate_on_nan) # init train loop related flags # TODO: remove in 1.3.0 if automatic_optimization is None: automatic_optimization = True else: rank_zero_warn( "Disable automatic optimization with the trainer flag is deprecated and will be removed in v1.3.0!" "Please use the property on the LightningModule for disabling automatic optimization" ) self.train_loop.on_trainer_init( max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps, automatic_optimization, weights_summary, ) self.evaluation_loop.on_trainer_init() # configure tuner self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size) # configure profiler self.profile_connector.on_trainer_init(profiler) # init logger flags self.logger_connector.on_trainer_init( logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu, ) # init debugging flags self.debugging_connector.on_init_start( limit_train_batches, limit_val_batches, limit_test_batches, limit_predict_batches, val_check_interval, overfit_batches, fast_dev_run, ) # Callback system self.on_init_end()
def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[ EarlyStopping, bool]] = False, # todo: remove in v1.0.0 callbacks: Optional[List[Callback]] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, process_position: int = 0, num_nodes: int = 1, num_processes: int = 1, gpus: Optional[Union[List[int], str, int]] = None, auto_select_gpus: bool = False, tpu_cores: Optional[Union[List[int], str, int]] = None, log_gpu_memory: Optional[str] = None, progress_bar_refresh_rate: int = 1, overfit_batches: Union[int, float] = 0.0, track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, fast_dev_run: bool = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_epochs: int = 1000, min_epochs: int = 1, max_steps: Optional[int] = None, min_steps: Optional[int] = None, limit_train_batches: Union[int, float] = 1.0, limit_val_batches: Union[int, float] = 1.0, limit_test_batches: Union[int, float] = 1.0, val_check_interval: Union[int, float] = 1.0, log_save_interval: int = 100, row_log_interval: int = 50, distributed_backend: Optional[str] = None, sync_batchnorm: bool = False, precision: int = 32, weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT, weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, profiler: Optional[Union[BaseProfiler, bool]] = None, benchmark: bool = False, deterministic: bool = False, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: bool = True, cluster_environment: ClusterEnvironment = None, amp_backend: str = 'native', amp_level: str = 'O2', # backward compatible, todo: remove in v1.0.0 overfit_pct: float = None, # backward compatible, todo: remove in v1.0.0 ): r""" Customize every aspect of training via flags Args: accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict. amp_backend: The mixed precision backend to use ("native" or "apex") amp_level: The optimization level to use (O1, O2, etc...). auto_lr_find: If set to True, will `initially` run a learning rate finder, trying to optimize initial learning for faster convergence. Sets learning rate in self.lr or self.learning_rate in the LightningModule. To use a different key, set a string instead of True with the key name. auto_scale_batch_size: If set to True, will `initially` run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either `power` that estimates the batch size through a power search or `binsearch` that estimates the batch size through a binary search. auto_select_gpus: If enabled and `gpus` is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in "exclusive mode", such that only one process at a time can access them. benchmark: If true enables cudnn.benchmark. callbacks: Add a list of callbacks. checkpoint_callback: Callback for checkpointing. check_val_every_n_epoch: Check val every n train epochs. cluster_environment: Environment config to link up arbitrary clusters default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' deterministic: If true enables cudnn.deterministic. distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn, ddp_cpu) early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`). Deprecated since v0.10.0 and will be removed in v1.0. fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). gpus: number of gpus to train on (int) or which GPUs to train on (list or str) applied per node gradient_clip_val: 0 means don't clip. limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches) limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches) limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches) logger: Logger (or iterable collection of loggers) for experiment tracking. log_gpu_memory: None, 'min_max', 'all'. Might slow performance log_save_interval: Writes logs to disk this often prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data process_position: orders the progress bar when running multiple models on same machine. progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`. profiler: To profile individual steps during training and assist in identifying bottlenecks. overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0 precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs. max_epochs: Stop training once this number of epochs is reached. min_epochs: Force training for at least these many epochs max_steps: Stop training after this number of steps. Disabled by default (None). min_steps: Force training for at least these number of steps. Disabled by default (None). num_nodes: number of GPU nodes for distributed training. num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. Set it to `-1` to run all batches in all validation dataloaders. Default: 2 reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch. replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this will toggled automatically when DDP is used. By default it will add ``shuffle=True`` for train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it, you can set ``replace_sampler_ddp=False`` and add your own distributed sampler. resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here. This can be a URL. row_log_interval: How often to add logging rows (does not write to disk) sync_batchnorm: Synchronize batch norm layers between process groups/whole world. terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf. tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1] track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm. truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer sequence. val_check_interval: How often to check the validation set. Use float to check within a training epoch, use int to check every n steps (batches). weights_summary: Prints a summary of the weights when training begins. weights_save_path: Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in `default_root_dir`. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' Defaults to `default_root_dir`. """ super().__init__() # init connectors self.dev_debugger = InternalDebugger(self) self.config_validator = ConfigValidator(self) self.data_connector = DataConnector(self) self.optimizer_connector = OptimizerConnector(self) self.accelerator_connector = AcceleratorConnector(self) self.logger_connector = LoggerConnector(self) self.model_connector = ModelConnector(self) self.precision_connector = PrecisionConnector(self) self.callback_connector = CallbackConnector(self) self.debugging_connector = DebuggingConnector(self) self.training_tricks_connector = TrainingTricksConnector(self) self.profile_connector = ProfilerConnector(self) self.checkpoint_connector = CheckpointConnector(self) self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) self.accelerator_backend = None self.evaluation_loop = EvaluationLoop(self) self.train_loop = TrainLoop(self) # training state self.weights_summary = weights_summary self.model = None self.shown_warnings = set() # init callbacks # Declare attributes to be set in callback_connector on_trainer_init self.checkpoint_callback: Union[ModelCheckpoint, bool] = checkpoint_callback self.early_stop_callback: Optional[Union[EarlyStopping, bool]] = early_stop_callback self.callback_connector.on_trainer_init( callbacks, early_stop_callback, checkpoint_callback, progress_bar_refresh_rate, process_position, default_root_dir, weights_save_path, resume_from_checkpoint) # hook self.on_init_start() # init optimizer + lr scheduler related flags self.optimizer_connector.on_trainer_init() # init data flags self.data_connector.on_trainer_init(check_val_every_n_epoch, reload_dataloaders_every_epoch, prepare_data_per_node) # init training tricks self.training_tricks_connector.on_trainer_init( gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, terminate_on_nan) # init accelerator related flags self.accelerator_connector.on_trainer_init( num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, log_gpu_memory, sync_batchnorm, benchmark, replace_sampler_ddp, deterministic, cluster_environment) # init train loop related flags self.train_loop.on_trainer_init(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) self.evaluation_loop.on_trainer_init() # configure tuner self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size) # configure profiler self.profile_connector.on_trainer_init(profiler) # init logger flags self.logger_connector.on_trainer_init(logger, log_save_interval, row_log_interval) # init debugging flags self.debugging_connector.on_init_start(overfit_pct, limit_train_batches, limit_val_batches, limit_test_batches, val_check_interval, overfit_batches, fast_dev_run) # set precision self.precision_connector.on_trainer_init(precision, amp_level, amp_backend) # Callback system self.on_init_end()
def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = False, callbacks: Optional[List[Callback]] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, process_position: int = 0, num_nodes: int = 1, num_processes: int = 1, gpus: Optional[Union[List[int], str, int]] = None, auto_select_gpus: bool = False, tpu_cores: Optional[Union[List[int], str, int]] = None, log_gpu_memory: Optional[str] = None, progress_bar_refresh_rate: int = 1, overfit_batches: Union[int, float] = 0.0, track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, fast_dev_run: bool = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_epochs: int = 1000, min_epochs: int = 1, max_steps: Optional[int] = None, min_steps: Optional[int] = None, limit_train_batches: Union[int, float] = 1.0, limit_val_batches: Union[int, float] = 1.0, limit_test_batches: Union[int, float] = 1.0, val_check_interval: Union[int, float] = 1.0, log_save_interval: int = 100, row_log_interval: int = 50, distributed_backend: Optional[str] = None, sync_batchnorm: bool = False, precision: int = 32, weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT, weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, profiler: Optional[Union[BaseProfiler, bool]] = None, benchmark: bool = False, deterministic: bool = False, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: bool = True, amp_backend: str = 'native', amp_level: str = 'O2', # backward compatible, todo: remove in v1.0.0 val_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 test_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 train_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 overfit_pct: float = None, # backward compatible, todo: remove in v1.0.0 ): super().__init__() # init connectors self.dev_debugger = InternalDebugger(self) self.config_validator = ConfigValidator(self) self.data_connector = DataConnector(self) self.optimizer_connector = OptimizerConnector(self) self.accelerator_connector = AcceleratorConnector(self) self.logger_connector = LoggerConnector(self) self.model_connector = ModelConnector(self) self.precision_connector = PrecisionConnector(self) self.callback_connector = CallbackConnector(self) self.debugging_connector = DebuggingConnector(self) self.training_tricks_connector = TrainingTricksConnector(self) self.profile_connector = ProfilerConnector(self) self.tuner = Tuner(self) self.accelerator_backend = None self.evaluation_loop = EvaluationLoop(self) self.train_loop = TrainLoop(self) # training state self.weights_summary = weights_summary self.model = None self.shown_warnings = set() # init callbacks self.callback_connector.on_trainer_init( callbacks, early_stop_callback, checkpoint_callback, progress_bar_refresh_rate, process_position, default_root_dir, weights_save_path, resume_from_checkpoint) # hook self.on_init_start() # init optimizer + lr scheduler related flags self.optimizer_connector.on_trainer_init() # init data flags self.data_connector.on_trainer_init(check_val_every_n_epoch, reload_dataloaders_every_epoch, prepare_data_per_node) # init training tricks self.training_tricks_connector.on_trainer_init( gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, terminate_on_nan) # init accelerator related flags self.accelerator_connector.on_trainer_init( num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, log_gpu_memory, sync_batchnorm, benchmark, replace_sampler_ddp, deterministic) # init train loop related flags self.train_loop.on_trainer_init(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) self.evaluation_loop.on_trainer_init() # configure tuner self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size) # configure profiler self.profile_connector.on_trainer_init(profiler) # init logger flags self.logger_connector.on_trainer_init(logger, log_save_interval, row_log_interval) # init debugging flags self.debugging_connector.on_init_start( overfit_pct, val_percent_check, test_percent_check, train_percent_check, limit_train_batches, limit_val_batches, limit_test_batches, val_check_interval, overfit_batches, fast_dev_run) # set precision self.precision_connector.on_trainer_init(precision, amp_level, amp_backend) # Callback system self.on_init_end()
class Trainer( TrainerProperties, TrainerIOMixin, TrainerCallbackHookMixin, TrainerModelHooksMixin, TrainerOptimizersMixin, TrainerDDPMixin, TrainerLoggingMixin, TrainerTrainingTricksMixin, TrainerDataLoadingMixin, TrainerCallbackConfigMixin, TrainerDeprecatedAPITillVer0_10, ): def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = False, callbacks: Optional[List[Callback]] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, process_position: int = 0, num_nodes: int = 1, num_processes: int = 1, gpus: Optional[Union[List[int], str, int]] = None, auto_select_gpus: bool = False, tpu_cores: Optional[Union[List[int], str, int]] = None, log_gpu_memory: Optional[str] = None, progress_bar_refresh_rate: int = 1, overfit_batches: Union[int, float] = 0.0, track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, fast_dev_run: bool = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_epochs: int = 1000, min_epochs: int = 1, max_steps: Optional[int] = None, min_steps: Optional[int] = None, limit_train_batches: Union[int, float] = 1.0, limit_val_batches: Union[int, float] = 1.0, limit_test_batches: Union[int, float] = 1.0, val_check_interval: Union[int, float] = 1.0, log_save_interval: int = 100, row_log_interval: int = 50, distributed_backend: Optional[str] = None, sync_batchnorm: bool = False, precision: int = 32, weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT, weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, profiler: Optional[Union[BaseProfiler, bool]] = None, benchmark: bool = False, deterministic: bool = False, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: bool = True, amp_backend: str = 'native', amp_level: str = 'O2', # backward compatible, todo: remove in v1.0.0 val_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 test_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 train_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 overfit_pct: float = None, # backward compatible, todo: remove in v1.0.0 ): super().__init__() self.deterministic = deterministic torch.backends.cudnn.deterministic = self.deterministic if self.deterministic: # fixing non-deterministic part of horovod # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383 os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0) # init the default rank if exists # we need to call this here or NVIDIA flags and other messaging in init will show on all ranks # this way we only show it on rank 0 if 'LOCAL_RANK' in os.environ: rank_zero_only.rank = int(os.environ['LOCAL_RANK']) # tracks internal state for debugging self.dev_debugger = InternalDebugger(self) self.config_validator = ConfigValidator(self) self.data_connector = DataConnector(self) self.lr_scheduler_connector = LRSchedulerConnector(self) self.accelerator_connector = AcceleratorConnector(self) self.logger_connector = LoggerConnector(self) self.model_connector = ModelConnector(self) self.initializer = Initializer(self) self.tuner = Tuner(self) self.accelerator_backend = None # loops self.evaluation_loop = EvaluationLoop(self) self.train_loop = TrainLoop(self) # training bookeeping self.total_batch_idx = 0 self.running_loss = TensorRunningAccum(window_length=20) self.batch_idx = 0 self.num_training_batches = 0 self.num_val_batches = [] self.num_sanity_val_batches = [] self.num_test_batches = [] self.train_dataloader = None self.test_dataloaders = None self.val_dataloaders = None # when true, prints test results self.verbose_test = True # when .test() is called, it sets this self.tested_ckpt_path = None # training state self.model = None self.datamodule = None self.testing = False self.prepare_data_per_node = prepare_data_per_node self.lr_schedulers = [] self.optimizers = None self.optimizer_frequencies = [] self.global_step = 0 self.current_epoch = 0 self.interrupted = False self.should_stop = False self.running_sanity_check = False self._state = TrainerState.INITIALIZING self._default_root_dir = default_root_dir or os.getcwd() self._weights_save_path = weights_save_path or self._default_root_dir # init callbacks self.callbacks = callbacks or [] # configure early stop callback # creates a default one if none passed in early_stop_callback = self.configure_early_stopping( early_stop_callback) if early_stop_callback: self.callbacks.append(early_stop_callback) # configure checkpoint callback # it is important that this is the last callback to run # pass through the required args to figure out defaults checkpoint_callback = self.configure_checkpoint_callback( checkpoint_callback) if checkpoint_callback: self.callbacks.append(checkpoint_callback) # TODO refactor codebase (tests) to not directly reach into these callbacks self.checkpoint_callback = checkpoint_callback self.early_stop_callback = early_stop_callback self.on_init_start() # benchmarking self.benchmark = benchmark torch.backends.cudnn.benchmark = self.benchmark # Transfer params self.num_nodes = num_nodes self.log_gpu_memory = log_gpu_memory # sync-bn backend self.sync_batchnorm = sync_batchnorm self.gradient_clip_val = gradient_clip_val self.check_val_every_n_epoch = check_val_every_n_epoch if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf': raise MisconfigurationException( "track_grad_norm can be an int, a float or 'inf' (infinity norm)." ) self.track_grad_norm = float(track_grad_norm) self.tpu_cores = device_parser.parse_tpu_cores(tpu_cores) self.on_tpu = self.tpu_cores is not None self.tpu_id = self.tpu_cores[0] if isinstance(self.tpu_cores, list) else None if num_processes != 1 and distributed_backend != "ddp_cpu": rank_zero_warn( "num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it." ) self.num_processes = num_processes self.weights_summary = weights_summary self.max_epochs = max_epochs self.min_epochs = min_epochs self.max_steps = max_steps self.min_steps = min_steps if num_sanity_val_steps == -1: self.num_sanity_val_steps = float('inf') else: self.num_sanity_val_steps = num_sanity_val_steps self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.auto_lr_find = auto_lr_find self.auto_scale_batch_size = auto_scale_batch_size self._is_data_prepared = False self.replace_sampler_ddp = replace_sampler_ddp self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint self.terminate_on_nan = terminate_on_nan self.shown_warnings = set() self.fast_dev_run = fast_dev_run if self.fast_dev_run: limit_train_batches = 1 limit_val_batches = 1 limit_test_batches = 1 self.num_sanity_val_steps = 0 self.max_epochs = 1 rank_zero_info( 'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch') # configure profiler if profiler is True: profiler = SimpleProfiler() self.profiler = profiler or PassThroughProfiler() # accumulated grads self.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) # override with environment flag gpus = os.environ.get('PL_TRAINER_GPUS', gpus) # for gpus allow int, string and gpu list if auto_select_gpus and isinstance(gpus, int): self.gpus = self.tuner.pick_multiple_gpus(gpus) else: self.gpus = gpus self.data_parallel_device_ids = device_parser.parse_gpu_ids(self.gpus) self.root_gpu = device_parser.determine_root_gpu_device( self.data_parallel_device_ids) self.root_device = torch.device("cpu") self.on_gpu = True if (self.data_parallel_device_ids and torch.cuda.is_available()) else False # tpu state flags self.use_tpu = False self.tpu_local_core_rank = None self.tpu_global_core_rank = None # distributed backend choice self.distributed_backend = distributed_backend self.set_distributed_mode(distributed_backend) # override dist backend when using tpus if self.on_tpu: self.distributed_backend = 'tpu' self.init_tpu() # init flags for SLURM+DDP to work self.world_size = 1 self.interactive_ddp_procs = [] self.configure_slurm_ddp(self.num_nodes) self.node_rank = self.determine_ddp_node_rank() self.local_rank = self.determine_local_rank() self.global_rank = 0 # NVIDIA setup self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) self._progress_bar_callback = self.configure_progress_bar( progress_bar_refresh_rate, process_position) # logging self.configure_logger(logger) self.log_save_interval = log_save_interval self.row_log_interval = row_log_interval # how much of the data to use # TODO: remove in 0.10.0 if overfit_pct is not None: rank_zero_warn( "Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0" " and this argument will be removed in v0.10.0", DeprecationWarning, ) overfit_batches = overfit_pct # TODO: remove in 0.10.0 if val_percent_check is not None: rank_zero_warn( "Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0" " and this argument will be removed in v0.10.0", DeprecationWarning, ) limit_val_batches = val_percent_check # TODO: remove in 0.10.0 if test_percent_check is not None: rank_zero_warn( "Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0" " and this argument will be removed in v0.10.0", DeprecationWarning, ) limit_test_batches = test_percent_check # TODO: remove in 0.10.0 if train_percent_check is not None: rank_zero_warn( "Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0" " and this argument will be removed in v0.10.0", DeprecationWarning, ) limit_train_batches = train_percent_check self.limit_train_batches = _determine_batch_limits( limit_train_batches, 'limit_train_batches') self.limit_val_batches = _determine_batch_limits( limit_val_batches, 'limit_val_batches') self.limit_test_batches = _determine_batch_limits( limit_test_batches, 'limit_test_batches') self.val_check_interval = _determine_batch_limits( val_check_interval, 'val_check_interval') self.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches') self.determine_data_use_amount(self.overfit_batches) # AMP init # These are the only lines needed after v0.8.0 # we wrap the user's forward with autocast and give it back at the end of fit self.autocast_original_forward = None self.precision = precision self.scaler = None self.amp_level = amp_level self.initializer.init_amp(amp_backend) self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv( 'KAGGLE_URL_BASE') # Callback system self.on_init_end() def tune( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, ): # TODO: temporary, need to decide if tune or separate object # setup data, etc... self.setup_fit(model, train_dataloader, val_dataloaders, datamodule) # hook self.call_hook('on_fit_start', model) # hook self.data_connector.prepare_data(model) # Run auto batch size scaling if self.auto_scale_batch_size: if isinstance(self.auto_scale_batch_size, bool): self.auto_scale_batch_size = 'power' self.tuner.scale_batch_size( model, mode=self.auto_scale_batch_size, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule, ) model.logger = self.logger # reset logger binding # Run learning rate finder: if self.auto_lr_find: self.tuner.internal_find_lr(self, model) model.logger = self.logger # reset logger binding # ----------------------------- # MODEL TRAINING # ----------------------------- @trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED) def fit( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, ): results = None # setup data, etc... self.setup_fit(model, train_dataloader, val_dataloaders, datamodule) # hook self.call_hook('on_fit_start', model) # hook self.data_connector.prepare_data(model) # set testing if set in environ self.testing = os.environ.get('PL_TESTING_MODE', self.testing) # ------------------------- # TRAIN # ------------------------- self.accelerator_backend = self.accelerator_connector.select_accelerator( ) self.accelerator_backend.setup(model) results = self.accelerator_backend.train() self.accelerator_backend.teardown() # ------------------------- # POST-Training # ------------------------- # hook self.call_hook('on_fit_end') # hook self.teardown('fit') if self.is_function_implemented('teardown'): model.teardown('fit') # return 1 when finished # used for testing or when we need to know that training succeeded return results or 1 def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): # bind logger and other properties self.model_connector.copy_trainer_model_properties(model) # clean hparams if hasattr(model, 'hparams'): parsing.clean_namespace(model.hparams) # links data to the trainer self.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) # check that model is configured correctly self.config_validator.verify_loop_configurations(model) def setup_training(self, model: LightningModule): """Sanity check a few things before starting actual training. Args: model: The model to run sanity test on. """ # -------------------------- # Setup?? # -------------------------- ref_model = model if self.data_parallel: ref_model = model.module # give model convenience properties ref_model.trainer = self # set local properties on the model self.model_connector.copy_trainer_model_properties(ref_model) # init amp. Must be done here instead of __init__ to allow ddp to work if self.amp_backend == AMPType.NATIVE and self.precision == 16 and not self.use_tpu: self.scaler = torch.cuda.amp.GradScaler() # log hyper-parameters if self.logger is not None: # save exp to get started self.logger.log_hyperparams(ref_model.hparams) self.logger.log_graph(ref_model) self.logger.save() if self.use_ddp or self.use_ddp2: torch_distrib.barrier() # wait for all models to restore weights if self.on_tpu and XLA_AVAILABLE: # wait for all processes to catch up torch_xla.core.xla_model.rendezvous("pl.Trainer.setup_training") elif self.use_horovod: # wait for all processes to catch up hvd.join() # register auto-resubmit when on SLURM self.register_slurm_signal_handlers() # -------------------------- # Pre-train # -------------------------- # on pretrain routine start self.on_pretrain_routine_start(ref_model) if self.is_function_implemented('on_pretrain_routine_start'): ref_model.on_pretrain_routine_start() # print model summary if self.is_global_zero and self.weights_summary is not None and not self.testing: if self.weights_summary in ModelSummary.MODES: ref_model.summarize(mode=self.weights_summary) else: raise MisconfigurationException( "weights_summary can be None, " + ", ".join(ModelSummary.MODES)) # track model now. # if cluster resets state, the model will update with the saved weights self.model = model # restore training and model before hpc is called self.restore_weights(model) # on pretrain routine end self.on_pretrain_routine_end(ref_model) if self.is_function_implemented('on_pretrain_routine_end'): ref_model.on_pretrain_routine_end() def train(self): self.run_sanity_check(self.get_model()) # enable train mode model = self.get_model() model.train() torch.set_grad_enabled(True) # reload data when needed self.train_loop.reset_train_val_dataloaders(model) # hook self.train_loop.on_train_start() try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) # hook self.train_loop.on_train_epoch_start(epoch) # run train epoch self.train_loop.run_training_epoch() if self.max_steps and self.max_steps <= self.global_step: # hook self.train_loop.on_train_end() return # update LR schedulers self.lr_scheduler_connector.update_learning_rates( interval='epoch') # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if self.should_stop: if (met_min_epochs and met_min_steps): self.train_loop.on_train_end() return else: log.info( 'Trainer was signaled to stop but required minimum epochs' f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...') # hook self.train_loop.on_train_end() except KeyboardInterrupt: rank_zero_warn( 'Detected KeyboardInterrupt, attempting graceful shutdown...') # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True self._state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() # hook self.train_loop.on_train_end() def run_evaluation(self, test_mode: bool = False, max_batches=None): # bookkeeping self.evaluation_loop.testing = test_mode dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders( max_batches) if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches): return [], [] # enable eval mode + no grads model = self.get_model() model.zero_grad() model.eval() torch.set_grad_enabled(False) # hook self.evaluation_loop.on_evaluation_start() # set up the eval loop self.evaluation_loop.setup(model, max_batches, dataloaders) # hook # TODO: should this be insider the dataloader loop? self.evaluation_loop.on_evaluation_epoch_start() # run validation/testing for dataloader_idx, dataloader in enumerate(dataloaders): # bookkeeping dl_outputs = [] dataloader = self.accelerator_backend.process_dataloader( dataloader) dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): if batch is None: continue # stop short when running on limited batches if batch_idx >= dl_max_batches: break # hook self.evaluation_loop.on_evaluation_batch_start( batch, batch_idx, dataloader_idx) # lightning module methods output = self.evaluation_loop.evaluation_step( test_mode, batch, batch_idx, dataloader_idx) output = self.evaluation_loop.evaluation_step_end(output) # hook self.evaluation_loop.on_evaluation_batch_end( batch, batch_idx, dataloader_idx) # clean up self.evaluation_loop.evaluation_batch_end_cleanup( output, batch_idx, dataloader_idx) self.evaluation_loop.log_step_metrics(output, batch_idx) # track epoch level metrics if output is not None: dl_outputs.append(output) self.evaluation_loop.outputs.append(dl_outputs) # lightning module method eval_results = self.evaluation_loop.evaluation_epoch_end( num_dataloaders=len(dataloaders)) # bookkeeping eval_loop_results = self.evaluation_loop.log_epoch_metrics( eval_results, test_mode) self.evaluation_loop.predictions.to_disk() # hook self.evaluation_loop.on_evaluation_epoch_end() # enable train mode again model.train() torch.set_grad_enabled(True) # hook self.evaluation_loop.on_evaluation_end() return eval_loop_results, eval_results def run_test(self): # only load test dataloader for testing # self.reset_test_dataloader(ref_model) eval_loop_results, _ = self.run_evaluation(test_mode=True) if len(eval_loop_results) == 0: return 1 # remove the tensors from the eval results for i, result in enumerate(eval_loop_results): if isinstance(result, dict): for k, v in result.items(): if isinstance(v, torch.Tensor): result[k] = v.cpu().item() return eval_loop_results def train_or_test(self): if self.testing: results = self.run_test() else: results = self.train() return results def run_sanity_check(self, ref_model): using_val_step = ref_model.val_dataloader is not None and is_overridden( 'validation_step', ref_model) should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 # run tiny validation (if validation defined) # to make sure program won't crash during val if should_sanity_check: self.reset_val_dataloader(ref_model) self.num_sanity_val_batches = [ min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches ] # hook and callback self.running_sanity_check = True self.on_sanity_check_start() # run eval step _, eval_results = self.run_evaluation( test_mode=False, max_batches=self.num_sanity_val_batches) # allow no returns from eval if eval_results is not None and len(eval_results) > 0: # when we get a list back, used only the last item if isinstance(eval_results, list): eval_results = eval_results[-1] if isinstance(eval_results, EvalResult): callback_metrics = eval_results.callback_metrics else: _, _, _, callback_metrics, _ = self.process_output( eval_results) self.logger_connector.callback_metrics = callback_metrics self.on_sanity_check_end() self.running_sanity_check = False @trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED) def test( self, model: Optional[LightningModule] = None, test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, ckpt_path: Optional[str] = 'best', verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ): # -------------------- # SETUP HOOK # -------------------- self.verbose_test = verbose if self.global_rank != 0: return # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule' ) # Attach datamodule to get setup/prepare_data added to model before the call to it below self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test') if model is not None: results = self.__test_given_model(model, test_dataloaders) else: results = self.__test_using_best_weights(ckpt_path, test_dataloaders) self.teardown('test') return results def __test_using_best_weights(self, ckpt_path, test_dataloaders): model = self.get_model() # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0: raise MisconfigurationException( 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.' ) # load best weights if ckpt_path is not None: # ckpt_path is 'best' so load the best model if ckpt_path == 'best': ckpt_path = self.checkpoint_callback.best_model_path if len(ckpt_path) == 0: rank_zero_warn( f'.test() found no path for the best weights, {ckpt_path}. Please ' f'specify a path for a checkpoint .test(ckpt_path=PATH)') return {} ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) # attach dataloaders if test_dataloaders is not None: self.data_connector.attach_dataloaders( model, test_dataloaders=test_dataloaders) # run tests self.tested_ckpt_path = ckpt_path self.testing = True os.environ['PL_TESTING_MODE'] = '1' self.model = model results = self.fit(model) self.testing = False del os.environ['PL_TESTING_MODE'] # teardown if self.is_function_implemented('teardown'): model_ref = self.get_model() model_ref.teardown('test') return results def __test_given_model(self, model, test_dataloaders): # attach data if test_dataloaders is not None: self.data_connector.attach_dataloaders( model, test_dataloaders=test_dataloaders) # run test # sets up testing so we short circuit to eval self.testing = True self.model = model results = self.fit(model) self.testing = False # teardown if self.is_function_implemented('teardown'): model.teardown('test') return results def call_setup_hook(self, model): # call setup after the ddp process has connected stage_name = 'test' if self.testing else 'fit' if self.datamodule is not None: called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit if not called: self.datamodule.setup(stage_name) self.setup(stage_name) model.setup(stage_name) def call_hook(self, hook_name, *args, **kwargs): # always profile hooks with self.profiler.profile(hook_name): # first call trainer hook if hasattr(self, hook_name): trainer_hook = getattr(self, hook_name) trainer_hook(*args, **kwargs) # next call hook in lightningModule output = None model_ref = self.get_model() if is_overridden(hook_name, model_ref): hook_fx = getattr(model_ref, hook_name) output = hook_fx(*args, **kwargs) # if the PL module doesn't have the hook then call the accelator # used to auto-reduce things for the user with Results obj elif hasattr(self.accelerator_backend, hook_name): accelerator_hook = getattr(self.accelerator_backend, hook_name) output = accelerator_hook(*args, **kwargs) return output
def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = False, callbacks: Optional[List[Callback]] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, process_position: int = 0, num_nodes: int = 1, num_processes: int = 1, gpus: Optional[Union[List[int], str, int]] = None, auto_select_gpus: bool = False, tpu_cores: Optional[Union[List[int], str, int]] = None, log_gpu_memory: Optional[str] = None, progress_bar_refresh_rate: int = 1, overfit_batches: Union[int, float] = 0.0, track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, fast_dev_run: bool = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_epochs: int = 1000, min_epochs: int = 1, max_steps: Optional[int] = None, min_steps: Optional[int] = None, limit_train_batches: Union[int, float] = 1.0, limit_val_batches: Union[int, float] = 1.0, limit_test_batches: Union[int, float] = 1.0, val_check_interval: Union[int, float] = 1.0, log_save_interval: int = 100, row_log_interval: int = 50, distributed_backend: Optional[str] = None, sync_batchnorm: bool = False, precision: int = 32, weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT, weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, profiler: Optional[Union[BaseProfiler, bool]] = None, benchmark: bool = False, deterministic: bool = False, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: bool = True, amp_backend: str = 'native', amp_level: str = 'O2', # backward compatible, todo: remove in v1.0.0 val_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 test_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 train_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 overfit_pct: float = None, # backward compatible, todo: remove in v1.0.0 ): super().__init__() self.deterministic = deterministic torch.backends.cudnn.deterministic = self.deterministic if self.deterministic: # fixing non-deterministic part of horovod # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383 os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0) # init the default rank if exists # we need to call this here or NVIDIA flags and other messaging in init will show on all ranks # this way we only show it on rank 0 if 'LOCAL_RANK' in os.environ: rank_zero_only.rank = int(os.environ['LOCAL_RANK']) # tracks internal state for debugging self.dev_debugger = InternalDebugger(self) self.config_validator = ConfigValidator(self) self.data_connector = DataConnector(self) self.lr_scheduler_connector = LRSchedulerConnector(self) self.accelerator_connector = AcceleratorConnector(self) self.logger_connector = LoggerConnector(self) self.model_connector = ModelConnector(self) self.initializer = Initializer(self) self.tuner = Tuner(self) self.accelerator_backend = None # loops self.evaluation_loop = EvaluationLoop(self) self.train_loop = TrainLoop(self) # training bookeeping self.total_batch_idx = 0 self.running_loss = TensorRunningAccum(window_length=20) self.batch_idx = 0 self.num_training_batches = 0 self.num_val_batches = [] self.num_sanity_val_batches = [] self.num_test_batches = [] self.train_dataloader = None self.test_dataloaders = None self.val_dataloaders = None # when true, prints test results self.verbose_test = True # when .test() is called, it sets this self.tested_ckpt_path = None # training state self.model = None self.datamodule = None self.testing = False self.prepare_data_per_node = prepare_data_per_node self.lr_schedulers = [] self.optimizers = None self.optimizer_frequencies = [] self.global_step = 0 self.current_epoch = 0 self.interrupted = False self.should_stop = False self.running_sanity_check = False self._state = TrainerState.INITIALIZING self._default_root_dir = default_root_dir or os.getcwd() self._weights_save_path = weights_save_path or self._default_root_dir # init callbacks self.callbacks = callbacks or [] # configure early stop callback # creates a default one if none passed in early_stop_callback = self.configure_early_stopping( early_stop_callback) if early_stop_callback: self.callbacks.append(early_stop_callback) # configure checkpoint callback # it is important that this is the last callback to run # pass through the required args to figure out defaults checkpoint_callback = self.configure_checkpoint_callback( checkpoint_callback) if checkpoint_callback: self.callbacks.append(checkpoint_callback) # TODO refactor codebase (tests) to not directly reach into these callbacks self.checkpoint_callback = checkpoint_callback self.early_stop_callback = early_stop_callback self.on_init_start() # benchmarking self.benchmark = benchmark torch.backends.cudnn.benchmark = self.benchmark # Transfer params self.num_nodes = num_nodes self.log_gpu_memory = log_gpu_memory # sync-bn backend self.sync_batchnorm = sync_batchnorm self.gradient_clip_val = gradient_clip_val self.check_val_every_n_epoch = check_val_every_n_epoch if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf': raise MisconfigurationException( "track_grad_norm can be an int, a float or 'inf' (infinity norm)." ) self.track_grad_norm = float(track_grad_norm) self.tpu_cores = device_parser.parse_tpu_cores(tpu_cores) self.on_tpu = self.tpu_cores is not None self.tpu_id = self.tpu_cores[0] if isinstance(self.tpu_cores, list) else None if num_processes != 1 and distributed_backend != "ddp_cpu": rank_zero_warn( "num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it." ) self.num_processes = num_processes self.weights_summary = weights_summary self.max_epochs = max_epochs self.min_epochs = min_epochs self.max_steps = max_steps self.min_steps = min_steps if num_sanity_val_steps == -1: self.num_sanity_val_steps = float('inf') else: self.num_sanity_val_steps = num_sanity_val_steps self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.auto_lr_find = auto_lr_find self.auto_scale_batch_size = auto_scale_batch_size self._is_data_prepared = False self.replace_sampler_ddp = replace_sampler_ddp self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint self.terminate_on_nan = terminate_on_nan self.shown_warnings = set() self.fast_dev_run = fast_dev_run if self.fast_dev_run: limit_train_batches = 1 limit_val_batches = 1 limit_test_batches = 1 self.num_sanity_val_steps = 0 self.max_epochs = 1 rank_zero_info( 'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch') # configure profiler if profiler is True: profiler = SimpleProfiler() self.profiler = profiler or PassThroughProfiler() # accumulated grads self.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) # override with environment flag gpus = os.environ.get('PL_TRAINER_GPUS', gpus) # for gpus allow int, string and gpu list if auto_select_gpus and isinstance(gpus, int): self.gpus = self.tuner.pick_multiple_gpus(gpus) else: self.gpus = gpus self.data_parallel_device_ids = device_parser.parse_gpu_ids(self.gpus) self.root_gpu = device_parser.determine_root_gpu_device( self.data_parallel_device_ids) self.root_device = torch.device("cpu") self.on_gpu = True if (self.data_parallel_device_ids and torch.cuda.is_available()) else False # tpu state flags self.use_tpu = False self.tpu_local_core_rank = None self.tpu_global_core_rank = None # distributed backend choice self.distributed_backend = distributed_backend self.set_distributed_mode(distributed_backend) # override dist backend when using tpus if self.on_tpu: self.distributed_backend = 'tpu' self.init_tpu() # init flags for SLURM+DDP to work self.world_size = 1 self.interactive_ddp_procs = [] self.configure_slurm_ddp(self.num_nodes) self.node_rank = self.determine_ddp_node_rank() self.local_rank = self.determine_local_rank() self.global_rank = 0 # NVIDIA setup self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) self._progress_bar_callback = self.configure_progress_bar( progress_bar_refresh_rate, process_position) # logging self.configure_logger(logger) self.log_save_interval = log_save_interval self.row_log_interval = row_log_interval # how much of the data to use # TODO: remove in 0.10.0 if overfit_pct is not None: rank_zero_warn( "Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0" " and this argument will be removed in v0.10.0", DeprecationWarning, ) overfit_batches = overfit_pct # TODO: remove in 0.10.0 if val_percent_check is not None: rank_zero_warn( "Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0" " and this argument will be removed in v0.10.0", DeprecationWarning, ) limit_val_batches = val_percent_check # TODO: remove in 0.10.0 if test_percent_check is not None: rank_zero_warn( "Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0" " and this argument will be removed in v0.10.0", DeprecationWarning, ) limit_test_batches = test_percent_check # TODO: remove in 0.10.0 if train_percent_check is not None: rank_zero_warn( "Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0" " and this argument will be removed in v0.10.0", DeprecationWarning, ) limit_train_batches = train_percent_check self.limit_train_batches = _determine_batch_limits( limit_train_batches, 'limit_train_batches') self.limit_val_batches = _determine_batch_limits( limit_val_batches, 'limit_val_batches') self.limit_test_batches = _determine_batch_limits( limit_test_batches, 'limit_test_batches') self.val_check_interval = _determine_batch_limits( val_check_interval, 'val_check_interval') self.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches') self.determine_data_use_amount(self.overfit_batches) # AMP init # These are the only lines needed after v0.8.0 # we wrap the user's forward with autocast and give it back at the end of fit self.autocast_original_forward = None self.precision = precision self.scaler = None self.amp_level = amp_level self.initializer.init_amp(amp_backend) self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv( 'KAGGLE_URL_BASE') # Callback system self.on_init_end()
class Trainer( TrainerProperties, TrainerIOMixin, TrainerCallbackHookMixin, TrainerModelHooksMixin, TrainerOptimizersMixin, TrainerDDPMixin, TrainerLoggingMixin, TrainerTrainingTricksMixin, TrainerDataLoadingMixin, TrainerDeprecatedAPITillVer0_10, ): def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = False, callbacks: Optional[List[Callback]] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, process_position: int = 0, num_nodes: int = 1, num_processes: int = 1, gpus: Optional[Union[List[int], str, int]] = None, auto_select_gpus: bool = False, tpu_cores: Optional[Union[List[int], str, int]] = None, log_gpu_memory: Optional[str] = None, progress_bar_refresh_rate: int = 1, overfit_batches: Union[int, float] = 0.0, track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, fast_dev_run: bool = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_epochs: int = 1000, min_epochs: int = 1, max_steps: Optional[int] = None, min_steps: Optional[int] = None, limit_train_batches: Union[int, float] = 1.0, limit_val_batches: Union[int, float] = 1.0, limit_test_batches: Union[int, float] = 1.0, val_check_interval: Union[int, float] = 1.0, log_save_interval: int = 100, row_log_interval: int = 50, distributed_backend: Optional[str] = None, sync_batchnorm: bool = False, precision: int = 32, weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT, weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, profiler: Optional[Union[BaseProfiler, bool]] = None, benchmark: bool = False, deterministic: bool = False, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: bool = True, amp_backend: str = 'native', amp_level: str = 'O2', # backward compatible, todo: remove in v1.0.0 val_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 test_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 train_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 overfit_pct: float = None, # backward compatible, todo: remove in v1.0.0 ): super().__init__() # init connectors self.dev_debugger = InternalDebugger(self) self.config_validator = ConfigValidator(self) self.data_connector = DataConnector(self) self.optimizer_connector = OptimizerConnector(self) self.accelerator_connector = AcceleratorConnector(self) self.logger_connector = LoggerConnector(self) self.model_connector = ModelConnector(self) self.precision_connector = PrecisionConnector(self) self.callback_connector = CallbackConnector(self) self.debugging_connector = DebuggingConnector(self) self.training_tricks_connector = TrainingTricksConnector(self) self.profile_connector = ProfilerConnector(self) self.tuner = Tuner(self) self.accelerator_backend = None self.evaluation_loop = EvaluationLoop(self) self.train_loop = TrainLoop(self) # training state self.weights_summary = weights_summary self.model = None self.shown_warnings = set() # init callbacks self.callback_connector.on_trainer_init( callbacks, early_stop_callback, checkpoint_callback, progress_bar_refresh_rate, process_position, default_root_dir, weights_save_path, resume_from_checkpoint ) # hook self.on_init_start() # init optimizer + lr scheduler related flags self.optimizer_connector.on_trainer_init() # init data flags self.data_connector.on_trainer_init( check_val_every_n_epoch, reload_dataloaders_every_epoch, prepare_data_per_node ) # init training tricks self.training_tricks_connector.on_trainer_init( gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, terminate_on_nan ) # init accelerator related flags self.accelerator_connector.on_trainer_init( num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, log_gpu_memory, sync_batchnorm, benchmark, replace_sampler_ddp, deterministic ) # init train loop related flags self.train_loop.on_trainer_init(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps) self.evaluation_loop.on_trainer_init() # configure tuner self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size) # configure profiler self.profile_connector.on_trainer_init(profiler) # init logger flags self.logger_connector.on_trainer_init(logger, log_save_interval, row_log_interval) # init debugging flags self.debugging_connector.on_init_start( overfit_pct, val_percent_check, test_percent_check, train_percent_check, limit_train_batches, limit_val_batches, limit_test_batches, val_check_interval, overfit_batches, fast_dev_run ) # set precision self.precision_connector.on_trainer_init(precision, amp_level, amp_backend) # Callback system self.on_init_end() def tune( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, ): # TODO: temporary, need to decide if tune or separate object # setup data, etc... self.setup_fit(model, train_dataloader, val_dataloaders, datamodule) # hook self.data_connector.prepare_data(model) # Run auto batch size scaling if self.auto_scale_batch_size: if isinstance(self.auto_scale_batch_size, bool): self.auto_scale_batch_size = 'power' self.tuner.scale_batch_size( model, mode=self.auto_scale_batch_size, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule, ) model.logger = self.logger # reset logger binding # Run learning rate finder: if self.auto_lr_find: self.tuner.internal_find_lr(self, model) model.logger = self.logger # reset logger binding # ----------------------------- # MODEL TRAINING # ----------------------------- @trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED) def fit( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, ): results = None # setup data, etc... self.setup_fit(model, train_dataloader, val_dataloaders, datamodule) # hook self.call_hook('on_fit_start', model) # hook self.data_connector.prepare_data(model) # set testing if set in environ self.testing = os.environ.get('PL_TESTING_MODE', self.testing) # ------------------------- # TRAIN # ------------------------- self.accelerator_backend = self.accelerator_connector.select_accelerator() self.accelerator_backend.setup(model) results = self.accelerator_backend.train() self.accelerator_backend.teardown() # ------------------------- # POST-Training # ------------------------- # hook self.call_hook('on_fit_end') # hook self.teardown('fit') if self.is_function_implemented('teardown'): model.teardown('fit') # return 1 when finished # used for testing or when we need to know that training succeeded return results or 1 def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): # bind logger and other properties self.model_connector.copy_trainer_model_properties(model) # clean hparams if hasattr(model, 'hparams'): parsing.clean_namespace(model.hparams) # links data to the trainer self.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) # check that model is configured correctly self.config_validator.verify_loop_configurations(model) def setup_training(self, model: LightningModule): """Sanity check a few things before starting actual training. Args: model: The model to run sanity test on. """ # -------------------------- # Setup?? # -------------------------- ref_model = model if self.data_parallel: ref_model = model.module # give model convenience properties ref_model.trainer = self # set local properties on the model self.model_connector.copy_trainer_model_properties(ref_model) # init amp. Must be done here instead of __init__ to allow ddp to work if self.amp_backend == AMPType.NATIVE and self.precision == 16 and not self.use_tpu: self.scaler = torch.cuda.amp.GradScaler() # log hyper-parameters if self.logger is not None: # save exp to get started self.logger.log_hyperparams(ref_model.hparams) self.logger.log_graph(ref_model) self.logger.save() if self.use_ddp or self.use_ddp2: torch_distrib.barrier() # wait for all models to restore weights if self.on_tpu and XLA_AVAILABLE: # wait for all processes to catch up torch_xla.core.xla_model.rendezvous("pl.Trainer.setup_training") elif self.use_horovod: # wait for all processes to catch up hvd.join() # register auto-resubmit when on SLURM self.register_slurm_signal_handlers() # -------------------------- # Pre-train # -------------------------- # on pretrain routine start self.on_pretrain_routine_start(ref_model) if self.is_function_implemented('on_pretrain_routine_start'): ref_model.on_pretrain_routine_start() # print model summary if self.is_global_zero and self.weights_summary is not None and not self.testing: if self.weights_summary in ModelSummary.MODES: ref_model.summarize(mode=self.weights_summary) else: raise MisconfigurationException("weights_summary can be None, " + ", ".join(ModelSummary.MODES)) # track model now. # if cluster resets state, the model will update with the saved weights self.model = model # restore training and model before hpc is called self.restore_weights(model) # on pretrain routine end self.on_pretrain_routine_end(ref_model) if self.is_function_implemented('on_pretrain_routine_end'): ref_model.on_pretrain_routine_end() def train(self): self.run_sanity_check(self.get_model()) # enable train mode model = self.get_model() model.train() torch.set_grad_enabled(True) # reload data when needed self.train_loop.reset_train_val_dataloaders(model) # hook self.train_loop.on_train_start() try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) # hook self.train_loop.on_train_epoch_start(epoch) # run train epoch self.train_loop.run_training_epoch() if self.max_steps and self.max_steps <= self.global_step: # hook self.train_loop.on_train_end() return # update LR schedulers self.optimizer_connector.update_learning_rates(interval='epoch') # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if self.should_stop: if (met_min_epochs and met_min_steps): self.train_loop.on_train_end() return else: log.info('Trainer was signaled to stop but required minimum epochs' f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...') # hook self.train_loop.on_train_end() except KeyboardInterrupt: rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True self._state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() # hook self.train_loop.on_train_end() def run_evaluation(self, test_mode: bool = False, max_batches=None): # bookkeeping self.evaluation_loop.testing = test_mode dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches) if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches): return [], [] # enable eval mode + no grads model = self.get_model() model.zero_grad() model.eval() torch.set_grad_enabled(False) # hook self.evaluation_loop.on_evaluation_start() # set up the eval loop self.evaluation_loop.setup(model, max_batches, dataloaders) # hook # TODO: should this be insider the dataloader loop? self.evaluation_loop.on_evaluation_epoch_start() # run validation/testing for dataloader_idx, dataloader in enumerate(dataloaders): # bookkeeping dl_outputs = [] dataloader = self.accelerator_backend.process_dataloader(dataloader) dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): if batch is None: continue # stop short when running on limited batches if batch_idx >= dl_max_batches: break # hook self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) # lightning module methods output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx) output = self.evaluation_loop.evaluation_step_end(output) # hook self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx) # clean up self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx) self.evaluation_loop.log_step_metrics(output, batch_idx) # track epoch level metrics if output is not None: dl_outputs.append(output) self.evaluation_loop.outputs.append(dl_outputs) # lightning module method eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders)) # bookkeeping eval_loop_results = self.evaluation_loop.log_epoch_metrics(eval_results, test_mode) self.evaluation_loop.predictions.to_disk() # hook self.evaluation_loop.on_evaluation_epoch_end() # enable train mode again model.train() torch.set_grad_enabled(True) # hook self.evaluation_loop.on_evaluation_end() return eval_loop_results, eval_results def run_test(self): # only load test dataloader for testing # self.reset_test_dataloader(ref_model) eval_loop_results, _ = self.run_evaluation(test_mode=True) if len(eval_loop_results) == 0: return 1 # remove the tensors from the eval results for i, result in enumerate(eval_loop_results): if isinstance(result, dict): for k, v in result.items(): if isinstance(v, torch.Tensor): result[k] = v.cpu().item() return eval_loop_results def train_or_test(self): if self.testing: results = self.run_test() else: results = self.train() return results def run_sanity_check(self, ref_model): using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model) should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 # run tiny validation (if validation defined) # to make sure program won't crash during val if should_sanity_check: self.reset_val_dataloader(ref_model) self.num_sanity_val_batches = [ min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches ] # hook and callback self.running_sanity_check = True self.on_sanity_check_start() # run eval step _, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches) # allow no returns from eval if eval_results is not None and len(eval_results) > 0: # when we get a list back, used only the last item if isinstance(eval_results, list): eval_results = eval_results[-1] if isinstance(eval_results, EvalResult): callback_metrics = eval_results.callback_metrics else: _, _, _, callback_metrics, _ = self.process_output(eval_results) self.logger_connector.callback_metrics = callback_metrics self.on_sanity_check_end() self.running_sanity_check = False @trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED) def test( self, model: Optional[LightningModule] = None, test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, ckpt_path: Optional[str] = 'best', verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ): # -------------------- # SETUP HOOK # -------------------- self.verbose_test = verbose if self.global_rank != 0: return # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule' ) # Attach datamodule to get setup/prepare_data added to model before the call to it below self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test') if model is not None: results = self.__test_given_model(model, test_dataloaders) else: results = self.__test_using_best_weights(ckpt_path, test_dataloaders) self.teardown('test') return results def __test_using_best_weights(self, ckpt_path, test_dataloaders): model = self.get_model() # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0: raise MisconfigurationException( 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.' ) # load best weights if ckpt_path is not None: # ckpt_path is 'best' so load the best model if ckpt_path == 'best': ckpt_path = self.checkpoint_callback.best_model_path if len(ckpt_path) == 0: rank_zero_warn( f'.test() found no path for the best weights, {ckpt_path}. Please ' f'specify a path for a checkpoint .test(ckpt_path=PATH)' ) return {} ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) # attach dataloaders if test_dataloaders is not None: self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) # run tests self.tested_ckpt_path = ckpt_path self.testing = True os.environ['PL_TESTING_MODE'] = '1' self.model = model results = self.fit(model) self.testing = False del os.environ['PL_TESTING_MODE'] # teardown if self.is_function_implemented('teardown'): model_ref = self.get_model() model_ref.teardown('test') return results def __test_given_model(self, model, test_dataloaders): # attach data if test_dataloaders is not None: self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) # run test # sets up testing so we short circuit to eval self.testing = True self.model = model results = self.fit(model) self.testing = False # teardown if self.is_function_implemented('teardown'): model.teardown('test') return results def call_setup_hook(self, model): # call setup after the ddp process has connected stage_name = 'test' if self.testing else 'fit' if self.datamodule is not None: called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit if not called: self.datamodule.setup(stage_name) self.setup(stage_name) model.setup(stage_name) def call_hook(self, hook_name, *args, **kwargs): # always profile hooks with self.profiler.profile(hook_name): # first call trainer hook if hasattr(self, hook_name): trainer_hook = getattr(self, hook_name) trainer_hook(*args, **kwargs) # next call hook in lightningModule output = None model_ref = self.get_model() if is_overridden(hook_name, model_ref): hook_fx = getattr(model_ref, hook_name) output = hook_fx(*args, **kwargs) # if the PL module doesn't have the hook then call the accelator # used to auto-reduce things for the user with Results obj elif hasattr(self.accelerator_backend, hook_name): accelerator_hook = getattr(self.accelerator_backend, hook_name) output = accelerator_hook(*args, **kwargs) return output