def fit( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, ): r""" Runs the full optimization routine. Args: datamodule: A instance of :class:`LightningDataModule`. model: Model to fit. train_dataloader: A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped. val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped """ # bookkeeping self._state = TrainerState.RUNNING # ---------------------------- # LINK DATA # ---------------------------- # setup data, etc... self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) # hook self.data_connector.prepare_data(model) # bookkeeping # we reuse fit in .test() but change its behavior using this flag self.testing = os.environ.get('PL_TESTING_MODE', self.testing) # ---------------------------- # SET UP TRAINING # ---------------------------- self.accelerator_backend = self.accelerator_connector.select_accelerator( ) self.accelerator_backend.setup(model) # ---------------------------- # INSPECT THESE FOR MAIN LOOPS # ---------------------------- # assign training and eval functions... inspect these to see the train and eval loops :) self.accelerator_backend.train_loop = self.train self.accelerator_backend.validation_loop = self.run_evaluation self.accelerator_backend.test_loop = self.run_evaluation # ---------------------------- # TRAIN # ---------------------------- # hook self.call_hook('on_fit_start') results = self.accelerator_backend.train() self.accelerator_backend.teardown() # ---------------------------- # POST-Training CLEAN UP # ---------------------------- # hook self.call_hook('on_fit_end') # hook self.teardown('fit') if self.is_function_implemented('teardown'): model.teardown('fit') # return 1 when finished # used for testing or when we need to know that training succeeded if self._state != TrainerState.INTERRUPTED: self._state = TrainerState.FINISHED return results or 1
def fit(self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[DataLoader] = None, test_dataloaders: Optional[DataLoader] = None): r""" Runs the full optimization routine. Args: model: Model to fit. train_dataloader: A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped. val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped test_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined test_dataloaders method this will be skipped Example:: # Option 1, # Define the train_dataloader(), test_dataloader() and val_dataloader() fxs # in the lightningModule # RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY trainer = Trainer() model = LightningModule() trainer.fit(model) # Option 2 # in production cases we might want to pass different datasets to the same model # Recommended for PRODUCTION SYSTEMS train, val, test = DataLoader(...), DataLoader(...), DataLoader(...) trainer = Trainer() model = LightningModule() trainer.fit(model, train_dataloader=train, val_dataloader=val, test_dataloader=test) # Option 1 & 2 can be mixed, for example the training set can be # defined as part of the model, and validation/test can then be # feed to .fit() """ # bind logger and other properties model.logger = self.logger self.copy_trainer_model_properties(model) # set up the passed in dataloaders (if needed) self.__attach_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders) # download the data and do whatever transforms we need # do before any spawn calls so that the model can assign properties # only on proc 0 because no spawn has happened yet model.prepare_data() # route to appropriate start method # when using multi-node or DDP within a node start each module in a separate process if self.use_ddp2: task = int(os.environ['SLURM_LOCALID']) self.ddp_train(task, model) elif self.use_ddp: if self.is_slurm_managing_tasks: task = int(os.environ['SLURM_LOCALID']) self.ddp_train(task, model) else: self.__set_random_port() # track for predict self.model = model # train mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model, )) # load weights if not interrupted self.load_spawn_weights(model) self.model = model # 1 gpu or dp option triggers training using DP module # easier to avoid NCCL issues elif self.use_dp: self.dp_train(model) elif self.single_gpu: self.single_gpu_train(model) elif self.use_tpu: # pragma: no-cover log.info(f'training on {self.num_tpu_cores} TPU cores') # COLAB_GPU is an env var available by default in Colab environments. start_method = 'fork' if os.getenv('COLAB_GPU') else 'spawn' # track for predict self.model = model # train xmp.spawn(self.tpu_train, args=(model, ), nprocs=self.num_tpu_cores, start_method=start_method) # load weights if not interrupted self.load_spawn_weights(model) self.model = model # ON CPU else: # run through amp wrapper if self.use_amp: raise MisconfigurationException( 'amp + cpu is not supported. Please use a GPU option') # CHOOSE OPTIMIZER # allow for lr schedulers as well self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \ self.init_optimizers(model.configure_optimizers()) self.run_pretrain_routine(model) # return 1 when finished # used for testing or when we need to know that training succeeded return 1
def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: optim_conf = model.configure_optimizers() if optim_conf is None: rank_zero_warn( '`LightningModule.configure_optimizers` returned `None`, ' 'this fit will run with no optimizer', UserWarning) optim_conf = _MockOptimizer() # single output, single optimizer if isinstance(optim_conf, Optimizer): return [optim_conf], [], [] # two lists, optimizer + lr schedulers elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \ and isinstance(optim_conf[0], list): optimizers, lr_schedulers = optim_conf lr_schedulers = self.configure_schedulers(lr_schedulers) return optimizers, lr_schedulers, [] # single dictionary elif isinstance(optim_conf, dict): optimizer = optim_conf["optimizer"] monitor = optim_conf.get('monitor', None) lr_scheduler = optim_conf.get("lr_scheduler", []) if lr_scheduler: lr_schedulers = self.configure_schedulers([lr_scheduler], monitor) else: lr_schedulers = [] return [optimizer], lr_schedulers, [] # multiple dictionaries elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict): optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] # take only lr wif exists and ot they are defined - not None lr_schedulers = [ opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler") ] # take only freq wif exists and ot they are defined - not None optimizer_frequencies = [ opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency") is not None ] # clean scheduler list if lr_schedulers: lr_schedulers = self.configure_schedulers(lr_schedulers) # assert that if frequencies are present, they are given for all optimizers if optimizer_frequencies and len(optimizer_frequencies) != len( optimizers): raise ValueError( "A frequency must be given to each optimizer.") return optimizers, lr_schedulers, optimizer_frequencies # single list or tuple, multiple optimizer elif isinstance(optim_conf, (list, tuple)): return list(optim_conf), [], [] # unknown configuration else: raise ValueError( 'Unknown configuration for model optimizers.' ' Output from `model.configure_optimizers()` should either be:' ' * single output, single `torch.optim.Optimizer`' ' * single output, list of `torch.optim.Optimizer`' ' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)' ' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)' ' * two outputs, first being a list of `torch.optim.Optimizer` second being' ' a list of `torch.optim.lr_scheduler`' ' * multiple outputs, dictionaries as described with an optional `frequency` key (int)' )
def lr_find( trainer, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, datamodule: Optional[LightningDataModule] = None, update_attr: bool = False, ): r""" ``lr_find`` enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. Args: model: Model to do range testing for train_dataloader: A PyTorch ``DataLoader`` with training samples. If the model has a predefined train_dataloader method, this will be skipped. min_lr: minimum learning rate to investigate max_lr: maximum learning rate to investigate num_training: number of learning rates to test mode: Search strategy to update learning rate after each batch: - ``'exponential'`` (default): Will increase the learning rate exponentially. - ``'linear'``: Will increase the learning rate linearly. early_stop_threshold: threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None. datamodule: An optional ``LightningDataModule`` which holds the training and validation dataloader(s). Note that the ``train_dataloader`` and ``val_dataloaders`` parameters cannot be used at the same time as this parameter, or a ``MisconfigurationException`` will be raised. update_attr: Whether to update the learning rate attribute or not. Example:: # Setup model and trainer model = MyModelClass(hparams) trainer = pl.Trainer() # Run lr finder lr_finder = trainer.tuner.lr_find(model, ...) # Inspect results fig = lr_finder.plot(); fig.show() suggested_lr = lr_finder.suggestion() # Overwrite lr and create new model hparams.lr = suggested_lr model = MyModelClass(hparams) # Ready to train with new learning rate trainer.fit(model) """ if trainer.fast_dev_run: rank_zero_warn( 'Skipping learning rate finder since fast_dev_run is enabled.', UserWarning) return # Determine lr attr if update_attr: lr_attr_name = _determine_lr_attr_name(trainer, model) save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt') __lr_finder_dump_params(trainer, model) # Prevent going into infinite loop trainer.auto_lr_find = False # Initialize lr finder object (stores results) lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) # Use special lr logger callback trainer.callbacks = [ _LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1) ] # No logging trainer.logger = DummyLogger() # Max step set to number of iterations trainer.max_steps = num_training # Disable standard progress bar for fit if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() # Required for saving the model trainer.optimizers, trainer.schedulers = [], [], trainer.model = model # Dump model checkpoint trainer.save_checkpoint(str(save_path)) # Configure optimizer and scheduler model.configure_optimizers = lr_finder._exchange_scheduler( model.configure_optimizers) # Fit, lr & loss logged in callback trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule) # Prompt if we stopped early if trainer.global_step != num_training: log.info('LR finder stopped early due to diverging loss.') # Transfer results from callback to lr finder object lr_finder.results.update({ 'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses }) lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose # Reset model state if trainer.is_global_zero: trainer.checkpoint_connector.restore( str(save_path), on_gpu=trainer._device_type == DeviceType.GPU) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) # Finish by resetting variables so trainer is ready to fit model __lr_finder_restore_params(trainer, model) if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() # Update lr attr if required if update_attr: lr = lr_finder.suggestion() # TODO: log lr.results to self.logger lightning_setattr(model, lr_attr_name, lr) log.info(f'Learning rate set to {lr}') return lr_finder
def fit( self, model: LightningModule, train_dataloader: Any = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, ): r""" Runs the full optimization routine. Args: datamodule: A instance of :class:`LightningDataModule`. model: Model to fit. train_dataloader: Either a single PyTorch DataLoader or a collection of these (list, dict, nested lists and dicts). In the case of multiple dataloaders, please see this :ref:`page <multiple-training-dataloaders>` val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped """ # bookkeeping self._state = TrainerState.RUNNING # bookkeeping # we reuse fit in .test() and .predict(). When already set, it shouldn't be modified. if self._running_stage is None: self._running_stage = RunningStage.TRAINING # set local properties on the model self.model_connector.copy_trainer_model_properties(model) # ---------------------------- # LINK DATA # ---------------------------- # setup data, etc... self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) # hook self.data_connector.prepare_data(model) self.callback_connector._attach_model_callbacks(model, self) # ---------------------------- # SET UP TRAINING # ---------------------------- self.call_setup_hook(model) self.call_hook("on_before_accelerator_backend_setup", model) self.accelerator.setup(self, model) self.setup_trainer(model) # ---------------------------- # INSPECT THE CORE LOOPS # ---------------------------- # Lightning internal flow looks like this. # # trainer.fit(...) or trainer.test(...) or trainer.predict(...) || # | || # create accelerator || # | || # trainer.dispatch || LIGHTNING # | || # start_training or start_testing or start_predicting call || FLOW # from `accelerator` || # | || DIRECTION # run_train or run_test or run_predict call || # from `trainer` || # | || # results \/ # This is used to guide readers to the core loops: train, test, predict. # `run_predict` is the simplest to understand, use `Go to Definition` to read it :) # Search for `start_training` or `start_testing` or `start_predicting` in # `pytorch_lightning/plugins/training_type` folder to find accelerator dispatch functions. self.accelerator.train_loop = self.run_train self.accelerator.validation_loop = self.run_evaluation self.accelerator.test_loop = self.run_evaluation self.accelerator.predict_loop = self.run_predict # ---------------------------- # TRAIN # ---------------------------- # hook self.call_hook("on_fit_start") # plugin will setup fitting (e.g. ddp will launch child processes) self.pre_dispatch() # dispath `start_training` or `start_testing` or `start_predicting` self.dispatch() # plugin will finalized fitting (e.g. ddp_spawn will load trained model) self.post_dispatch() # ---------------------------- # POST-Training CLEAN UP # ---------------------------- # hook self.call_hook('on_fit_end') # hook self.teardown('fit') if self.is_function_implemented('teardown'): model.teardown('fit') # return 1 when finished # used for testing or when we need to know that training succeeded if self._state != TrainerState.INTERRUPTED: self._state = TrainerState.FINISHED self._running_stage = None return self.accelerator.results or 1
def backward(self, loss, optimizer, optimizer_idx): return LightningModule.backward(self, loss, optimizer, optimizer_idx)
def _evaluate( self, model: LightningModule, dataloaders: List[DataLoader], max_batches: Union[int, List[int]], test_mode: bool = False ): """Run evaluation code. Args: model: The model to evaluate. dataloaders: A list of PyTorch dataloaders. max_batches: An integer or list of integers with length of the number of dataloaders. Each entry is the number of batches to process in the corresponding dataloader. test_mode: """ # enable eval mode model.zero_grad() model.eval() # copy properties for forward overrides self.copy_trainer_model_properties(model) # disable gradients to save memory torch.set_grad_enabled(False) # bookkeeping outputs = [] # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) # run validation for dataloader_idx, dataloader in enumerate(dataloaders): dl_outputs = [] # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device(self.tpu_id) dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) # each dataloader has a max num batches dl_max_batches = max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): if batch is None: continue # stop short when on fast_dev_run (sets max_batch=1) if batch_idx >= dl_max_batches: break # callbacks if test_mode: self.on_test_batch_start() else: self.on_validation_batch_start() # ----------------- # RUN EVALUATION STEP # ----------------- if self.use_amp and NATIVE_AMP_AVALAIBLE: with torch.cuda.amp.autocast(): output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) else: output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) # on dp / ddp2 might still want to do something with the batch parts if test_mode: if self.is_overridden('test_step_end'): model_ref = self.get_model() with self.profiler.profile('test_step_end'): output = model_ref.test_step_end(output) self.on_test_batch_end() else: if self.is_overridden('validation_step_end'): model_ref = self.get_model() with self.profiler.profile('validation_step_end'): output = model_ref.validation_step_end(output) self.on_validation_batch_end() # track outputs for collation dl_outputs.append(output) outputs.append(dl_outputs) eval_results = {} # with a single dataloader don't pass an array if len(dataloaders) == 1: outputs = outputs[0] # give model a chance to do something with the outputs (and method defined) if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)): model = model.module if test_mode: if self.is_overridden('test_end', model=model): # TODO: remove in v1.0.0 eval_results = model.test_end(outputs) rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed v1.0.' ' Use `test_epoch_end` instead.', DeprecationWarning) elif self.is_overridden('test_epoch_end', model=model): eval_results = model.test_epoch_end(outputs) else: if self.is_overridden('validation_end', model=model): # TODO: remove in v1.0.0 eval_results = model.validation_end(outputs) rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed v1.0.' ' Use `validation_epoch_end` instead.', DeprecationWarning) elif self.is_overridden('validation_epoch_end', model=model): eval_results = model.validation_epoch_end(outputs) # aggregate ddp stats across has_content = eval_results is not None and len(eval_results) > 0 if has_content and (self.use_ddp or self.use_ddp2): self.reduce_eval_ddp(eval_results) # enable train mode again model.train() # enable gradients to save memory torch.set_grad_enabled(True) return eval_results
def lr_find(self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[DataLoader] = None, min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, num_accumulation_steps=None): r""" lr_find enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. Args: model: Model to do range testing for train_dataloader: A PyTorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped. min_lr: minimum learning rate to investigate max_lr: maximum learning rate to investigate num_training: number of learning rates to test mode: search strategy, either 'linear' or 'exponential'. If set to 'linear' the learning rate will be searched by linearly increasing after each batch. If set to 'exponential', will increase learning rate exponentially. early_stop_threshold: threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None. num_accumulation_steps: deprepecated, number of batches to calculate loss over. Set trainer argument ``accumulate_grad_batches`` instead. Example:: # Setup model and trainer model = MyModelClass(hparams) trainer = pl.Trainer() # Run lr finder lr_finder = trainer.lr_find(model, ...) # Inspect results fig = lr_finder.plot(); fig.show() suggested_lr = lr_finder.suggestion() # Overwrite lr and create new model hparams.lr = suggested_lr model = MyModelClass(hparams) # Ready to train with new learning rate trainer.fit(model) """ if num_accumulation_steps is not None: rank_zero_warn("Argument `num_accumulation_steps` has been deprepecated" " since v0.7.6 and will be removed in 0.9. Please" " set trainer argument `accumulate_grad_batches` instead.", DeprecationWarning) save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt') self.__lr_finder_dump_params(model) # Prevent going into infinite loop self.auto_lr_find = False # Initialize lr finder object (stores results) lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) # Use special lr logger callback self.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] # No logging self.logger = DummyLogger() # Max step set to number of iterations self.max_steps = num_training # Disable standard progress bar for fit if self.progress_bar_callback: self.progress_bar_callback.disable() # Disable standard checkpoint & early stopping self.checkpoint_callback = False self.early_stop_callback = None self.enable_early_stop = False # Required for saving the model self.optimizers, self.schedulers = [], [], self.model = model # Dump model checkpoint self.save_checkpoint(str(save_path)) # Configure optimizer and scheduler optimizers, _, _ = self.init_optimizers(model) if len(optimizers) != 1: raise MisconfigurationException( f'`model.configure_optimizers()` returned {len(optimizers)}, but' ' learning rate finder only works with single optimizer') model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0]) # Fit, lr & loss logged in callback self.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) # Prompt if we stopped early if self.global_step != num_training: log.info('LR finder stopped early due to diverging loss.') # Transfer results from callback to lr finder object lr_finder.results.update({'lr': self.callbacks[0].lrs, 'loss': self.callbacks[0].losses}) lr_finder._total_batch_idx = self.total_batch_idx # for debug purpose # Reset model state self.restore(str(save_path), on_gpu=self.on_gpu) os.remove(save_path) # Finish by resetting variables so trainer is ready to fit model self.__lr_finder_restore_params(model) if self.progress_bar_callback: self.progress_bar_callback.enable() return lr_finder
def call_teardown_hook(self, model: LightningModule) -> None: state = self._teardown_state self.profiler.teardown(stage=state) self.teardown(stage=state) model.teardown(stage=state)
def fit( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None ): r""" Runs the full optimization routine. Args: model: Model to fit. train_dataloader: A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped. val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped Example:: # Option 1, # Define the train_dataloader() and val_dataloader() fxs # in the lightningModule # RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY trainer = Trainer() model = LightningModule() trainer.fit(model) # Option 2 # in production cases we might want to pass different datasets to the same model # Recommended for PRODUCTION SYSTEMS train, val = DataLoader(...), DataLoader(...) trainer = Trainer() model = LightningModule() trainer.fit(model, train_dataloader=train, val_dataloaders=val) # Option 1 & 2 can be mixed, for example the training set can be # defined as part of the model, and validation can then be feed to .fit() """ # bind logger and other properties model.logger = self.logger self.copy_trainer_model_properties(model) # clean hparams if hasattr(model, 'hparams'): parsing.clean_namespace(model.hparams) # set up the passed in dataloaders (if needed) self.__attach_dataloaders(model, train_dataloader, val_dataloaders) # check that model is configured correctly self.check_model_configuration(model) # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 if self.can_prepare_data(): model.prepare_data() self._is_data_prepared = True # Run auto batch size scaling if self.auto_scale_batch_size: if isinstance(self.auto_scale_batch_size, bool): self.auto_scale_batch_size = 'power' self.scale_batch_size(model, mode=self.auto_scale_batch_size) model.logger = self.logger # reset logger binding # Run learning rate finder: if self.auto_lr_find: self._run_lr_finder_internally(model) model.logger = self.logger # reset logger binding # route to appropriate start method # when using multi-node or DDP within a node start each module in a separate process if self.use_ddp2: if self.is_slurm_managing_tasks: task = int(os.environ['SLURM_LOCALID']) # torchelastic or general non_slurm ddp2 elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ): task = int(os.environ['LOCAL_RANK']) self.ddp_train(task, model) elif self.use_ddp: if self.is_slurm_managing_tasks: task = int(os.environ['SLURM_LOCALID']) self.ddp_train(task, model) # torchelastic or general non_slurm ddp elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ): task = int(os.environ['LOCAL_RANK']) self.ddp_train(task, model) elif self.distributed_backend == 'cpu_ddp': self.__set_random_port() self.model = model mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,)) elif self.distributed_backend == 'ddp_spawn': model.share_memory() # spin up peers mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, )) elif self.distributed_backend == 'ddp': self.spawn_ddp_children(model) # 1 gpu or dp option triggers training using DP module # easier to avoid NCCL issues elif self.use_dp: self.dp_train(model) elif self.use_horovod: self.horovod_train(model) elif self.single_gpu: self.single_gpu_train(model) elif self.use_tpu: # pragma: no-cover rank_zero_info(f'training on {self.tpu_cores} TPU cores') # COLAB_GPU is an env var available by default in Colab environments. start_method = 'fork' if self.on_colab_kaggle else 'spawn' # track for predict self.model = model # train if self.tpu_id is not None: self.tpu_train(self.tpu_id, model) else: xmp.spawn(self.tpu_train, args=(model,), nprocs=self.tpu_cores, start_method=start_method) # load weights if not interrupted self.load_spawn_weights(model) self.model = model # ON CPU else: # run through amp wrapper if self.use_amp: raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') # CHOOSE OPTIMIZER # allow for lr schedulers as well self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) self.run_pretrain_routine(model) # return 1 when finished # used for testing or when we need to know that training succeeded return 1
def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_mode: bool = False): """Run evaluation code. Args: model: PT model dataloaders: list of PT dataloaders max_batches: Scalar test_mode: """ # enable eval mode model.zero_grad() model.eval() # copy properties for forward overrides self.copy_trainer_model_properties(model) # disable gradients to save memory torch.set_grad_enabled(False) # bookkeeping outputs = [] # run validation for dataloader_idx, dataloader in enumerate(dataloaders): dl_outputs = [] # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device() dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) for batch_idx, batch in enumerate(dataloader): if batch is None: continue # stop short when on fast_dev_run (sets max_batch=1) if batch_idx >= max_batches: break # ----------------- # RUN EVALUATION STEP # ----------------- output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) # on dp / ddp2 might still want to do something with the batch parts if test_mode: if self.is_overriden('test_step_end'): model_ref = self.get_model() with self.profiler.profile('test_step_end'): output = model_ref.test_step_end(output) else: if self.is_overriden('validation_step_end'): model_ref = self.get_model() with self.profiler.profile('validation_step_end'): output = model_ref.validation_step_end(output) # track outputs for collation dl_outputs.append(output) # batch done if self.progress_bar_refresh_rate >= 1 and batch_idx % self.progress_bar_refresh_rate == 0: if test_mode: self.test_progress_bar.update( self.progress_bar_refresh_rate) else: self.val_progress_bar.update( self.progress_bar_refresh_rate) self.main_progress_bar.update( self.progress_bar_refresh_rate) outputs.append(dl_outputs) eval_results = {} # with a single dataloader don't pass an array if len(dataloaders) == 1: outputs = outputs[0] # give model a chance to do something with the outputs (and method defined) if isinstance( model, (LightningDistributedDataParallel, LightningDataParallel)): model = model.module if test_mode: if self.is_overriden('test_end', model=model): # TODO: remove in v1.0.0 eval_results = model.test_end(outputs) rank_zero_warn( 'Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.' ' Use `test_epoch_end` instead.', DeprecationWarning) elif self.is_overriden('test_epoch_end', model=model): eval_results = model.test_epoch_end(outputs) else: if self.is_overriden('validation_end', model=model): # TODO: remove in v1.0.0 eval_results = model.validation_end(outputs) rank_zero_warn( 'Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.' ' Use `validation_epoch_end` instead.', DeprecationWarning) elif self.is_overriden('validation_epoch_end', model=model): eval_results = model.validation_epoch_end(outputs) # enable train mode again model.train() # enable gradients to save memory torch.set_grad_enabled(True) return eval_results
def fit( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, ): r""" Runs the full optimization routine. Args: datamodule: A instance of :class:`LightningDataModule`. model: Model to fit. train_dataloader: A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped. val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped """ self._state = TrainerState.RUNNING # setup data, etc... self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) # hook self.data_connector.prepare_data(model) # set testing if set in environ self.testing = os.environ.get('PL_TESTING_MODE', self.testing) # ------------------------- # TRAIN # ------------------------- self.accelerator_backend = self.accelerator_connector.select_accelerator() self.accelerator_backend.setup(model) # hook self.call_hook('on_fit_start') results = self.accelerator_backend.train() self.accelerator_backend.teardown() # ------------------------- # POST-Training # ------------------------- # hook self.call_hook('on_fit_end') # hook self.teardown('fit') if self.is_function_implemented('teardown'): model.teardown('fit') # return 1 when finished # used for testing or when we need to know that training succeeded if self._state != TrainerState.INTERRUPTED: self._state = TrainerState.FINISHED return results or 1
def on_epoch_end(self, trainer: Trainer, model: LightningModule): metrics = self.get_metrics(trainer, model) assert metrics["foo"] == self.trainer.current_epoch assert metrics["foo_2"] == self.trainer.current_epoch model.on_epoch_end_called = True
def _evaluate( self, model: LightningModule, dataloaders: List[DataLoader], max_batches: Union[int, List[int]], test_mode: bool = False ): """Run evaluation code. Args: model: The model to evaluate. dataloaders: A list of PyTorch dataloaders. max_batches: An integer or list of integers with length of the number of dataloaders. Each entry is the number of batches to process in the corresponding dataloader. test_mode: """ # enable eval mode model.zero_grad() model.eval() # copy properties for forward overrides self.copy_trainer_model_properties(model) # disable gradients to save memory torch.set_grad_enabled(False) # bookkeeping outputs = [] # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) # -------------------------- # ON_EVAL_EPOCH_START hook # -------------------------- self.__call_eval_loop_hook_start(test_mode) # run validation for dataloader_idx, dataloader in enumerate(dataloaders): dl_outputs = [] # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device(self.tpu_id) dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) # each dataloader has a max num batches dl_max_batches = max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): if batch is None: continue # stop short when on fast_dev_run (sets max_batch=1) if batch_idx >= dl_max_batches: break # callbacks if test_mode: self.on_test_batch_start() else: self.on_validation_batch_start() # ----------------- # RUN EVALUATION STEP # ----------------- if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: with torch.cuda.amp.autocast(): output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) else: output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) # allow only EvalResult when using structured results (from val_step) if isinstance(output, Result) and not isinstance(output, EvalResult): m = 'only EvalResults or dicts are allowed from validation_step' raise MisconfigurationException(m) # on dp / ddp2 might still want to do something with the batch parts if test_mode: if self.is_overridden('test_step_end'): model_ref = self.get_model() with self.profiler.profile('test_step_end'): output = model_ref.test_step_end(output) self.on_test_batch_end() else: if self.is_overridden('validation_step_end'): model_ref = self.get_model() with self.profiler.profile('validation_step_end'): output = model_ref.validation_step_end(output) self.on_validation_batch_end() # track outputs for collation if output is not None: dl_outputs.append(output) self.__eval_add_step_metrics(output) outputs.append(dl_outputs) # --------------------- # EVAL_EPOCH_END # --------------------- using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult) eval_results = self.__run_eval_epoch_end(test_mode, outputs, dataloaders, using_eval_result) # log callback metrics self.__update_callback_metrics(eval_results, using_eval_result) # enable train mode again model.train() # enable gradients to save memory torch.set_grad_enabled(True) # -------------------------- # ON_EVAL_EPOCH_END hook # -------------------------- self.__call_eval_loop_hook_end(test_mode) return eval_results
def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: optim_conf = model.configure_optimizers() if optim_conf is None: rank_zero_warn( '`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer', UserWarning, ) optim_conf = _MockOptimizer() optimizers, lr_schedulers, optimizer_frequencies = [], [], [] monitor = None # single output, single optimizer if isinstance(optim_conf, Optimizer): optimizers = [optim_conf] # two lists, optimizer + lr schedulers elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance( optim_conf[0], list): opt, sch = optim_conf optimizers = opt lr_schedulers = sch if isinstance(sch, list) else [sch] # single dictionary elif isinstance(optim_conf, dict): optimizers = [optim_conf["optimizer"]] monitor = optim_conf.get('monitor', None) lr_schedulers = [optim_conf["lr_scheduler"] ] if "lr_scheduler" in optim_conf else [] # multiple dictionaries elif isinstance(optim_conf, (list, tuple)) and all( isinstance(d, dict) for d in optim_conf): optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] lr_schedulers = [ opt_dict["lr_scheduler"] for opt_dict in optim_conf if "lr_scheduler" in opt_dict ] optimizer_frequencies = [ opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None ] # assert that if frequencies are present, they are given for all optimizers if optimizer_frequencies and len(optimizer_frequencies) != len( optimizers): raise ValueError( "A frequency must be given to each optimizer.") # single list or tuple, multiple optimizer elif isinstance(optim_conf, (list, tuple)): optimizers = list(optim_conf) # unknown configuration else: raise MisconfigurationException( 'Unknown configuration for model optimizers.' ' Output from `model.configure_optimizers()` should either be:\n' ' * `torch.optim.Optimizer`\n' ' * [`torch.optim.Optimizer`]\n' ' * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n' ' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n' ' * A list of the previously described dict format, with an optional "frequency" key (int)' ) lr_schedulers = self.configure_schedulers(lr_schedulers, monitor=monitor) _validate_scheduler_optimizer(optimizers, lr_schedulers) return optimizers, lr_schedulers, optimizer_frequencies