class TrainerTrainLoopMixin(ABC):

    # this is just a summary on variables used in this abstract class,
    #  the proper values/initialisation should be done in child class
    max_epochs: int
    min_epochs: int
    use_ddp: bool
    use_dp: bool
    use_ddp2: bool
    single_gpu: bool
    use_tpu: bool
    data_parallel_device_ids: ...
    check_val_every_n_epoch: ...
    num_training_batches: int
    val_check_batch: ...
    num_val_batches: int
    disable_validation: bool
    fast_dev_run: ...
    main_progress_bar: ...
    accumulation_scheduler: ...
    lr_schedulers: ...
    enable_early_stop: ...
    early_stop_callback: ...
    callback_metrics: ...
    logger: Union[LightningLoggerBase, bool]
    global_step: int
    testing: bool
    log_save_interval: float
    proc_rank: int
    row_log_interval: float
    total_batches: int
    truncated_bptt_steps: ...
    optimizers: ...
    optimizer_frequencies: ...
    accumulate_grad_batches: int
    use_amp: bool
    track_grad_norm: ...
    model: LightningModule
    running_loss: ...
    training_tqdm_dict: ...
    reduce_lr_on_plateau_scheduler: ...
    profiler: ...
    batch_idx: int
    precision: ...
    train_dataloader: DataLoader
    reload_dataloaders_every_epoch: bool
    progress_bar_refresh_rate: ...
    max_steps: int
    min_steps: int
    total_batch_idx: int
    checkpoint_callback: ...

    # Callback system
    callbacks: List[Callback]
    on_train_start: Callable
    on_train_end: Callable
    on_batch_start: Callable
    on_batch_end: Callable
    on_epoch_start: Callable
    on_epoch_end: Callable
    on_validation_end: Callable

    @abstractmethod
    def get_model(self):
        """Warning: this is just empty shell for code implemented in other class."""

    @abstractmethod
    def is_function_implemented(self, *args):
        """Warning: this is just empty shell for code implemented in other class."""

    @abstractmethod
    def run_evaluation(self, *args):
        """Warning: this is just empty shell for code implemented in other class."""

    @abstractmethod
    def transfer_batch_to_gpu(self, *args):
        """Warning: this is just empty shell for code implemented in other class."""

    @abstractmethod
    def transfer_batch_to_tpu(self, *args):
        """Warning: this is just empty shell for code implemented in other class."""

    @abstractmethod
    def clip_gradients(self):
        """Warning: this is just empty shell for code implemented in other class."""

    @abstractmethod
    def detect_nan_tensors(self, *args):
        """Warning: this is just empty shell for code implemented in other class."""

    @abstractmethod
    def is_overriden(self, *args):
        """Warning: this is just empty shell for code implemented in other class."""

    @abstractmethod
    def add_tqdm_metrics(self, *args):
        """Warning: this is just empty shell for code implemented in other class."""

    @abstractmethod
    def log_metrics(self, *args):
        """Warning: this is just empty shell for code implemented in other class."""

    @abstractmethod
    def process_output(self, *args):
        """Warning: this is just empty shell for code implemented in other class."""

    @abstractmethod
    def reset_train_dataloader(self, *args):
        """Warning: this is just empty shell for code implemented in other class."""

    @abstractmethod
    def reset_val_dataloader(self, model):
        """Warning: this is just empty shell for code implemented in other class."""

    @abstractmethod
    def has_arg(self, *args):
        """Warning: this is just empty shell for code implemented in other class."""

    def train(self):
        warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
                      ' but will start from "0" in v0.8.0.', RuntimeWarning)

        # get model
        model = self.get_model()

        # load data
        # if reload_dataloaders_every_epoch, this is moved to the epoch loop
        if not self.reload_dataloaders_every_epoch:
            self.reset_train_dataloader(model)
        self.reset_val_dataloader(model)

        # Train start events
        with self.profiler.profile('on_train_start'):
            # callbacks
            self.on_train_start()
            # initialize early stop callback
            if self.early_stop_callback is not None:
                self.early_stop_callback.on_train_start(self, self.get_model())
            # model hooks
            model.on_train_start()

        try:
            # run all epochs
            for epoch in range(self.current_epoch, self.max_epochs):
                # reset train dataloader
                if self.reload_dataloaders_every_epoch:
                    self.reset_train_dataloader(model)
                # set seed for distributed sampler (enables shuffling for each epoch)
                if self.use_ddp \
                        and hasattr(self.train_dataloader.sampler, 'set_epoch'):
                    self.train_dataloader.sampler.set_epoch(epoch)

                # update training progress in trainer and model
                model.current_epoch = epoch
                self.current_epoch = epoch

                total_val_batches = 0
                is_val_epoch = False
                if not self.disable_validation and self.num_training_batches != float('inf'):
                    # val can be checked multiple times in epoch
                    is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
                    val_checks_per_epoch = self.num_training_batches // self.val_check_batch
                    val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
                    total_val_batches = self.num_val_batches * val_checks_per_epoch

                # total batches includes multiple val checks
                self.total_batches = self.num_training_batches + total_val_batches

                # changing gradient according accumulation_scheduler
                self.accumulation_scheduler.on_epoch_start(self, self.get_model())

                # stores accumulated grad fractions per batch
                self.batch_loss_value = TensorRunningMean(
                    window_length=self.accumulate_grad_batches
                )

                if self.fast_dev_run:
                    # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
                    num_iterations = 2
                elif self.total_batches == float('inf'):
                    # for infinite train or val loader, the progress bar never ends
                    num_iterations = None
                else:
                    num_iterations = self.total_batches

                # reset progress bar
                # .reset() doesn't work on disabled progress bar so we should check
                if not self.main_progress_bar.disable:
                    self.main_progress_bar.reset(num_iterations)
                desc = f'Epoch {epoch + 1}'
                self.main_progress_bar.set_description(desc)

                # -----------------
                # RUN TNG EPOCH
                # -----------------
                self.run_training_epoch()

                # update LR schedulers
                self.update_learning_rates(interval='epoch')

                if self.max_steps and self.max_steps == self.global_step:
                    self.run_training_teardown()
                    return

                # early stopping
                met_min_epochs = epoch >= self.min_epochs - 1
                met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

                # TODO wrap this logic into the callback
                if self.enable_early_stop:
                    if (met_min_epochs and met_min_steps) or self.fast_dev_run:
                        should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())
                        # stop training
                        stop = should_stop and met_min_epochs
                        if stop:
                            self.run_training_teardown()
                            return

            self.run_training_teardown()

        except KeyboardInterrupt:
            log.info('Detected KeyboardInterrupt, attempting graceful shutdown...')
            self.run_training_teardown()

    def run_training_epoch(self):

        # Epoch start events
        with self.profiler.profile('on_epoch_start'):
            # callbacks
            self.on_epoch_start()

            # model hooks
            if self.is_function_implemented('on_epoch_start'):
                self.get_model().on_epoch_start()

        # track local dataloader so TPU can wrap each epoch
        train_dataloader = self.train_dataloader

        # on TPU we have to wrap it under the ParallelLoader
        if self.use_tpu:
            device = xm.xla_device()
            train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
            train_dataloader = train_dataloader.per_device_loader(device)

        # run epoch
        for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
            enumerate(_with_is_last(train_dataloader)), "get_train_batch"
        ):
            # stop epoch if we limited the number of training batches
            if batch_idx >= self.num_training_batches:
                break

            self.batch_idx = batch_idx

            model = self.get_model()
            model.global_step = self.global_step

            # ---------------
            # RUN TRAIN STEP
            # ---------------
            output = self.run_training_batch(batch, batch_idx)
            batch_result, grad_norm_dic, batch_step_metrics = output

            # when returning -1 from train_step, we end epoch early
            early_stop_epoch = batch_result == -1

            # update lr
            self.update_learning_rates(interval='step')

            # ---------------
            # RUN VAL STEP
            # ---------------
            is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
            can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
            can_check_val = not self.disable_validation and can_check_epoch
            should_check_val = is_val_check_batch or early_stop_epoch
            should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
            should_check_val = can_check_val and should_check_val

            # fast_dev_run always forces val checking after train batch
            if self.fast_dev_run or should_check_val:
                self.run_evaluation(test_mode=self.testing)

            # when logs should be saved
            should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
            if should_save_log or self.fast_dev_run:
                if self.proc_rank == 0 and self.logger is not None:
                    self.logger.save()

            # when metrics should be logged
            should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
            if should_log_metrics or self.fast_dev_run:
                # logs user requested information to logger
                self.log_metrics(batch_step_metrics, grad_norm_dic)

            # ---------------
            # CHECKPOINTING, EARLY STOPPING
            # ---------------
            # save checkpoint even when no test or val step are defined
            if self.fast_dev_run or should_check_val:
                self.call_checkpoint_callback()

                if self.enable_early_stop:
                    self.early_stop_callback.check_metrics(self.callback_metrics)

            # progress global step according to grads progress
            if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
                self.global_step += 1
            self.total_batch_idx += 1

            # max steps reached, end training
            if self.max_steps is not None and self.max_steps == self.global_step:
                break

            # end epoch early
            # stop when the flag is changed or we've gone past the amount
            # requested in the batches
            if early_stop_epoch or self.fast_dev_run:
                break

        # in case validation step is missing and you are not running fast-dev to duplicate last batch
        if not self.is_overriden('validation_step') and not (self.fast_dev_run or should_check_val):
            self.call_checkpoint_callback()

            if self.enable_early_stop:
                self.early_stop_callback.check_metrics(self.callback_metrics)

        # Epoch end events
        with self.profiler.profile('on_epoch_end'):
            # callbacks
            self.on_epoch_end()
            # model hooks
            if self.is_function_implemented('on_epoch_end'):
                self.get_model().on_epoch_end()

    def run_training_batch(self, batch, batch_idx):
        # track grad norms
        grad_norm_dic = {}

        # track all metrics for callbacks
        all_callback_metrics = []

        # track metrics to log
        all_log_metrics = []

        if batch is None:
            return 0, grad_norm_dic, {}

        # Batch start events
        with self.profiler.profile('on_batch_start'):
            # callbacks
            self.on_batch_start()
            # hooks
            if self.is_function_implemented('on_batch_start'):
                response = self.get_model().on_batch_start(batch)
                if response == -1:
                    return -1, grad_norm_dic, {}

        splits = [batch]
        if self.truncated_bptt_steps is not None:
            model_ref = self.get_model()
            with self.profiler.profile('tbptt_split_batch'):
                splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps)

        self.hiddens = None
        for split_idx, split_batch in enumerate(splits):
            self.split_idx = split_idx

            for opt_idx, optimizer in self._get_optimizers_iterable():
                # make sure only the gradients of the current optimizer's paramaters are calculated
                # in the training step to prevent dangling gradients in multiple-optimizer setup.
                if len(self.optimizers) > 1:
                    for param in self.get_model().parameters():
                        param.requires_grad = False
                    for group in optimizer.param_groups:
                        for param in group['params']:
                            param.requires_grad = True

                # wrap the forward step in a closure so second order methods work
                def optimizer_closure():
                    # forward pass
                    with self.profiler.profile('model_forward'):
                        output = self.training_forward(
                            split_batch, batch_idx, opt_idx, self.hiddens)

                    closure_loss = output[0]
                    progress_bar_metrics = output[1]
                    log_metrics = output[2]
                    callback_metrics = output[3]
                    self.hiddens = output[4]

                    # accumulate loss
                    # (if accumulate_grad_batches = 1 no effect)
                    closure_loss = closure_loss / self.accumulate_grad_batches

                    # backward pass
                    model_ref = self.get_model()
                    with self.profiler.profile('model_backward'):
                        model_ref.backward(self, closure_loss, optimizer, opt_idx)

                    # track metrics for callbacks
                    all_callback_metrics.append(callback_metrics)

                    # track progress bar metrics
                    self.add_tqdm_metrics(progress_bar_metrics)
                    all_log_metrics.append(log_metrics)

                    # insert after step hook
                    if self.is_function_implemented('on_after_backward'):
                        model_ref = self.get_model()
                        with self.profiler.profile('on_after_backward'):
                            model_ref.on_after_backward()

                    return closure_loss

                # calculate loss
                loss = optimizer_closure()

                # check if loss or model weights are nan
                self.detect_nan_tensors(loss)

                # track total loss for logging (avoid mem leaks)
                self.batch_loss_value.append(loss)

                # gradient update with accumulated gradients
                if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:

                    # track gradient norms when requested
                    if batch_idx % self.row_log_interval == 0:
                        if self.track_grad_norm > 0:
                            model = self.get_model()
                            grad_norm_dic = model.grad_norm(
                                self.track_grad_norm)

                    # clip gradients
                    self.clip_gradients()

                    # calls .step(), .zero_grad()
                    # override function to modify this behavior
                    model = self.get_model()
                    with self.profiler.profile('optimizer_step'):
                        model.optimizer_step(self.current_epoch, batch_idx,
                                             optimizer, opt_idx, optimizer_closure)

                    # calculate running loss for display
                    self.running_loss.append(self.batch_loss_value.mean())

                    # reset for next set of accumulated grads
                    self.batch_loss_value.reset()

        # Batch end events
        with self.profiler.profile('on_batch_end'):
            # callbacks
            self.on_batch_end()
            # model hooks
            if self.is_function_implemented('on_batch_end'):
                self.get_model().on_batch_end()

        # update progress bar
        if self.progress_bar_refresh_rate >= 1 and batch_idx % self.progress_bar_refresh_rate == 0:
            self.main_progress_bar.update(self.progress_bar_refresh_rate)
            self.main_progress_bar.set_postfix(**self.training_tqdm_dict)

        # collapse all metrics into one dict
        all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}

        # track all metrics for callbacks
        self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()})

        return 0, grad_norm_dic, all_log_metrics

    def _get_optimizers_iterable(self):
        if not self.optimizer_frequencies:
            # call training_step once per optimizer
            return list(enumerate(self.optimizers))

        optimizer_freq_cumsum = np.cumsum(self.optimizer_frequencies)
        optimizers_loop_length = optimizer_freq_cumsum[-1]
        current_place_in_loop = self.total_batch_idx % optimizers_loop_length

        # find optimzier index by looking for the first {item > current_place} in the cumsum list
        opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
        return [(opt_idx, self.optimizers[opt_idx])]

    def run_training_teardown(self):
        self.main_progress_bar.close()

        # Train end events
        with self.profiler.profile('on_train_end'):
            # callbacks
            self.on_train_end()
            # model hooks
            if self.is_function_implemented('on_train_end'):
                self.get_model().on_train_end()

        if self.logger is not None:
            self.logger.finalize("success")

        # summarize profile results
        self.profiler.describe()

    def training_forward(self, batch, batch_idx, opt_idx, hiddens):
        """
        Handle forward for each training case (distributed, single gpu, etc...)
        :param batch:
        :param batch_idx:
        :return:
        """
        # ---------------
        # FORWARD
        # ---------------
        # enable not needing to add opt_idx to training_step
        args = [batch, batch_idx]

        if len(self.optimizers) > 1:
            if self.has_arg('training_step', 'optimizer_idx'):
                args.append(opt_idx)
            else:
                num_opts = len(self.optimizers)
                raise ValueError(
                    f'Your LightningModule defines {num_opts} optimizers but '
                    f'training_step is missing the "optimizer_idx" argument.'
                )

        # pass hiddens if using tbptt
        if self.truncated_bptt_steps is not None:
            args.append(hiddens)

        # distributed forward
        if self.use_ddp or self.use_ddp2 or self.use_dp:
            output = self.model(*args)

        # single GPU forward
        elif self.single_gpu:
            gpu_id = 0
            if isinstance(self.data_parallel_device_ids, list):
                gpu_id = self.data_parallel_device_ids[0]
            batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id)
            args[0] = batch
            output = self.model.training_step(*args)

        # TPU support
        elif self.use_tpu:
            batch = self.transfer_batch_to_tpu(copy.copy(batch))
            args[0] = batch
            output = self.model.training_step(*args)

        # CPU forward
        else:
            output = self.model.training_step(*args)

        # allow any mode to define training_step_end
        # do something will all the dp outputs (like softmax)
        if self.is_overriden('training_step_end'):
            model_ref = self.get_model()
            with self.profiler.profile('training_step_end'):
                output = model_ref.training_step_end(output)

        # allow any mode to define training_end
        # TODO: remove in 1.0.0
        if self.is_overriden('training_end'):
            model_ref = self.get_model()
            with self.profiler.profile('training_end'):
                output = model_ref.training_end(output)

            warnings.warn('`training_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
                          ' Use training_epoch_end instead', DeprecationWarning)

        # format and reduce outputs accordingly
        output = self.process_output(output, train=True)

        return output

    def update_learning_rates(self, interval: str):
        """Update learning rates.

        Args:
            interval: either 'epoch' or 'step'.
        """
        if not self.lr_schedulers:
            return

        for lr_scheduler in self.lr_schedulers:
            current_idx = self.batch_idx if interval == 'step' else self.current_epoch
            current_idx += 1  # account for both batch and epoch starts from 0
            # Take step if call to update_learning_rates matches the interval key and
            # the current step modulo the schedulers frequency is zero
            if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency'] == 0:
                # If instance of ReduceLROnPlateau, we need to pass validation loss
                if lr_scheduler['reduce_on_plateau']:
                    monitor_key = lr_scheduler['monitor']
                    monitor_val = self.callback_metrics.get(monitor_key)
                    if monitor_val is None:
                        avail_metrics = ','.join(list(self.callback_metrics.keys()))
                        raise MisconfigurationException(
                            f'ReduceLROnPlateau conditioned on metric {monitor_key}'
                            f' which is not available. Available metrics are: {avail_metrics}.'
                            ' Condition can be set using `monitor` key in lr scheduler dict'
                        )
                    lr_scheduler['scheduler'].step(monitor_val)
                else:
                    lr_scheduler['scheduler'].step()

    def call_checkpoint_callback(self):
        if self.checkpoint_callback is not None:
            self.checkpoint_callback.on_validation_end(self, self.get_model())
        self.on_validation_end()
    def train(self):
        warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
                      ' but will start from "0" in v0.8.0.', RuntimeWarning)

        # get model
        model = self.get_model()

        # load data
        # if reload_dataloaders_every_epoch, this is moved to the epoch loop
        if not self.reload_dataloaders_every_epoch:
            self.reset_train_dataloader(model)
        self.reset_val_dataloader(model)

        # Train start events
        with self.profiler.profile('on_train_start'):
            # callbacks
            self.on_train_start()
            # initialize early stop callback
            if self.early_stop_callback is not None:
                self.early_stop_callback.on_train_start(self, self.get_model())
            # model hooks
            model.on_train_start()

        try:
            # run all epochs
            for epoch in range(self.current_epoch, self.max_epochs):
                # reset train dataloader
                if self.reload_dataloaders_every_epoch:
                    self.reset_train_dataloader(model)
                # set seed for distributed sampler (enables shuffling for each epoch)
                if self.use_ddp \
                        and hasattr(self.train_dataloader.sampler, 'set_epoch'):
                    self.train_dataloader.sampler.set_epoch(epoch)

                # update training progress in trainer and model
                model.current_epoch = epoch
                self.current_epoch = epoch

                total_val_batches = 0
                is_val_epoch = False
                if not self.disable_validation and self.num_training_batches != float('inf'):
                    # val can be checked multiple times in epoch
                    is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
                    val_checks_per_epoch = self.num_training_batches // self.val_check_batch
                    val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
                    total_val_batches = self.num_val_batches * val_checks_per_epoch

                # total batches includes multiple val checks
                self.total_batches = self.num_training_batches + total_val_batches

                # changing gradient according accumulation_scheduler
                self.accumulation_scheduler.on_epoch_start(self, self.get_model())

                # stores accumulated grad fractions per batch
                self.batch_loss_value = TensorRunningMean(
                    window_length=self.accumulate_grad_batches
                )

                if self.fast_dev_run:
                    # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
                    num_iterations = 2
                elif self.total_batches == float('inf'):
                    # for infinite train or val loader, the progress bar never ends
                    num_iterations = None
                else:
                    num_iterations = self.total_batches

                # reset progress bar
                # .reset() doesn't work on disabled progress bar so we should check
                if not self.main_progress_bar.disable:
                    self.main_progress_bar.reset(num_iterations)
                desc = f'Epoch {epoch + 1}'
                self.main_progress_bar.set_description(desc)

                # -----------------
                # RUN TNG EPOCH
                # -----------------
                self.run_training_epoch()

                # update LR schedulers
                self.update_learning_rates(interval='epoch')

                if self.max_steps and self.max_steps == self.global_step:
                    self.run_training_teardown()
                    return

                # early stopping
                met_min_epochs = epoch >= self.min_epochs - 1
                met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

                # TODO wrap this logic into the callback
                if self.enable_early_stop:
                    if (met_min_epochs and met_min_steps) or self.fast_dev_run:
                        should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())
                        # stop training
                        stop = should_stop and met_min_epochs
                        if stop:
                            self.run_training_teardown()
                            return

            self.run_training_teardown()

        except KeyboardInterrupt:
            log.info('Detected KeyboardInterrupt, attempting graceful shutdown...')
            self.run_training_teardown()
