class Closure(AbstractClosure): """An implementation of a :class:`AbstractClosure` for optimization in Lightning that combines three elementary closures into one: ``training_step``, ``backward`` and ``zero_grad``. The Closure gets created by the training loop(s) and is then passed to the :meth:`torch.optim.Optimizer.step` method. An optimizer is responsible for calling the closure and optionally do something with the output. Args: step_fn: This is typically the :meth:`pytorch_lightning.core.lightning.LightningModule.training_step wrapped with processing for its outputs backward_fn: A function that takes a loss value as input, performs back-propagation and returns the loss value. Can be set to ``None`` to skip the backward operation. zero_grad_fn: A function that zeroes the gradients. Can be set to ``None`` to skip zero_grad, for example when accumulating gradients. profiler: A profiler for profiling the actions of the passed in closure functions. Example: closure = Closure() optimizer = torch.optim.Adam(...) optimizer.step(closure) """ warning_cache = WarningCache() def __init__( self, step_fn: Callable[[], Optional[Dict]], backward_fn: Optional[Callable[[Tensor], Tensor]] = None, zero_grad_fn: Optional[Callable[[], None]] = None, profiler: Optional[BaseProfiler] = None, ): super().__init__() self._step_fn = step_fn self._backward_fn = backward_fn self._zero_grad_fn = zero_grad_fn self._profiler = PassThroughProfiler( ) if profiler is None else profiler def closure(self, *args: Any, **kwargs: Any) -> Optional[ClosureResult]: with self._profiler.profile("training_step_and_backward"): step_output = self._step_fn() step_output = ClosureResult(**step_output) if step_output else None if step_output is None: self.warning_cache.warn( "training_step returned None. If this was on purpose, ignore this warning..." ) if self._zero_grad_fn is not None: with self._profiler.profile("zero_grad"): self._zero_grad_fn() if self._backward_fn is not None and step_output is not None and step_output.closure_loss is not None: with self._profiler.profile("backward"): step_output.closure_loss = self._backward_fn( step_output.closure_loss) return step_output
def __init__( self, step_fn: Callable[[], dict], backward_fn: Optional[Callable[[Tensor], Tensor]] = None, zero_grad_fn: Optional[Callable[[], None]] = None, profiler: Optional[BaseProfiler] = None, ): super().__init__() self._step_fn = step_fn self._backward_fn = backward_fn self._zero_grad_fn = zero_grad_fn self._profiler = PassThroughProfiler() if profiler is None else profiler
def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]): if profiler and not isinstance(profiler, (bool, str, BaseProfiler)): # TODO: Update exception on removal of bool raise MisconfigurationException( "Only None, bool, str and subclasses of `BaseProfiler`" " are valid values for `Trainer`'s `profiler` parameter." f" Received {profiler} which is of type {type(profiler)}.") if isinstance(profiler, bool): rank_zero_warn( "Passing a bool value as a `profiler` argument to `Trainer` is deprecated" " and will be removed in v1.3. Use str ('simple' or 'advanced') instead.", DeprecationWarning) if profiler: profiler = SimpleProfiler() elif isinstance(profiler, str): if profiler.lower() in PROFILERS: profiler_class = PROFILERS[profiler.lower()] profiler = profiler_class() else: raise ValueError( "When passing string value for the `profiler` parameter of" " `Trainer`, it can only be 'simple' or 'advanced'") self.trainer.profiler = profiler or PassThroughProfiler()
def build_profiler(name): if name == 'inference': return InferenceProfiler() elif name == 'pytorch': from pytorch_lightning.profiler import PyTorchProfiler return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) elif name is None: return PassThroughProfiler() else: raise ValueError(f'Invalid profiler: {name}')
def build_profiler(name): if name == 'inference': return InferenceProfiler() elif name == 'pytorch': from pytorch_lightning.profiler import PyTorchProfiler # TODO: this profiler will be introduced after upgrading pl dependency to 1.3.0 @zehong return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) elif name is None: return PassThroughProfiler() else: raise ValueError(f'Invalid profiler: {name}')
def on_trainer_init(self, profiler: Union[BaseProfiler, str]): if profiler and not isinstance(profiler, (str, BaseProfiler)): raise MisconfigurationException( "Only None, str and subclasses of `BaseProfiler`" " are valid values for `Trainer`'s `profiler` parameter." f" Received {profiler} which is of type {type(profiler)}.") if isinstance(profiler, str): if profiler.lower() in PROFILERS: profiler_class = PROFILERS[profiler.lower()] profiler = profiler_class() else: raise ValueError( "When passing string value for the `profiler` parameter of" " `Trainer`, it can only be 'simple' or 'advanced'") self.trainer.profiler = profiler or PassThroughProfiler()
def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = None, callbacks: List[Callback] = [], default_save_path: Optional[str] = None, gradient_clip_val: float = 0, gradient_clip=None, # backward compatible, todo: remove in v0.8.0 process_position: int = 0, nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0 num_nodes: int = 1, gpus: Optional[Union[List[int], str, int]] = None, num_tpu_cores: Optional[int] = None, log_gpu_memory: Optional[str] = None, show_progress_bar: bool = True, progress_bar_refresh_rate: int = 50, overfit_pct: float = 0.0, track_grad_norm: int = -1, check_val_every_n_epoch: int = 1, fast_dev_run: bool = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 max_epochs: int = 1000, min_epochs: int = 1, max_steps: Optional[int] = None, min_steps: Optional[int] = None, train_percent_check: float = 1.0, val_percent_check: float = 1.0, test_percent_check: float = 1.0, val_check_interval: Union[float] = 1.0, log_save_interval: int = 100, row_log_interval: int = 10, add_row_log_interval=None, # backward compatible, todo: remove in v0.8.0 distributed_backend: Optional[str] = None, use_amp=False, # backward compatible, todo: remove in v0.8.0 precision: int = 32, print_nan_grads: bool = False, weights_summary: str = 'full', weights_save_path: Optional[str] = None, amp_level: str = 'O1', nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0 num_sanity_val_steps: int = 5, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, profiler: Optional[BaseProfiler] = None, benchmark: bool = False, reload_dataloaders_every_epoch: bool = False, ): r""" Customize every aspect of training via flags Args: logger: Logger (or iterable collection of loggers) for experiment tracking. Example:: from pytorch_lightning.loggers import TensorBoardLogger # default logger used by trainer logger = TensorBoardLogger( save_dir=os.getcwd(), version=self.slurm_job_id, name='lightning_logs' ) Trainer(logger=logger) checkpoint_callback: Callback for checkpointing. Example:: from pytorch_lightning.callbacks import ModelCheckpoint # default used by the Trainer checkpoint_callback = ModelCheckpoint( filepath=os.getcwd(), save_best_only=True, verbose=True, monitor='val_loss', mode='min', prefix='' ) trainer = Trainer(checkpoint_callback=checkpoint_callback) early_stop_callback: Callback for early stopping. If set to ``True``, then the default callback monitoring ``'val_loss'`` is created. Will raise an error if ``'val_loss'`` is not found. If set to ``False``, then early stopping will be disabled. If set to ``None``, then the default callback monitoring ``'val_loss'`` is created. If ``'val_loss'`` is not found will work as if early stopping is disabled. Default: ``None``. Example:: from pytorch_lightning.callbacks import EarlyStopping # default used by the Trainer early_stop_callback = EarlyStopping( monitor='val_loss', patience=3, strict=False, verbose=False, mode='min' ) trainer = Trainer(early_stop_callback=early_stop_callback) callbacks: Add a list of callbacks. Example:: from pytorch_lightning.callbacks import Callback class PrintCallback(Callback): def on_train_start(self): print("Training is started!") def on_train_end(self): print(f"Training is done. The logs are: {self.trainer.logs}") # a list of callbacks callbacks = [PrintCallback()] trainer = Trainer(callbacks=callbacks) default_save_path: Default path for logs and weights when no logger/ckpt_callback passed Example:: # default used by the Trainer trainer = Trainer(default_save_path=os.getcwd()) gradient_clip_val: 0 means don't clip. Example:: # default used by the Trainer trainer = Trainer(gradient_clip_val=0.0) gradient_clip: .. warning: .. deprecated:: 0.5.0 Use `gradient_clip_val` instead. Will remove 0.8.0. process_position: orders the tqdm bar when running multiple models on same machine. Example:: # default used by the Trainer trainer = Trainer(process_position=0) num_nodes: number of GPU nodes for distributed training. Example:: # default used by the Trainer trainer = Trainer(num_nodes=1) # to train on 8 nodes trainer = Trainer(num_nodes=8) nb_gpu_nodes: ..warning:: .. deprecated:: 0.5.0 Use `num_nodes` instead. Will remove 0.8.0. gpus: Which GPUs to train on. Example:: # default used by the Trainer (ie: train on CPU) trainer = Trainer(gpus=None) # int: train on 2 gpus trainer = Trainer(gpus=2) # list: train on GPUs 1, 4 (by bus ordering) trainer = Trainer(gpus=[1, 4]) trainer = Trainer(gpus='1, 4') # equivalent # -1: train on all gpus trainer = Trainer(gpus=-1) trainer = Trainer(gpus='-1') # equivalent # combine with num_nodes to train on multiple GPUs across nodes trainer = Trainer(gpus=2, num_nodes=4) # uses 8 gpus in total num_tpu_cores: How many TPU cores to train on (1 or 8). A single TPU v2 or v3 has 8 cores. A TPU pod has up to 2048 cores. A slice of a POD means you get as many cores as you request. You MUST use DistributedDataSampler with your dataloader for this to work. Your effective batch size is batch_size * total tpu cores. This parameter can be either 1 or 8. Example:: # your_trainer_file.py # default used by the Trainer (ie: train on CPU) trainer = Trainer(num_tpu_cores=None) # int: train on a single core trainer = Trainer(num_tpu_cores=1) # int: train on all cores few cores trainer = Trainer(num_tpu_cores=8) # for 8+ cores must submit via xla script with # a max of 8 cores specified. The XLA script # will duplicate script onto each TPU in the POD trainer = Trainer(num_tpu_cores=8) # -1: train on all available TPUs trainer = Trainer(num_tpu_cores=-1) To train on more than 8 cores (ie: a POD), submit this script using the xla_dist script. Example:: $ python -m torch_xla.distributed.xla_dist --tpu=$TPU_POD_NAME --conda-env=torch-xla-nightly --env=XLA_USE_BF16=1 -- python your_trainer_file.py log_gpu_memory: None, 'min_max', 'all'. Might slow performance because it uses the output of nvidia-smi. Example:: # default used by the Trainer trainer = Trainer(log_gpu_memory=None) # log all the GPUs (on master node only) trainer = Trainer(log_gpu_memory='all') # log only the min and max memory on the master node trainer = Trainer(log_gpu_memory='min_max') show_progress_bar: If true shows tqdm progress bar Example:: # default used by the Trainer trainer = Trainer(show_progress_bar=True) progress_bar_refresh_rate: How often to refresh progress bar (in steps) overfit_pct: uses this much data of all datasets. Example:: # default used by the Trainer trainer = Trainer(overfit_pct=0.0) # use only 1% of the train, test, val datasets trainer = Trainer(overfit_pct=0.01) track_grad_norm: -1 no tracking. Otherwise tracks that norm Example:: # default used by the Trainer trainer = Trainer(track_grad_norm=-1) # track the 2-norm trainer = Trainer(track_grad_norm=2) check_val_every_n_epoch: Check val every n train epochs. Example:: # default used by the Trainer trainer = Trainer(check_val_every_n_epoch=1) # run val loop every 10 training epochs trainer = Trainer(check_val_every_n_epoch=10) fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). Example:: # default used by the Trainer trainer = Trainer(fast_dev_run=False) # runs 1 train, val, test batch and program ends trainer = Trainer(fast_dev_run=True) accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict. Example:: # default used by the Trainer (no accumulation) trainer = Trainer(accumulate_grad_batches=1) # accumulate every 4 batches (effective batch size is batch*4) trainer = Trainer(accumulate_grad_batches=4) # no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20}) max_epochs: Stop training once this number of epochs is reached. Example:: # default used by the Trainer trainer = Trainer(max_epochs=1000) max_nb_epochs: .. warning:: .. deprecated:: 0.5.0 Use `max_epochs` instead. Will remove 0.8.0. min_epochs: Force training for at least these many epochs Example:: # default used by the Trainer trainer = Trainer(min_epochs=1) min_nb_epochs: .. warning:: .. deprecated:: 0.5.0 Use `min_nb_epochs` instead. Will remove 0.8.0. max_steps: Stop training after this number of steps. Disabled by default (None). Training will stop if max_steps or max_epochs have reached (earliest). Example:: # Stop after 100 steps trainer = Trainer(max_steps=100) min_steps: Force training for at least these number of steps. Disabled by default (None). Trainer will train model for at least min_steps or min_epochs (latest). Example:: # Run at least for 100 steps (disable min_epochs) trainer = Trainer(min_steps=100, min_epochs=0) train_percent_check: How much of training dataset to check. Useful when debugging or testing something that happens at the end of an epoch. Example:: # default used by the Trainer trainer = Trainer(train_percent_check=1.0) # run through only 25% of the training set each epoch trainer = Trainer(train_percent_check=0.25) val_percent_check: How much of validation dataset to check. Useful when debugging or testing something that happens at the end of an epoch. Example:: # default used by the Trainer trainer = Trainer(val_percent_check=1.0) # run through only 25% of the validation set each epoch trainer = Trainer(val_percent_check=0.25) test_percent_check: How much of test dataset to check. Useful when debugging or testing something that happens at the end of an epoch. Example:: # default used by the Trainer trainer = Trainer(test_percent_check=1.0) # run through only 25% of the test set each epoch trainer = Trainer(test_percent_check=0.25) val_check_interval: How often within one training epoch to check the validation set If float, % of tng epoch. If int, check every n batch Example:: # default used by the Trainer trainer = Trainer(val_check_interval=1.0) # check validation set 4 times during a training epoch trainer = Trainer(val_check_interval=0.25) # check validation set every 1000 training batches # use this when using iterableDataset and your dataset has no length # (ie: production cases with streaming data) trainer = Trainer(val_check_interval=1000) log_save_interval: Writes logs to disk this often Example:: # default used by the Trainer trainer = Trainer(log_save_interval=100) row_log_interval: How often to add logging rows (does not write to disk) Example:: # default used by the Trainer trainer = Trainer(row_log_interval=10) add_row_log_interval: .. warning:: .. deprecated:: 0.5.0 Use `row_log_interval` instead. Will remove 0.8.0. distributed_backend: The distributed backend to use. Options: 'dp', 'ddp', 'ddp2'. Example:: # default used by the Trainer trainer = Trainer(distributed_backend=None) # dp = DataParallel (split a batch onto k gpus on same machine). trainer = Trainer(gpus=2, distributed_backend='dp') # ddp = DistributedDataParallel # Each gpu trains by itself on a subset of the data. # Gradients sync across all gpus and all machines. trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp') # ddp2 = DistributedDataParallel + dp # behaves like dp on every node # syncs gradients across nodes like ddp # useful for things like increasing the number of negative samples trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2') use_amp: .. warning:: .. deprecated:: 0.6.1 Use `precision` instead. Will remove 0.8.0. precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs. If used on TPU will use torch.bfloat16 but tensor printing will still show torch.float32. Example:: # default used by the Trainer trainer = Trainer(precision=32) # 16-bit precision trainer = Trainer(precision=16) # one day trainer = Trainer(precision=8|4|2) print_nan_grads: Prints gradients with nan values Example:: # default used by the Trainer trainer = Trainer(print_nan_grads=False) weights_summary: Prints a summary of the weights when training begins. Options: 'full', 'top', None. Example:: # default used by the Trainer (ie: print all weights) trainer = Trainer(weights_summary='full') # print only the top level modules trainer = Trainer(weights_summary='top') # don't print a summary trainer = Trainer(weights_summary=None) weights_save_path: Where to save weights if specified. Example:: # default used by the Trainer trainer = Trainer(weights_save_path=os.getcwd()) # save to your custom path trainer = Trainer(weights_save_path='my/path') # if checkpoint callback used, then overrides the weights path # **NOTE: this saves weights to some/path NOT my/path checkpoint_callback = ModelCheckpoint(filepath='some/path') trainer = Trainer( checkpoint_callback=checkpoint_callback, weights_save_path='my/path' ) amp_level: The optimization level to use (O1, O2, etc...). Check nvidia docs for level (https://nvidia.github.io/apex/amp.html#opt-levels) Example:: # default used by the Trainer trainer = Trainer(amp_level='O1') num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine. This catches any bugs in your validation without having to wait for the first validation check. The Trainer uses 5 steps by default. Turn it off or modify it here. Example:: # default used by the Trainer trainer = Trainer(num_sanity_val_steps=5) # turn it off trainer = Trainer(num_sanity_val_steps=0) nb_sanity_val_steps: .. warning:: .. deprecated:: 0.5.0 Use `num_sanity_val_steps` instead. Will remove 0.8.0. truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of a much longer sequence If this is enabled, your batches will automatically get truncated and the trainer will apply Truncated Backprop to it. Make sure your batches have a sequence dimension. (`Williams et al. "An efficient gradient-based algorithm for on-line training of recurrent network trajectories." <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_) Example:: # default used by the Trainer (ie: disabled) trainer = Trainer(truncated_bptt_steps=None) # backprop every 5 steps in a batch trainer = Trainer(truncated_bptt_steps=5) Lightning takes care to split your batch along the time-dimension. .. note:: If you need to modify how the batch is split, override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`. .. note:: Using this feature requires updating your LightningModule's :meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg. resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.k Example:: # default used by the Trainer trainer = Trainer(resume_from_checkpoint=None) # resume from a specific checkpoint trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt') profiler: To profile individual steps during training and assist in identifying bottlenecks. Example:: from pytorch_lightning.profiler import Profiler, AdvancedProfiler # default used by the Trainer trainer = Trainer(profiler=None) # to profile standard training events trainer = Trainer(profiler=True) # equivalent to profiler=True profiler = Profiler() trainer = Trainer(profiler=profiler) # advanced profiler for function-level stats profiler = AdvancedProfiler() trainer = Trainer(profiler=profiler) reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch benchmark (bool): If true enables cudnn.benchmark. This flag is likely to increase the speed of your system if your input sizes don't change. However, if it does, then it will likely make your system slower. The speedup comes from allowing the cudnn auto-tuner to find the best algorithm for the hardware `[see discussion here] <https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936>`_. .. warning:: Following arguments become deprecated and they will be removed in v0.8.0: - `nb_sanity_val_steps` """ # Init callbacks self.callbacks = callbacks self.on_init_start() # benchmarking self.benchmark = benchmark if benchmark: torch.backends.cudnn.benchmark = True # Transfer params # Backward compatibility if nb_gpu_nodes is not None: warnings.warn("`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) if not num_nodes: # in case you did not set the proper value num_nodes = nb_gpu_nodes self.num_gpu_nodes = num_nodes self.log_gpu_memory = log_gpu_memory # Backward compatibility if gradient_clip is not None: warnings.warn("`gradient_clip` has renamed to `gradient_clip_val` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) if not gradient_clip_val: # in case you did not set the proper value gradient_clip_val = gradient_clip self.gradient_clip_val = gradient_clip_val self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.progress_bar_refresh_rate = progress_bar_refresh_rate self.check_val_every_n_epoch = check_val_every_n_epoch self.track_grad_norm = track_grad_norm self.on_gpu = True if (gpus and torch.cuda.is_available()) else False # tpu config self.on_tpu = num_tpu_cores is not None self.num_tpu_cores = num_tpu_cores assert num_tpu_cores in [1, 8, None], 'num_tpu_cores can only be 1 or 8' self.process_position = process_position self.weights_summary = weights_summary # Backward compatibility if max_nb_epochs is not None: warnings.warn("`max_nb_epochs` has renamed to `max_epochs` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) if not max_epochs: # in case you did not set the proper value max_epochs = max_nb_epochs self.max_epochs = max_epochs # Backward compatibility if min_nb_epochs is not None: warnings.warn("`min_nb_epochs` has renamed to `min_epochs` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) if not min_epochs: # in case you did not set the proper value min_epochs = min_nb_epochs self.min_epochs = min_epochs self.max_steps = max_steps self.min_steps = min_steps # Backward compatibility if nb_sanity_val_steps is not None: warnings.warn("`nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) if not num_sanity_val_steps: # in case you did not set the proper value num_sanity_val_steps = nb_sanity_val_steps self.num_sanity_val_steps = num_sanity_val_steps self.print_nan_grads = print_nan_grads self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint self.shown_warnings = set() self.fast_dev_run = fast_dev_run if self.fast_dev_run: self.num_sanity_val_steps = 1 self.max_epochs = 1 m = ''' Running in fast_dev_run mode: will run a full train, val loop using a single batch ''' log.info(m) # set default save path if user didn't provide one self.default_save_path = default_save_path if self.default_save_path is None: self.default_save_path = os.getcwd() # training bookeeping self.total_batch_idx = 0 self.running_loss = [] self.avg_loss = 0 self.batch_idx = 0 self.tqdm_metrics = {} self.callback_metrics = {} self.num_val_batches = 0 self.num_training_batches = 0 self.num_test_batches = 0 self.train_dataloader = None self.test_dataloaders = None self.val_dataloaders = None # training state self.model = None self.testing = False self.disable_validation = False self.lr_schedulers = [] self.optimizers = None self.global_step = 0 self.current_epoch = 0 self.total_batches = 0 # configure logger self.configure_logger(logger) # configure profiler if profiler is True: profiler = Profiler() self.profiler = profiler or PassThroughProfiler() # configure early stop callback # creates a default one if none passed in self.configure_early_stopping(early_stop_callback) self.reduce_lr_on_plateau_scheduler = None # configure checkpoint callback self.checkpoint_callback = checkpoint_callback self.weights_save_path = weights_save_path # accumulated grads self.configure_accumulated_gradients(accumulate_grad_batches) # allow int, string and gpu list self.data_parallel_device_ids = parse_gpu_ids(gpus) self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids) # tpu state flags self.use_tpu = False self.tpu_local_core_rank = None self.tpu_global_core_rank = None # distributed backend choice self.use_ddp = False self.use_ddp2 = False self.use_dp = False self.single_gpu = False self.distributed_backend = distributed_backend self.set_distributed_mode(distributed_backend, num_nodes) # override dist backend when using tpus if self.on_tpu: self.init_tpu() self.current_tpu_idx = None # init flags for SLURM+ddp to work self.proc_rank = 0 self.world_size = 1 self.node_rank = 0 self.configure_slurm_ddp(num_nodes) # nvidia setup self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) # can't init progress bar here because starting a new process # means the progress_bar won't survive pickling self.show_progress_bar = show_progress_bar # logging self.log_save_interval = log_save_interval self.val_check_interval = val_check_interval # backward compatibility if add_row_log_interval is not None: warnings.warn("`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) if not row_log_interval: # in case you did not set the proper value row_log_interval = add_row_log_interval self.row_log_interval = row_log_interval # how much of the data to use self.determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct) # 16 bit mixed precision training using apex self.amp_level = amp_level self.precision = precision if self.precision == 16: use_amp = True self.init_amp(use_amp) # Callback system self.on_init_end()
def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = False, callbacks: Optional[List[Callback]] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, process_position: int = 0, num_nodes: int = 1, num_processes: int = 1, gpus: Optional[Union[List[int], str, int]] = None, auto_select_gpus: bool = False, tpu_cores: Optional[Union[List[int], int]] = None, log_gpu_memory: Optional[str] = None, progress_bar_refresh_rate: int = 1, overfit_pct: float = 0.0, track_grad_norm: int = -1, check_val_every_n_epoch: int = 1, fast_dev_run: bool = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_epochs: int = 1000, min_epochs: int = 1, max_steps: Optional[int] = None, min_steps: Optional[int] = None, train_percent_check: float = 1.0, val_percent_check: float = 1.0, test_percent_check: float = 1.0, val_check_interval: float = 1.0, log_save_interval: int = 100, row_log_interval: int = 10, add_row_log_interval=None, # backward compatible, todo: remove in v0.8.0 distributed_backend: Optional[str] = None, precision: int = 32, print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0 weights_summary: Optional[str] = 'full', weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, profiler: Optional[Union[BaseProfiler, bool]] = None, benchmark: bool = False, deterministic: bool = False, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, num_tpu_cores: Optional[ int] = None, # backward compatible, todo: remove in v0.9.0 amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0 default_save_path=None, # backward compatible, todo: remove in v0.8.0 gradient_clip=None, # backward compatible, todo: remove in v0.8.0 nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0 max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 use_amp=None, # backward compatible, todo: remove in v0.9.0 show_progress_bar=None, # backward compatible, todo: remove in v0.9.0 nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0 ): r""" Customize every aspect of training via flags Args: logger: Logger (or iterable collection of loggers) for experiment tracking. checkpoint_callback: Callback for checkpointing. early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`): callbacks: Add a list of callbacks. default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed default_save_path: .. warning:: .. deprecated:: 0.7.3 Use `default_root_dir` instead. Will remove 0.9.0. gradient_clip_val: 0 means don't clip. gradient_clip: .. warning:: .. deprecated:: 0.7.0 Use `gradient_clip_val` instead. Will remove 0.9.0. process_position: orders the progress bar when running multiple models on same machine. num_nodes: number of GPU nodes for distributed training. nb_gpu_nodes: .. warning:: .. deprecated:: 0.7.0 Use `num_nodes` instead. Will remove 0.9.0. gpus: Which GPUs to train on. auto_select_gpus: If enabled and `gpus` is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in "exclusive mode", such that only one process at a time can access them. tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1] num_tpu_cores: How many TPU cores to train on (1 or 8) .. warning:: .. deprecated:: 0.7.6. Will remove 0.9.0. log_gpu_memory: None, 'min_max', 'all'. Might slow performance show_progress_bar: .. warning:: .. deprecated:: 0.7.2 Set `progress_bar_refresh_rate` to positive integer to enable. Will remove 0.9.0. progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`. overfit_pct: How much of training-, validation-, and test dataset to check. track_grad_norm: -1 no tracking. Otherwise tracks that norm check_val_every_n_epoch: Check val every n train epochs. fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict. max_epochs: Stop training once this number of epochs is reached. max_nb_epochs: .. warning:: .. deprecated:: 0.7.0 Use `max_epochs` instead. Will remove 0.9.0. min_epochs: Force training for at least these many epochs min_nb_epochs: .. warning:: .. deprecated:: 0.7.0 Use `min_epochs` instead. Will remove 0.9.0. max_steps: Stop training after this number of steps. Disabled by default (None). min_steps: Force training for at least these number of steps. Disabled by default (None). train_percent_check: How much of training dataset to check. val_percent_check: How much of validation dataset to check. test_percent_check: How much of test dataset to check. val_check_interval: How often within one training epoch to check the validation set log_save_interval: Writes logs to disk this often row_log_interval: How often to add logging rows (does not write to disk) add_row_log_interval: .. warning:: .. deprecated:: 0.7.0 Use `row_log_interval` instead. Will remove 0.9.0. distributed_backend: The distributed backend to use. use_amp: .. warning:: .. deprecated:: 0.7.0 Use `precision` instead. Will remove 0.9.0. precision: Full precision (32), half precision (16). print_nan_grads: .. warning:: .. deprecated:: 0.7.2 Has no effect. When detected, NaN grads will be printed automatically. Will remove 0.9.0. weights_summary: Prints a summary of the weights when training begins. weights_save_path: Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in `default_root_dir`. amp_level: The optimization level to use (O1, O2, etc...). num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine. nb_sanity_val_steps: .. warning:: .. deprecated:: 0.7.0 Use `num_sanity_val_steps` instead. Will remove 0.8.0. truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here. profiler: To profile individual steps during training and assist in reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch auto_lr_find: If set to True, will `initially` run a learning rate finder, trying to optimize initial learning for faster convergence. Sets learning rate in self.lr or self.learning_rate in the LightningModule. To use a different key, set a string instead of True with the key name. replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this will toggled automatically ddp is used benchmark: If true enables cudnn.benchmark. deterministic: If true enables cudnn.deterministic terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf. auto_scale_batch_size: If set to True, will `initially` run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either `power` that estimates the batch size through a power search or `binsearch` that estimates the batch size through a binary search. """ super().__init__() self.deterministic = deterministic torch.backends.cudnn.deterministic = self.deterministic if self.deterministic: # fixing non-deterministic part of horovod # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383 os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0) # Init callbacks self.callbacks = callbacks or [] self.on_init_start() # benchmarking self.benchmark = benchmark torch.backends.cudnn.benchmark = self.benchmark # Transfer params self.num_nodes = num_nodes # Backward compatibility, TODO: remove in v0.8.0 if nb_gpu_nodes is not None: rank_zero_warn( "Argument `nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.num_gpu_nodes = nb_gpu_nodes self.log_gpu_memory = log_gpu_memory self.gradient_clip_val = gradient_clip_val # Backward compatibility, TODO: remove in v0.8.0 if gradient_clip is not None: rank_zero_warn( "Argument `gradient_clip` has renamed to `gradient_clip_val` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.gradient_clip = gradient_clip self.check_val_every_n_epoch = check_val_every_n_epoch self.track_grad_norm = track_grad_norm self.on_gpu = True if (gpus and torch.cuda.is_available()) else False # tpu config if num_tpu_cores is not None: rank_zero_warn( "Argument `num_tpu_cores` is now set by `tpu_cores` since v0.7.6" " and this argument will be removed in v0.9.0", DeprecationWarning) if tpu_cores is None: tpu_cores = num_tpu_cores self.on_tpu = tpu_cores is not None self.tpu_cores = tpu_cores assert self.tpu_cores in (1, 8, None) or ( isinstance(self.tpu_cores, (list, tuple, set)) and len(self.tpu_cores) == 1), '`tpu_cores` can only be 1, 8 or [<1-8>]' self.tpu_id = tpu_cores[0] if isinstance(tpu_cores, list) else None if num_processes != 1 and distributed_backend != "ddp_cpu": rank_zero_warn( "num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it." ) self.num_processes = num_processes self.weights_summary = weights_summary self.max_epochs = max_epochs # Backward compatibility, TODO: remove in v0.8.0 if max_nb_epochs is not None: rank_zero_warn( "Argument `max_nb_epochs` has renamed to `max_epochs` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.max_nb_epochs = max_nb_epochs self.min_epochs = min_epochs # Backward compatibility, TODO: remove in v0.8.0 if min_nb_epochs is not None: rank_zero_warn( "Argument `min_nb_epochs` has renamed to `min_epochs` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.min_nb_epochs = min_nb_epochs self.max_steps = max_steps self.min_steps = min_steps self.num_sanity_val_steps = num_sanity_val_steps # Backward compatibility, TODO: remove in v0.8.0 if nb_sanity_val_steps is not None: rank_zero_warn( "Argument `nb_sanity_val_steps` has renamed to " "`num_sanity_val_steps` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.nb_sanity_val_steps = nb_sanity_val_steps # Backward compatibility, TODO: remove in v0.9.0 if print_nan_grads: rank_zero_warn( "Argument `print_nan_grads` has no effect and will be removed in v0.9.0." " NaN grads will be printed automatically when detected.", DeprecationWarning) self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.auto_lr_find = auto_lr_find self.auto_scale_batch_size = auto_scale_batch_size self._is_data_prepared = False self.replace_sampler_ddp = replace_sampler_ddp self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint self.terminate_on_nan = terminate_on_nan self.shown_warnings = set() self.fast_dev_run = fast_dev_run if self.fast_dev_run: self.num_sanity_val_steps = 0 self.max_epochs = 1 log.info('Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch') # set default save path if user didn't provide one self.default_root_dir = default_root_dir # Backward compatibility, TODO: remove in v0.8.0 if default_save_path is not None: self.default_root_dir = default_save_path if self.default_root_dir is None: self.default_root_dir = os.getcwd() # training bookeeping self.total_batch_idx = 0 self.running_loss = TensorRunningAccum(window_length=20) self.batch_idx = 0 self.progress_bar_metrics = {} self.callback_metrics = {} self.num_val_batches = 0 self.num_training_batches = 0 self.num_test_batches = 0 self.train_dataloader = None self.test_dataloaders = None self.val_dataloaders = None # training state self.model = None self.testing = False self.disable_validation = False self.lr_schedulers = [] self.optimizers = None self.optimizer_frequencies = [] self.global_step = 0 self.current_epoch = 0 self.interrupted = False # configure logger self.configure_logger(logger) # configure profiler if profiler is True: profiler = SimpleProfiler() self.profiler = profiler or PassThroughProfiler() # configure early stop callback # creates a default one if none passed in self.configure_early_stopping(early_stop_callback) # configure checkpoint callback self.checkpoint_callback = checkpoint_callback self.weights_save_path = weights_save_path # accumulated grads self.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) # for gpus allow int, string and gpu list if auto_select_gpus and isinstance(gpus, int): self.gpus = pick_multiple_gpus(gpus) else: self.gpus = gpus self.data_parallel_device_ids = parse_gpu_ids(self.gpus) self.root_gpu = determine_root_gpu_device( self.data_parallel_device_ids) self.root_device = torch.device("cpu") # tpu state flags self.use_tpu = False self.tpu_local_core_rank = None self.tpu_global_core_rank = None # distributed backend choice self.distributed_backend = distributed_backend self.set_distributed_mode(distributed_backend) # override dist backend when using tpus if self.on_tpu: self.init_tpu() # init flags for SLURM+ddp to work self.proc_rank = 0 self.world_size = 1 self.configure_slurm_ddp(self.num_nodes) self.node_rank = self.determine_ddp_node_rank() # nvidia setup self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) # backward compatibility if show_progress_bar is not None: self.show_progress_bar = show_progress_bar self._progress_bar_callback = self.configure_progress_bar( progress_bar_refresh_rate, process_position) # logging self.log_save_interval = log_save_interval self.val_check_interval = val_check_interval # backward compatibility if add_row_log_interval is not None: rank_zero_warn( "`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) if not row_log_interval: # in case you did not set the proper value row_log_interval = add_row_log_interval self.row_log_interval = row_log_interval # how much of the data to use self.overfit_pct = overfit_pct self.determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct) # AMP init # These are the only lines needed after v0.8.0 # we wrap the user's forward with autocast and give it back at the end of fit self.autocast_original_forward = None self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr( torch.cuda.amp, "autocast") self.precision = precision self.scaler = None # TODO: remove for v0.8.0 self.amp_level = amp_level self.init_amp(use_amp) self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv( 'KAGGLE_URL_BASE') # Callback system self.on_init_end()
def __init__( self, logger=True, checkpoint_callback=True, early_stop_callback=None, default_save_path=None, gradient_clip_val=0, gradient_clip=None, # backward compatible, todo: remove in v0.8.0 process_position=0, nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0 num_nodes=1, gpus=None, log_gpu_memory=None, show_progress_bar=True, overfit_pct=0.0, track_grad_norm=-1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=1, max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 max_epochs=1000, min_epochs=1, train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, val_check_interval=1.0, log_save_interval=100, row_log_interval=10, add_row_log_interval=None, # backward compatible, todo: remove in v0.8.0 distributed_backend=None, use_amp=False, print_nan_grads=False, weights_summary='full', weights_save_path=None, amp_level='O1', nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0 num_sanity_val_steps=5, truncated_bptt_steps=None, resume_from_checkpoint=None, profiler=None): r""" Customize every aspect of training via flags Args: logger (:class:`.Logger`): Logger for experiment tracking. Example:: from pytorch_lightning.loggers import TensorBoardLogger # default logger used by trainer logger = TensorBoardLogger( save_dir=os.getcwd(), version=self.slurm_job_id, name='lightning_logs' ) Trainer(logger=logger) checkpoint_callback (:class:`CheckpointCallback`): Callback for checkpointing. Example:: from pytorch_lightning.callbacks import ModelCheckpoint # default used by the Trainer checkpoint_callback = ModelCheckpoint( filepath=os.getcwd(), save_best_only=True, verbose=True, monitor='val_loss', mode='min', prefix='' ) trainer = Trainer(checkpoint_callback=checkpoint_callback) early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping. If set to ``True``, then the default callback monitoring ``'val_loss'`` is created. Will raise an error if ``'val_loss'`` is not found. If set to ``False``, then early stopping will be disabled. If set to ``None``, then the default callback monitoring ``'val_loss'`` is created. If ``'val_loss'`` is not found will work as if early stopping is disabled. Default: ``None``. Example:: from pytorch_lightning.callbacks import EarlyStopping # default used by the Trainer early_stop_callback = EarlyStopping( monitor='val_loss', patience=3, strict=False, verbose=False, mode='min' ) trainer = Trainer(early_stop_callback=early_stop_callback) default_save_path (str): Default path for logs and weights when no logger/ckpt_callback passed Example:: # default used by the Trainer trainer = Trainer(default_save_path=os.getcwd()) gradient_clip_val (float): 0 means don't clip. Example:: # default used by the Trainer trainer = Trainer(gradient_clip_val=0.0) gradient_clip (int): .. deprecated:: 0.5.0 Use `gradient_clip_val` instead. Will remove 0.8.0. process_position (int): orders the tqdm bar when running multiple models on same machine. Example:: # default used by the Trainer trainer = Trainer(process_position=0) num_nodes (int): number of GPU nodes for distributed training. Example:: # default used by the Trainer trainer = Trainer(num_nodes=1) # to train on 8 nodes trainer = Trainer(num_nodes=8) nb_gpu_nodes (int): .. deprecated:: 0.5.0 Use `num_nodes` instead. Will remove 0.8.0. gpus (list|str|int): Which GPUs to train on. Example:: # default used by the Trainer (ie: train on CPU) trainer = Trainer(gpus=None) # int: train on 2 gpus trainer = Trainer(gpus=2) # list: train on GPUs 1, 4 (by bus ordering) trainer = Trainer(gpus=[1, 4]) trainer = Trainer(gpus='1, 4') # equivalent # -1: train on all gpus trainer = Trainer(gpus=-1) trainer = Trainer(gpus='-1') # equivalent # combine with num_nodes to train on multiple GPUs across nodes trainer = Trainer(gpus=2, num_nodes=4) # uses 8 gpus in total log_gpu_memory (str): None, 'min_max', 'all'. Might slow performance because it uses the output of nvidia-smi. Example:: # default used by the Trainer trainer = Trainer(log_gpu_memory=None) # log all the GPUs (on master node only) trainer = Trainer(log_gpu_memory='all') # log only the min and max memory on the master node trainer = Trainer(log_gpu_memory='min_max') show_progress_bar (bool): If true shows tqdm progress bar Example:: # default used by the Trainer trainer = Trainer(show_progress_bar=True) overfit_pct (float): uses this much data of all datasets. Example:: # default used by the Trainer trainer = Trainer(overfit_pct=0.0) # use only 1% of the train, test, val datasets trainer = Trainer(overfit_pct=0.01) track_grad_norm (int): -1 no tracking. Otherwise tracks that norm Example:: # default used by the Trainer trainer = Trainer(track_grad_norm=-1) # track the 2-norm trainer = Trainer(track_grad_norm=2) check_val_every_n_epoch (int): Check val every n train epochs. Example:: # default used by the Trainer trainer = Trainer(check_val_every_n_epoch=1) # run val loop every 10 training epochs trainer = Trainer(check_val_every_n_epoch=10) fast_dev_run (bool): runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). Example:: # default used by the Trainer trainer = Trainer(fast_dev_run=False) # runs 1 train, val, test batch and program ends trainer = Trainer(fast_dev_run=True) accumulate_grad_batches (int|dict): Accumulates grads every k batches or as set up in the dict. Example:: # default used by the Trainer (no accumulation) trainer = Trainer(accumulate_grad_batches=1) # accumulate every 4 batches (effective batch size is batch*4) trainer = Trainer(accumulate_grad_batches=4) # no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20}) max_epochs (int): Stop training once this number of epochs is reached. Example:: # default used by the Trainer trainer = Trainer(max_epochs=1000) max_nb_epochs (int): .. deprecated:: 0.5.0 Use `max_epochs` instead. Will remove 0.8.0. min_epochs (int): Force training for at least these many epochs Example:: # default used by the Trainer trainer = Trainer(min_epochs=1) min_nb_epochs (int): .. deprecated:: 0.5.0 Use `min_nb_epochs` instead. Will remove 0.8.0. train_percent_check (int): How much of training dataset to check. Useful when debugging or testing something that happens at the end of an epoch. Example:: # default used by the Trainer trainer = Trainer(train_percent_check=1.0) # run through only 25% of the training set each epoch trainer = Trainer(train_percent_check=0.25) val_percent_check (int): How much of validation dataset to check. Useful when debugging or testing something that happens at the end of an epoch. Example:: # default used by the Trainer trainer = Trainer(val_percent_check=1.0) # run through only 25% of the validation set each epoch trainer = Trainer(val_percent_check=0.25) test_percent_check (int): How much of test dataset to check. Useful when debugging or testing something that happens at the end of an epoch. Example:: # default used by the Trainer trainer = Trainer(test_percent_check=1.0) # run through only 25% of the test set each epoch trainer = Trainer(test_percent_check=0.25) val_check_interval (float|int): How often within one training epoch to check the validation set If float, % of tng epoch. If int, check every n batch Example:: # default used by the Trainer trainer = Trainer(val_check_interval=1.0) # check validation set 4 times during a training epoch trainer = Trainer(val_check_interval=0.25) # check validation set every 1000 training batches # use this when using iterableDataset and your dataset has no length # (ie: production cases with streaming data) trainer = Trainer(val_check_interval=1000) log_save_interval (int): Writes logs to disk this often Example:: # default used by the Trainer trainer = Trainer(log_save_interval=100) row_log_interval (int): How often to add logging rows (does not write to disk) Example:: # default used by the Trainer trainer = Trainer(row_log_interval=10) add_row_log_interval (int): .. deprecated:: 0.5.0 Use `row_log_interval` instead. Will remove 0.8.0. distributed_backend (str): The distributed backend to use. Options: 'dp', 'ddp', 'ddp2'. Example:: # default used by the Trainer trainer = Trainer(distributed_backend=None) # dp = DataParallel (split a batch onto k gpus on same machine). trainer = Trainer(gpus=2, distributed_backend='dp') # ddp = DistributedDataParallel # Each gpu trains by itself on a subset of the data. # Gradients sync across all gpus and all machines. trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp') # ddp2 = DistributedDataParallel + dp # behaves like dp on every node # syncs gradients across nodes like ddp # useful for things like increasing the number of negative samples trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2') use_amp (bool): If true uses apex for 16bit precision Example:: # default used by the Trainer trainer = Trainer(use_amp=False) print_nan_grads (bool): Prints gradients with nan values Example:: # default used by the Trainer trainer = Trainer(print_nan_grads=False) weights_summary (str): Prints a summary of the weights when training begins. Options: 'full', 'top', None. Example:: # default used by the Trainer (ie: print all weights) trainer = Trainer(weights_summary='full') # print only the top level modules trainer = Trainer(weights_summary='top') # don't print a summary trainer = Trainer(weights_summary=None) weights_save_path (str): Where to save weights if specified. Example:: # default used by the Trainer trainer = Trainer(weights_save_path=os.getcwd()) # save to your custom path trainer = Trainer(weights_save_path='my/path') # if checkpoint callback used, then overrides the weights path # **NOTE: this saves weights to some/path NOT my/path checkpoint_callback = ModelCheckpoint(filepath='some/path') trainer = Trainer( checkpoint_callback=checkpoint_callback, weights_save_path='my/path' ) amp_level (str): The optimization level to use (O1, O2, etc...). Check nvidia docs for level (https://nvidia.github.io/apex/amp.html#opt-levels) Example:: # default used by the Trainer trainer = Trainer(amp_level='O1') num_sanity_val_steps (int): Sanity check runs n batches of val before starting the training routine. This catches any bugs in your validation without having to wait for the first validation check. The Trainer uses 5 steps by default. Turn it off or modify it here. Example:: # default used by the Trainer trainer = Trainer(num_sanity_val_steps=5) # turn it off trainer = Trainer(num_sanity_val_steps=0) nb_sanity_val_steps (int): .. deprecated:: 0.5.0 Use `num_sanity_val_steps` instead. Will remove 0.8.0. truncated_bptt_steps (int): Truncated back prop breaks performs backprop every k steps of a much longer sequence If this is enabled, your batches will automatically get truncated and the trainer will apply Truncated Backprop to it. Make sure your batches have a sequence dimension. (`Williams et al. "An efficient gradient-based algorithm for on-line training of recurrent network trajectories." <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_) Example:: # default used by the Trainer (ie: disabled) trainer = Trainer(truncated_bptt_steps=None) # backprop every 5 steps in a batch trainer = Trainer(truncated_bptt_steps=5) Using this feature requires updating your LightningModule's `training_step()` to include a `hiddens` arg. resume_from_checkpoint (str): To resume training from a specific checkpoint pass in the path here.k Example:: # default used by the Trainer trainer = Trainer(resume_from_checkpoint=None) # resume from a specific checkpoint trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt') profiler (BaseProfiler): To profile individual steps during training and assist in identifying bottlenecks. Example:: from pytorch_lightning.profiler import Profiler, AdvancedProfiler # default used by the Trainer trainer = Trainer(profiler=None) # to profile standard training events trainer = Trainer(profiler=True) # equivalent to profiler=True profiler = Profiler() trainer = Trainer(profiler=profiler) # advanced profiler for function-level stats profiler = AdvancedProfiler() trainer = Trainer(profiler=profiler) .. warning:: Following arguments become deprecated and they will be removed in v0.8.0: - `nb_sanity_val_steps` """ # Transfer params # Backward compatibility if nb_gpu_nodes is not None: warnings.warn( "`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0" " and will be removed in v0.8.0", DeprecationWarning) if not num_nodes: # in case you did not set the proper value num_nodes = nb_gpu_nodes self.num_gpu_nodes = num_nodes self.log_gpu_memory = log_gpu_memory # Backward compatibility if gradient_clip is not None: warnings.warn( "`gradient_clip` has renamed to `gradient_clip_val` since v0.5.0" " and will be removed in v0.8.0", DeprecationWarning) if not gradient_clip_val: # in case you did not set the proper value gradient_clip_val = gradient_clip self.gradient_clip_val = gradient_clip_val self.check_val_every_n_epoch = check_val_every_n_epoch self.track_grad_norm = track_grad_norm self.on_gpu = True if (gpus and torch.cuda.is_available()) else False self.process_position = process_position self.weights_summary = weights_summary # Backward compatibility if max_nb_epochs is not None: warnings.warn( "`max_nb_epochs` has renamed to `max_epochs` since v0.5.0" " and will be removed in v0.8.0", DeprecationWarning) if not max_epochs: # in case you did not set the proper value max_epochs = max_nb_epochs self.max_epochs = max_epochs # Backward compatibility if min_nb_epochs is not None: warnings.warn( "`min_nb_epochs` has renamed to `min_epochs` since v0.5.0" " and will be removed in v0.8.0", DeprecationWarning) if not min_epochs: # in case you did not set the proper value min_epochs = min_nb_epochs self.min_epochs = min_epochs # Backward compatibility if nb_sanity_val_steps is not None: warnings.warn( "`nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0" " and will be removed in v0.8.0", DeprecationWarning) if not num_sanity_val_steps: # in case you did not set the proper value num_sanity_val_steps = nb_sanity_val_steps self.num_sanity_val_steps = num_sanity_val_steps self.print_nan_grads = print_nan_grads self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint self.shown_warnings = set() self.fast_dev_run = fast_dev_run if self.fast_dev_run: self.num_sanity_val_steps = 1 self.max_epochs = 1 m = ''' Running in fast_dev_run mode: will run a full train, val loop using a single batch ''' log.info(m) # set default save path if user didn't provide one self.default_save_path = default_save_path if self.default_save_path is None: self.default_save_path = os.getcwd() # training bookeeping self.total_batch_idx = 0 self.running_loss = [] self.avg_loss = 0 self.batch_idx = 0 self.tqdm_metrics = {} self.callback_metrics = {} self.num_val_batches = 0 self.num_training_batches = 0 self.num_test_batches = 0 self.get_train_dataloader = None self.get_test_dataloaders = None self.get_val_dataloaders = None self.is_iterable_train_dataloader = False # training state self.model = None self.testing = False self.disable_validation = False self.lr_schedulers = [] self.optimizers = None self.global_step = 0 self.current_epoch = 0 self.total_batches = 0 # configure logger self.configure_logger(logger) # configure profiler if profiler is True: profiler = Profiler() self.profiler = profiler or PassThroughProfiler() # configure early stop callback # creates a default one if none passed in self.configure_early_stopping(early_stop_callback) self.reduce_lr_on_plateau_scheduler = None # configure checkpoint callback self.checkpoint_callback = checkpoint_callback self.weights_save_path = weights_save_path # accumulated grads self.configure_accumulated_gradients(accumulate_grad_batches) # allow int, string and gpu list self.data_parallel_device_ids = parse_gpu_ids(gpus) self.root_gpu = determine_root_gpu_device( self.data_parallel_device_ids) # distributed backend choice self.use_ddp = False self.use_ddp2 = False self.use_dp = False self.single_gpu = False self.distributed_backend = distributed_backend self.set_distributed_mode(distributed_backend, num_nodes) # init flags for SLURM+ddp to work self.proc_rank = 0 self.world_size = 1 self.node_rank = 0 self.configure_slurm_ddp(num_nodes) # nvidia setup self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) # can't init progress bar here because starting a new process # means the progress_bar won't survive pickling self.show_progress_bar = show_progress_bar # logging self.log_save_interval = log_save_interval self.val_check_interval = val_check_interval # backward compatibility if add_row_log_interval is not None: warnings.warn( "`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0" " and will be removed in v0.8.0", DeprecationWarning) if not row_log_interval: # in case you did not set the proper value row_log_interval = add_row_log_interval self.row_log_interval = row_log_interval # how much of the data to use self.determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct) # 16 bit mixed precision training using apex self.amp_level = amp_level self.init_amp(use_amp)
def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = False, callbacks: List[Callback] = [], default_save_path: Optional[str] = None, gradient_clip_val: float = 0, gradient_clip=None, # backward compatible, todo: remove in v0.8.0 process_position: int = 0, nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0 num_nodes: int = 1, gpus: Optional[Union[List[int], str, int]] = None, num_tpu_cores: Optional[int] = None, log_gpu_memory: Optional[str] = None, show_progress_bar: bool = True, progress_bar_refresh_rate: int = 50, overfit_pct: float = 0.0, track_grad_norm: int = -1, check_val_every_n_epoch: int = 1, fast_dev_run: bool = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 max_epochs: int = 1000, min_epochs: int = 1, max_steps: Optional[int] = None, min_steps: Optional[int] = None, train_percent_check: float = 1.0, val_percent_check: float = 1.0, test_percent_check: float = 1.0, val_check_interval: float = 1.0, log_save_interval: int = 100, row_log_interval: int = 10, add_row_log_interval=None, # backward compatible, todo: remove in v0.8.0 distributed_backend: Optional[str] = None, use_amp=False, # backward compatible, todo: remove in v0.9.0 precision: int = 32, print_nan_grads: bool = False, weights_summary: str = 'full', weights_save_path: Optional[str] = None, amp_level: str = 'O1', nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0 num_sanity_val_steps: int = 5, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, profiler: Optional[BaseProfiler] = None, benchmark: bool = False, reload_dataloaders_every_epoch: bool = False, **kwargs): r""" Customize every aspect of training via flags Args: logger: Logger (or iterable collection of loggers) for experiment tracking. checkpoint_callback: Callback for checkpointing. early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`): callbacks: Add a list of callbacks. default_save_path: Default path for logs and weights when no logger/ckpt_callback passed gradient_clip_val: 0 means don't clip. gradient_clip: .. warning:: deprecated 0.7.0 Use `gradient_clip_val` instead. Will remove 0.9.0. process_position: orders the tqdm bar when running multiple models on same machine. num_nodes: number of GPU nodes for distributed training. nb_gpu_nodes: .. warning:: .. deprecated:: 0.7.0 Use `num_nodes` instead. Will remove 0.9.0. gpus: Which GPUs to train on. num_tpu_cores: How many TPU cores to train on (1 or 8). log_gpu_memory: None, 'min_max', 'all'. Might slow performance show_progress_bar: If true shows tqdm progress bar progress_bar_refresh_rate: How often to refresh progress bar (in steps) track_grad_norm: -1 no tracking. Otherwise tracks that norm check_val_every_n_epoch: Check val every n train epochs. fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict. max_epochs: Stop training once this number of epochs is reached. max_nb_epochs: .. warning:: .. deprecated:: 0.7.0 Use `max_epochs` instead. Will remove 0.9.0. min_epochs: Force training for at least these many epochs min_nb_epochs: .. warning:: .. deprecated:: 0.7.0 Use `min_epochs` instead. Will remove 0.9.0. max_steps: Stop training after this number of steps. Disabled by default (None). min_steps: Force training for at least these number of steps. Disabled by default (None). train_percent_check: How much of training dataset to check. val_percent_check: How much of validation dataset to check. test_percent_check: How much of test dataset to check. val_check_interval: How often within one training epoch to check the validation set log_save_interval: Writes logs to disk this often row_log_interval: How often to add logging rows (does not write to disk) add_row_log_interval: .. warning:: .. deprecated:: 0.7.0 Use `row_log_interval` instead. Will remove 0.9.0. distributed_backend: The distributed backend to use. use_amp: .. warning:: .. deprecated:: 0.7.0 Use `precision` instead. Will remove 0.9.0. precision: Full precision (32), half precision (16). print_nan_grads: Prints gradients with nan values weights_summary: Prints a summary of the weights when training begins. weights_save_path: Where to save weights if specified. amp_level: The optimization level to use (O1, O2, etc...). num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine. nb_sanity_val_steps: .. warning:: .. deprecated:: 0.7.0 Use `num_sanity_val_steps` instead. Will remove 0.8.0. truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.k profiler: To profile individual steps during training and assist in reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch benchmark: If true enables cudnn.benchmark. """ # Init callbacks self.callbacks = callbacks self.on_init_start() # benchmarking self.benchmark = benchmark if benchmark: torch.backends.cudnn.benchmark = True # Transfer params self.num_nodes = num_nodes # Backward compatibility, TODO: remove in v0.8.0 if nb_gpu_nodes is not None: warnings.warn( "Argument `nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.num_gpu_nodes = nb_gpu_nodes self.log_gpu_memory = log_gpu_memory self.gradient_clip_val = gradient_clip_val # Backward compatibility, TODO: remove in v0.8.0 if gradient_clip is not None: warnings.warn( "Argument `gradient_clip` has renamed to `gradient_clip_val` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.gradient_clip = gradient_clip self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.progress_bar_refresh_rate = progress_bar_refresh_rate self.check_val_every_n_epoch = check_val_every_n_epoch self.track_grad_norm = track_grad_norm self.on_gpu = True if (gpus and torch.cuda.is_available()) else False # tpu config self.on_tpu = num_tpu_cores is not None self.num_tpu_cores = num_tpu_cores assert num_tpu_cores in [1, 8, None], 'num_tpu_cores can only be 1 or 8' self.process_position = process_position self.weights_summary = weights_summary self.max_epochs = max_epochs # Backward compatibility, TODO: remove in v0.8.0 if max_nb_epochs is not None: warnings.warn( "Argument `max_nb_epochs` has renamed to `max_epochs` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.max_nb_epochs = max_nb_epochs self.min_epochs = min_epochs # Backward compatibility, TODO: remove in v0.8.0 if min_nb_epochs is not None: warnings.warn( "Argument `min_nb_epochs` has renamed to `min_epochs` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.min_nb_epochs = min_nb_epochs self.max_steps = max_steps self.min_steps = min_steps self.num_sanity_val_steps = num_sanity_val_steps # Backward compatibility, TODO: remove in v0.8.0 if nb_sanity_val_steps is not None: warnings.warn( "Argument `nb_sanity_val_steps` has renamed to " "`num_sanity_val_steps` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.nb_sanity_val_steps = nb_sanity_val_steps self.print_nan_grads = print_nan_grads self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint self.shown_warnings = set() self.fast_dev_run = fast_dev_run if self.fast_dev_run: self.num_sanity_val_steps = 1 self.max_epochs = 1 m = ''' Running in fast_dev_run mode: will run a full train, val loop using a single batch ''' log.info(m) # set default save path if user didn't provide one self.default_save_path = default_save_path if self.default_save_path is None: self.default_save_path = os.getcwd() # training bookeeping self.total_batch_idx = 0 self.running_loss = [] self.avg_loss = 0 self.batch_idx = 0 self.tqdm_metrics = {} self.callback_metrics = {} self.num_val_batches = 0 self.num_training_batches = 0 self.num_test_batches = 0 self.train_dataloader = None self.test_dataloaders = None self.val_dataloaders = None # training state self.model = None self.testing = False self.disable_validation = False self.lr_schedulers = [] self.optimizers = None self.global_step = 0 self.current_epoch = 0 self.total_batches = 0 # configure logger self.configure_logger(logger) # configure profiler if profiler is True: profiler = Profiler() self.profiler = profiler or PassThroughProfiler() # configure early stop callback # creates a default one if none passed in self.configure_early_stopping(early_stop_callback) # configure checkpoint callback self.checkpoint_callback = checkpoint_callback self.weights_save_path = weights_save_path # accumulated grads self.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) # allow int, string and gpu list self.gpus = gpus self.data_parallel_device_ids = parse_gpu_ids(self.gpus) self.root_gpu = determine_root_gpu_device( self.data_parallel_device_ids) root_device = (torch.device("cuda", self.root_gpu) if self.root_gpu else torch.device("cpu")) torch.cuda.set_device(root_device) # tpu state flags self.use_tpu = False self.tpu_local_core_rank = None self.tpu_global_core_rank = None # distributed backend choice self.use_ddp = False self.use_ddp2 = False self.use_dp = False self.single_gpu = False self.distributed_backend = distributed_backend self.set_distributed_mode(distributed_backend, self.num_nodes) # override dist backend when using tpus if self.on_tpu: self.init_tpu() self.current_tpu_idx = None # init flags for SLURM+ddp to work self.proc_rank = 0 self.world_size = 1 self.node_rank = 0 self.configure_slurm_ddp(self.num_nodes) # nvidia setup self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) # can't init progress bar here because starting a new process # means the progress_bar won't survive pickling self.show_progress_bar = show_progress_bar # logging self.log_save_interval = log_save_interval self.val_check_interval = val_check_interval # backward compatibility if add_row_log_interval is not None: warnings.warn( "`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) if not row_log_interval: # in case you did not set the proper value row_log_interval = add_row_log_interval self.row_log_interval = row_log_interval # how much of the data to use self.overfit_pct = overfit_pct self.determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct) # 16 bit mixed precision training using apex self.amp_level = amp_level self.precision = precision assert self.precision in (16, 32), 'only 32 or 16 bit precision supported' if self.precision == 16 and self.num_tpu_cores is None: use_amp = True self.init_amp(use_amp) # Callback system self.on_init_end()
def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = False, callbacks: Optional[List[Callback]] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, process_position: int = 0, num_nodes: int = 1, num_processes: int = 1, gpus: Optional[Union[List[int], str, int]] = None, auto_select_gpus: bool = False, tpu_cores: Optional[Union[List[int], str, int]] = None, log_gpu_memory: Optional[str] = None, progress_bar_refresh_rate: int = 1, overfit_batches: Union[int, float] = 0.0, track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, fast_dev_run: bool = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_epochs: int = 1000, min_epochs: int = 1, max_steps: Optional[int] = None, min_steps: Optional[int] = None, limit_train_batches: Union[int, float] = 1.0, limit_val_batches: Union[int, float] = 1.0, limit_test_batches: Union[int, float] = 1.0, val_check_interval: Union[int, float] = 1.0, log_save_interval: int = 100, row_log_interval: int = 50, distributed_backend: Optional[str] = None, sync_batchnorm: bool = False, precision: int = 32, weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT, weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, profiler: Optional[Union[BaseProfiler, bool]] = None, benchmark: bool = False, deterministic: bool = False, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: bool = True, amp_backend: str = 'native', amp_level: str = 'O2', # backward compatible, todo: remove in v1.0.0 val_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 test_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 train_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 overfit_pct: float = None, # backward compatible, todo: remove in v1.0.0 ): super().__init__() self.deterministic = deterministic torch.backends.cudnn.deterministic = self.deterministic if self.deterministic: # fixing non-deterministic part of horovod # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383 os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0) # init the default rank if exists # we need to call this here or NVIDIA flags and other messaging in init will show on all ranks # this way we only show it on rank 0 if 'LOCAL_RANK' in os.environ: rank_zero_only.rank = int(os.environ['LOCAL_RANK']) # tracks internal state for debugging self.dev_debugger = InternalDebugger(self) self.config_validator = ConfigValidator(self) self.data_connector = DataConnector(self) self.lr_scheduler_connector = LRSchedulerConnector(self) self.accelerator_connector = AcceleratorConnector(self) self.logger_connector = LoggerConnector(self) self.model_connector = ModelConnector(self) self.initializer = Initializer(self) self.tuner = Tuner(self) self.accelerator_backend = None # loops self.evaluation_loop = EvaluationLoop(self) self.train_loop = TrainLoop(self) # training bookeeping self.total_batch_idx = 0 self.running_loss = TensorRunningAccum(window_length=20) self.batch_idx = 0 self.num_training_batches = 0 self.num_val_batches = [] self.num_sanity_val_batches = [] self.num_test_batches = [] self.train_dataloader = None self.test_dataloaders = None self.val_dataloaders = None # when true, prints test results self.verbose_test = True # when .test() is called, it sets this self.tested_ckpt_path = None # training state self.model = None self.datamodule = None self.testing = False self.prepare_data_per_node = prepare_data_per_node self.lr_schedulers = [] self.optimizers = None self.optimizer_frequencies = [] self.global_step = 0 self.current_epoch = 0 self.interrupted = False self.should_stop = False self.running_sanity_check = False self._state = TrainerState.INITIALIZING self._default_root_dir = default_root_dir or os.getcwd() self._weights_save_path = weights_save_path or self._default_root_dir # init callbacks self.callbacks = callbacks or [] # configure early stop callback # creates a default one if none passed in early_stop_callback = self.configure_early_stopping( early_stop_callback) if early_stop_callback: self.callbacks.append(early_stop_callback) # configure checkpoint callback # it is important that this is the last callback to run # pass through the required args to figure out defaults checkpoint_callback = self.configure_checkpoint_callback( checkpoint_callback) if checkpoint_callback: self.callbacks.append(checkpoint_callback) # TODO refactor codebase (tests) to not directly reach into these callbacks self.checkpoint_callback = checkpoint_callback self.early_stop_callback = early_stop_callback self.on_init_start() # benchmarking self.benchmark = benchmark torch.backends.cudnn.benchmark = self.benchmark # Transfer params self.num_nodes = num_nodes self.log_gpu_memory = log_gpu_memory # sync-bn backend self.sync_batchnorm = sync_batchnorm self.gradient_clip_val = gradient_clip_val self.check_val_every_n_epoch = check_val_every_n_epoch if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf': raise MisconfigurationException( "track_grad_norm can be an int, a float or 'inf' (infinity norm)." ) self.track_grad_norm = float(track_grad_norm) self.tpu_cores = device_parser.parse_tpu_cores(tpu_cores) self.on_tpu = self.tpu_cores is not None self.tpu_id = self.tpu_cores[0] if isinstance(self.tpu_cores, list) else None if num_processes != 1 and distributed_backend != "ddp_cpu": rank_zero_warn( "num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it." ) self.num_processes = num_processes self.weights_summary = weights_summary self.max_epochs = max_epochs self.min_epochs = min_epochs self.max_steps = max_steps self.min_steps = min_steps if num_sanity_val_steps == -1: self.num_sanity_val_steps = float('inf') else: self.num_sanity_val_steps = num_sanity_val_steps self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.auto_lr_find = auto_lr_find self.auto_scale_batch_size = auto_scale_batch_size self._is_data_prepared = False self.replace_sampler_ddp = replace_sampler_ddp self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint self.terminate_on_nan = terminate_on_nan self.shown_warnings = set() self.fast_dev_run = fast_dev_run if self.fast_dev_run: limit_train_batches = 1 limit_val_batches = 1 limit_test_batches = 1 self.num_sanity_val_steps = 0 self.max_epochs = 1 rank_zero_info( 'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch') # configure profiler if profiler is True: profiler = SimpleProfiler() self.profiler = profiler or PassThroughProfiler() # accumulated grads self.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) # override with environment flag gpus = os.environ.get('PL_TRAINER_GPUS', gpus) # for gpus allow int, string and gpu list if auto_select_gpus and isinstance(gpus, int): self.gpus = self.tuner.pick_multiple_gpus(gpus) else: self.gpus = gpus self.data_parallel_device_ids = device_parser.parse_gpu_ids(self.gpus) self.root_gpu = device_parser.determine_root_gpu_device( self.data_parallel_device_ids) self.root_device = torch.device("cpu") self.on_gpu = True if (self.data_parallel_device_ids and torch.cuda.is_available()) else False # tpu state flags self.use_tpu = False self.tpu_local_core_rank = None self.tpu_global_core_rank = None # distributed backend choice self.distributed_backend = distributed_backend self.set_distributed_mode(distributed_backend) # override dist backend when using tpus if self.on_tpu: self.distributed_backend = 'tpu' self.init_tpu() # init flags for SLURM+DDP to work self.world_size = 1 self.interactive_ddp_procs = [] self.configure_slurm_ddp(self.num_nodes) self.node_rank = self.determine_ddp_node_rank() self.local_rank = self.determine_local_rank() self.global_rank = 0 # NVIDIA setup self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) self._progress_bar_callback = self.configure_progress_bar( progress_bar_refresh_rate, process_position) # logging self.configure_logger(logger) self.log_save_interval = log_save_interval self.row_log_interval = row_log_interval # how much of the data to use # TODO: remove in 0.10.0 if overfit_pct is not None: rank_zero_warn( "Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0" " and this argument will be removed in v0.10.0", DeprecationWarning, ) overfit_batches = overfit_pct # TODO: remove in 0.10.0 if val_percent_check is not None: rank_zero_warn( "Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0" " and this argument will be removed in v0.10.0", DeprecationWarning, ) limit_val_batches = val_percent_check # TODO: remove in 0.10.0 if test_percent_check is not None: rank_zero_warn( "Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0" " and this argument will be removed in v0.10.0", DeprecationWarning, ) limit_test_batches = test_percent_check # TODO: remove in 0.10.0 if train_percent_check is not None: rank_zero_warn( "Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0" " and this argument will be removed in v0.10.0", DeprecationWarning, ) limit_train_batches = train_percent_check self.limit_train_batches = _determine_batch_limits( limit_train_batches, 'limit_train_batches') self.limit_val_batches = _determine_batch_limits( limit_val_batches, 'limit_val_batches') self.limit_test_batches = _determine_batch_limits( limit_test_batches, 'limit_test_batches') self.val_check_interval = _determine_batch_limits( val_check_interval, 'val_check_interval') self.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches') self.determine_data_use_amount(self.overfit_batches) # AMP init # These are the only lines needed after v0.8.0 # we wrap the user's forward with autocast and give it back at the end of fit self.autocast_original_forward = None self.precision = precision self.scaler = None self.amp_level = amp_level self.initializer.init_amp(amp_backend) self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv( 'KAGGLE_URL_BASE') # Callback system self.on_init_end()
def __init__(self, data, debug=False): ####################################################### # # Directories # self.data_set = f"etf" self.data_dir = f"./data/{self.data_set}" self.data_pickle = "allData.pickle" self.dataset_size = 1_500 if debug == True else 1_000_000_000 # experiment directory self.work_dir = "./experiments/logs" ####################################################### # # Debugging # self.debug = debug # If debug is on, do not debug those functions self.skip_debug = [''] # Choices: pl.PassThroughProfiler(): no proofiling # pl.SimpleProfiler(): just time recording # pl.AdvancedProfiler(),... # See pytorch_lightning.profiler # Simple and Advanced profilers accept output_filename self.profiler = PassThroughProfiler() ####################################################### # # Dimensions # # dimensionality of the transformer_model's hidden states' # depth of the transformer_model = no. of series = n_series self.d_model = data.shape[1] self.adapt_inp = False self.n_layer = 3 if debug == True else 8 # number of attention heads for each attention layer in the Transformer # encoder self.n_head = 4 if debug == True else 16 # dimensionality of the transformer_model's heads self.d_head = 8 if debug == True else 32 # dimensionality of the hidden states self.d_hidden = self.d_model # Dimensionality of the embeddings - must be EVEN self.d_pos_enc = 12 if debug == True else self.d_model // 2 # transformer_model dimension. Must be even. self.n_model = 13 if debug == True else 60 self.d_FF_inner = 4 if debug == True else 16 # TODO: Check that n_train is actually used self.n_train = 12 self.n_val = 6 if debug == True else 12 self.n_test = 9 if debug == True else 12 # batch size" self.n_batch = 19 if debug == True else 64 self.batch_chunk = 1 self.not_tied = False self.pre_lnorm = False self.dropout = 0.0 self.dropout_attn = 0.0 # When debugging, dataloaders will run in the main process self.num_workers = 4 if debug == True else 4 # number of tokens to predict self.n_predict = 3 if debug == True else 10 self.eval_n_predict = 5 if debug == True else 20 # length of the extended context self.n_ext_ctx = 2 if debug == True else 16 self.n_mems = 2 if debug == True else 64 self.varlen = False self.same_length = True # use the same pos embeddings after n_clamp_after self.n_clamp_after = -1 # parameter initializer to use. self.init = "normal" self.emb_init = "normal" self.init_range = 0.1 self.emb_init_range = 0.01 self.init_std = 0.02 self.proj_init_std = 0.01 ####################################################### # # Running parameters # self.max_epochs = 100 if debug == True else 1_000 # Optimizer / Scheduler # Choices: adam, sgd, adagrad self.optim = "adam" self.lr = 0.00025 # Choices: cosine, inv_sqrt, dev_perf, constant self.scheduler = "dev_perf" self.warmup_step = 0 self.decay_rate = 0.5 self.min_lr = 0.0 self.clip = 0.25 self.clip_nonemb = True self.eta_min = 0.0 self.patience = 0 # momentum for sgd self.mom = 0.0 # random seed self.seed = 42 self.max_step = 4 if debug == True else 512 self.max_eval_steps = -1 self.log_interval = 10 if debug == True else 200 # evaluation interval self.eval_interval = 20 if debug == True else 200 # Restart self.restart_dir = "" self.restart = True self.restart_from = None self.finetune_v2 = True self.finetune_v3 = True self.log_first_epochs = 0 self.reset_lr = True # TODO reset learning schedule to start self.expand = None # TODO Add layers to transformer_model throughout training # choices: "repeat", "reinit", "repeat_bottom", "reinit_bottom", "duplicate" self.integration = "" # choices=["freeze", "reverse_distil_full", # "reverse_distil_partial"] self.integration_length = 0 self.expansion_dict = {} # TODO Add layers to transformer_model throughout training # choices: "reinit", "duplicate" self.widen = None self.widen_dict = {}