Beispiel #3
0
    def __init__(
            self,
            logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
            checkpoint_callback: Union[ModelCheckpoint, bool] = True,
            early_stop_callback: Optional[Union[EarlyStopping, bool]] = False,
            callbacks: List[Callback] = [],
            default_save_path: Optional[str] = None,
            gradient_clip_val: float = 0,
            process_position: int = 0,
            num_nodes: int = 1,
            gpus: Optional[Union[List[int], str, int]] = None,
            num_tpu_cores: Optional[int] = None,
            log_gpu_memory: Optional[str] = None,
            show_progress_bar=None,  # backward compatible, todo: remove in v0.9.0
            progress_bar_refresh_rate: int = 1,
            overfit_pct: float = 0.0,
            track_grad_norm: int = -1,
            check_val_every_n_epoch: int = 1,
            fast_dev_run: bool = False,
            accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
            max_epochs: int = 1000,
            min_epochs: int = 1,
            max_steps: Optional[int] = None,
            min_steps: Optional[int] = None,
            train_percent_check: float = 1.0,
            val_percent_check: float = 1.0,
            test_percent_check: float = 1.0,
            val_check_interval: float = 1.0,
            log_save_interval: int = 100,
            row_log_interval: int = 10,
            add_row_log_interval=None,  # backward compatible, todo: remove in v0.8.0
            distributed_backend: Optional[str] = None,
            precision: int = 32,
            print_nan_grads: bool = False,  # backward compatible, todo: remove in v0.9.0
            weights_summary: Optional[str] = 'full',
            weights_save_path: Optional[str] = None,
            amp_level: str = 'O1',
            num_sanity_val_steps: int = 5,
            truncated_bptt_steps: Optional[int] = None,
            resume_from_checkpoint: Optional[str] = None,
            profiler: Optional[BaseProfiler] = None,
            benchmark: bool = False,
            reload_dataloaders_every_epoch: bool = False,
            gradient_clip=None,  # backward compatible, todo: remove in v0.8.0
            nb_gpu_nodes=None,  # backward compatible, todo: remove in v0.8.0
            max_nb_epochs=None,  # backward compatible, todo: remove in v0.8.0
            min_nb_epochs=None,  # backward compatible, todo: remove in v0.8.0
            use_amp=False,  # backward compatible, todo: remove in v0.9.0
            nb_sanity_val_steps=None,  # backward compatible, todo: remove in v0.8.0
            **kwargs
    ):
        r"""

        Customize every aspect of training via flags

        Args:
            logger: Logger (or iterable collection of loggers) for experiment tracking.

            checkpoint_callback: Callback for checkpointing.

            early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`):

            callbacks: Add a list of callbacks.

            default_save_path: Default path for logs and weights when no logger/ckpt_callback passed

            gradient_clip_val: 0 means don't clip.

            gradient_clip:
                .. warning:: .. deprecated:: 0.7.0

                    Use `gradient_clip_val` instead. Will remove 0.9.0.

            process_position: orders the tqdm bar when running multiple models on same machine.

            num_nodes: number of GPU nodes for distributed training.

            nb_gpu_nodes:
                .. warning:: .. deprecated:: 0.7.0

                    Use `num_nodes` instead. Will remove 0.9.0.

            gpus: Which GPUs to train on.

            num_tpu_cores: How many TPU cores to train on (1 or 8).

            log_gpu_memory: None, 'min_max', 'all'. Might slow performance

            show_progress_bar:
                .. warning:: .. deprecated:: 0.7.2

                        Set `progress_bar_refresh_rate` to postive integer to enable. Will remove 0.9.0.

            progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.

            overfit_pct: How much of training-, validation-, and test dataset to check.

            track_grad_norm: -1 no tracking. Otherwise tracks that norm

            check_val_every_n_epoch: Check val every n train epochs.

            fast_dev_run: runs 1 batch of train, test  and val to find any bugs (ie: a sort of unit test).

            accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.

            max_epochs: Stop training once this number of epochs is reached.

            max_nb_epochs:
                .. warning:: .. deprecated:: 0.7.0

                    Use `max_epochs` instead. Will remove 0.9.0.

            min_epochs: Force training for at least these many epochs

            min_nb_epochs:
                .. warning:: .. deprecated:: 0.7.0

                    Use `min_epochs` instead. Will remove 0.9.0.

            max_steps: Stop training after this number of steps. Disabled by default (None).

            min_steps: Force training for at least these number of steps. Disabled by default (None).

            train_percent_check: How much of training dataset to check.

            val_percent_check: How much of validation dataset to check.

            test_percent_check: How much of test dataset to check.

            val_check_interval: How often within one training epoch to check the validation set

            log_save_interval: Writes logs to disk this often

            row_log_interval: How often to add logging rows (does not write to disk)

            add_row_log_interval:
                .. warning:: .. deprecated:: 0.7.0

                    Use `row_log_interval` instead. Will remove 0.9.0.

            distributed_backend: The distributed backend to use.

            use_amp:
                .. warning:: .. deprecated:: 0.7.0

                    Use `precision` instead. Will remove 0.9.0.

            precision: Full precision (32), half precision (16).

            print_nan_grads:
                .. warning:: .. deprecated:: 0.7.2

                    Has no effect. When detected, NaN grads will be printed automatically.
                    Will remove 0.9.0.

            weights_summary: Prints a summary of the weights when training begins.

            weights_save_path: Where to save weights if specified.

            amp_level: The optimization level to use (O1, O2, etc...).

            num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.

            nb_sanity_val_steps:
                .. warning:: .. deprecated:: 0.7.0

                    Use `num_sanity_val_steps` instead. Will remove 0.8.0.

            truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of

            resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.

            profiler:  To profile individual steps during training and assist in

            reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch

            benchmark: If true enables cudnn.benchmark.
        """

        # Init callbacks
        self.callbacks = callbacks
        self.on_init_start()

        # benchmarking
        self.benchmark = benchmark
        if benchmark:
            torch.backends.cudnn.benchmark = True

        # Transfer params
        self.num_nodes = num_nodes
        # Backward compatibility, TODO: remove in v0.8.0
        if nb_gpu_nodes is not None:
            warnings.warn("Argument `nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0"
                          " and this method will be removed in v0.8.0", DeprecationWarning)
            self.num_gpu_nodes = nb_gpu_nodes
        self.log_gpu_memory = log_gpu_memory

        self.gradient_clip_val = gradient_clip_val
        # Backward compatibility, TODO: remove in v0.8.0
        if gradient_clip is not None:
            warnings.warn("Argument `gradient_clip` has renamed to `gradient_clip_val` since v0.5.0"
                          " and this method will be removed in v0.8.0", DeprecationWarning)
            self.gradient_clip = gradient_clip

        self.progress_bar_refresh_rate = progress_bar_refresh_rate
        self.check_val_every_n_epoch = check_val_every_n_epoch
        self.track_grad_norm = track_grad_norm
        self.on_gpu = True if (gpus and torch.cuda.is_available()) else False

        # tpu config
        self.on_tpu = num_tpu_cores is not None
        self.num_tpu_cores = num_tpu_cores
        assert num_tpu_cores in [1, 8, None], 'num_tpu_cores can only be 1 or 8'

        self.process_position = process_position
        self.weights_summary = weights_summary

        self.max_epochs = max_epochs
        # Backward compatibility, TODO: remove in v0.8.0
        if max_nb_epochs is not None:
            warnings.warn("Argument `max_nb_epochs` has renamed to `max_epochs` since v0.5.0"
                          " and this method will be removed in v0.8.0", DeprecationWarning)
            self.max_nb_epochs = max_nb_epochs

        self.min_epochs = min_epochs
        # Backward compatibility, TODO: remove in v0.8.0
        if min_nb_epochs is not None:
            warnings.warn("Argument `min_nb_epochs` has renamed to `min_epochs` since v0.5.0"
                          " and this method will be removed in v0.8.0", DeprecationWarning)
            self.min_nb_epochs = min_nb_epochs

        self.max_steps = max_steps
        self.min_steps = min_steps

        self.num_sanity_val_steps = num_sanity_val_steps
        # Backward compatibility, TODO: remove in v0.8.0
        if nb_sanity_val_steps is not None:
            warnings.warn("Argument `nb_sanity_val_steps` has renamed to "
                          "`num_sanity_val_steps` since v0.5.0"
                          " and this method will be removed in v0.8.0", DeprecationWarning)
            self.nb_sanity_val_steps = nb_sanity_val_steps

        # Backward compatibility, TODO: remove in v0.9.0
        if print_nan_grads:
            warnings.warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0."
                          " NaN grads will be printed automatically when detected.",
                          DeprecationWarning)

        self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

        self.truncated_bptt_steps = truncated_bptt_steps
        self.resume_from_checkpoint = resume_from_checkpoint
        self.shown_warnings = set()

        self.fast_dev_run = fast_dev_run
        if self.fast_dev_run:
            self.num_sanity_val_steps = 1
            self.max_epochs = 1
            log.info('Running in fast_dev_run mode: will run a full train,'
                     ' val loop using a single batch')

        # set default save path if user didn't provide one
        self.default_save_path = default_save_path
        if self.default_save_path is None:
            self.default_save_path = os.getcwd()

        # training bookeeping
        self.total_batch_idx = 0
        self.running_loss = TensorRunningMean(window_length=20)
        self.batch_idx = 0
        self.tqdm_metrics = {}
        self.callback_metrics = {}
        self.num_val_batches = 0
        self.num_training_batches = 0
        self.num_test_batches = 0
        self.train_dataloader = None
        self.test_dataloaders = None
        self.val_dataloaders = None

        # training state
        self.model = None
        self.testing = False
        self.disable_validation = False
        self.lr_schedulers = []
        self.optimizers = None
        self.optimizer_frequencies = []
        self.global_step = 0
        self.current_epoch = 0
        self.total_batches = 0

        # configure logger
        self.configure_logger(logger)

        # configure profiler
        if profiler is True:
            profiler = SimpleProfiler()
        self.profiler = profiler or PassThroughProfiler()

        # configure early stop callback
        # creates a default one if none passed in
        self.configure_early_stopping(early_stop_callback)

        # configure checkpoint callback
        self.checkpoint_callback = checkpoint_callback
        self.weights_save_path = weights_save_path

        # accumulated grads
        self.accumulate_grad_batches = accumulate_grad_batches
        self.configure_accumulated_gradients(accumulate_grad_batches)

        # allow int, string and gpu list
        self.gpus = gpus
        self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
        self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)

        # tpu state flags
        self.use_tpu = False
        self.tpu_local_core_rank = None
        self.tpu_global_core_rank = None

        # distributed backend choice
        self.use_ddp = False
        self.use_ddp2 = False
        self.use_dp = False
        self.single_gpu = False
        self.distributed_backend = distributed_backend
        self.set_distributed_mode(distributed_backend, self.num_nodes)

        # override dist backend when using tpus
        if self.on_tpu:
            self.init_tpu()
            self.current_tpu_idx = None

        # init flags for SLURM+ddp to work
        self.proc_rank = 0
        self.world_size = 1
        self.node_rank = 0
        self.configure_slurm_ddp(self.num_nodes)

        # nvidia setup
        self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)

        # can't init progress bar here because starting a new process
        # means the progress_bar won't survive pickling
        # backward compatibility
        if show_progress_bar is not None:
            self.show_progress_bar = show_progress_bar

        # logging
        self.log_save_interval = log_save_interval
        self.val_check_interval = val_check_interval

        # backward compatibility
        if add_row_log_interval is not None:
            warnings.warn("`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0"
                          " and this method will be removed in v0.8.0", DeprecationWarning)
            if not row_log_interval:  # in case you did not set the proper value
                row_log_interval = add_row_log_interval
        self.row_log_interval = row_log_interval

        # how much of the data to use
        self.overfit_pct = overfit_pct
        self.determine_data_use_amount(train_percent_check, val_percent_check,
                                       test_percent_check, overfit_pct)

        # 16 bit mixed precision training using apex
        self.amp_level = amp_level
        self.precision = precision

        assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'

        if self.precision == 16 and self.num_tpu_cores is None:
            use_amp = True
        self.init_amp(use_amp)

        # Callback system
        self.on_init_end()