Пример #1
0
 def __init__(self,
              name: Optional[str] = None,
              save_dir: Optional[str] = None,
              offline: bool = False,
              id: Optional[str] = None,
              anonymous: bool = False,
              version: Optional[str] = None,
              project: Optional[str] = None,
              log_model: bool = False,
              experiment=None,
              prefix: str = '',
              **kwargs):
     if wandb is None:
         raise ImportError(
             'You want to use `wandb` logger which is not installed yet,'  # pragma: no-cover
             ' install it with `pip install wandb`.')
     super().__init__()
     self._name = name
     self._save_dir = save_dir
     self._anonymous = 'allow' if anonymous else None
     self._id = version or id
     self._project = project
     self._experiment = experiment
     self._offline = offline
     self._log_model = log_model
     self._prefix = prefix
     self._kwargs = kwargs
     # logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
     self._step_offset = 0
     self.warning_cache = WarningCache()
Пример #2
0
 def __init__(self, trainer):
     self.trainer = trainer
     self.testing = False
     self.outputs = []
     self.step_metrics = []
     self.predictions = None
     self.max_batches = None
     self.warning_cache = WarningCache()
Пример #3
0
 def __init__(self, trainer):
     self.trainer = trainer
     self.early_stopping_accumulator = None
     self.checkpoint_accumulator = None
     self.accumulated_loss = None
     self.warning_cache = WarningCache()
     self._teardown_already_run = False
     self.running_loss = TensorRunningAccum(window_length=20)
Пример #4
0
 def __init__(self, trainer):
     self.trainer = trainer
     self.outputs = []
     self.step_metrics = []
     self.predictions = None
     self.max_batches = None
     self.warning_cache = WarningCache()
     self.num_dataloaders = None
Пример #5
0
 def __init__(self, trainer):
     self.trainer = trainer
     self.early_stopping_accumulator = None
     self.checkpoint_accumulator = None
     self.accumulated_loss = None
     self.warning_cache = WarningCache()
     self._teardown_already_run = False
     self.running_loss = TensorRunningAccum(window_length=20)
     self.automatic_optimization = True
     self._curr_step_result = None
     self._cur_grad_norm_dict = None
Пример #6
0
    def __init__(self,
                 name: Optional[str] = None,
                 save_dir: Optional[str] = None,
                 offline: bool = False,
                 id: Optional[str] = None,
                 anonymous: bool = False,
                 version: Optional[str] = None,
                 project: Optional[str] = None,
                 log_model: bool = False,
                 experiment=None,
                 prefix: str = '',
                 **kwargs):
        if wandb is None:
            raise ImportError(
                'You want to use `wandb` logger which is not installed yet,'  # pragma: no-cover
                ' install it with `pip install wandb`.')

        if offline and log_model:
            raise MisconfigurationException(
                f'Providing log_model={log_model} and offline={offline} is an invalid configuration'
                ' since model checkpoints cannot be uploaded in offline mode.\n'
                'Hint: Set `offline=False` to log your model.')

        super().__init__()
        self._name = name
        self._save_dir = save_dir
        self._anonymous = 'allow' if anonymous else None
        self._id = version or id
        self._project = project
        self._experiment = experiment
        self._offline = offline
        self._log_model = log_model
        self._prefix = prefix
        self._kwargs = kwargs
        # logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
        self._step_offset = 0
        self.warning_cache = WarningCache()
Пример #7
0
class TrainLoop:
    def __init__(self, trainer):
        self.trainer = trainer
        self.early_stopping_accumulator = None
        self.checkpoint_accumulator = None
        self.accumulated_loss = None
        self.warning_cache = WarningCache()
        self._teardown_already_run = False
        self.running_loss = TensorRunningAccum(window_length=20)
        self.automatic_optimization = True
        self._curr_step_result = None
        self._cur_grad_norm_dict = None

    def on_trainer_init(self, max_epochs, min_epochs, max_steps, min_steps,
                        num_sanity_val_steps, automatic_optimization):
        self.trainer.global_step = 0
        self.trainer.current_epoch = 0
        self.trainer.interrupted = False
        self.trainer.should_stop = False
        self.trainer._state = TrainerState.INITIALIZING

        self.trainer.total_batch_idx = 0
        self.trainer.batch_idx = 0
        self.trainer.num_training_batches = 0
        self.trainer.train_dataloader = None
        self.automatic_optimization = automatic_optimization

        self.trainer.max_epochs = max_epochs
        self.trainer.min_epochs = min_epochs
        self.trainer.max_steps = max_steps
        self.trainer.min_steps = min_steps

        if num_sanity_val_steps == -1:
            self.trainer.num_sanity_val_steps = float("inf")
        else:
            self.trainer.num_sanity_val_steps = num_sanity_val_steps

    @property
    def num_optimizers(self):
        num_optimizers = len(self.get_optimizers_iterable())
        return num_optimizers

    def should_skip_training(self):
        if self.trainer.current_epoch >= self.trainer.max_epochs:
            return True

        if self.trainer.limit_train_batches == 0:
            return True

        return False

    def on_train_start(self):
        # clear cache before training
        if self.trainer.on_gpu and self.trainer.root_gpu is not None:
            # use context because of:
            # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
            with torch.cuda.device(f"cuda:{self.trainer.root_gpu}"):
                torch.cuda.empty_cache()

        # hook
        self.trainer.call_hook("on_train_start")

    def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
        # bind logger and other properties
        self.trainer.model_connector.copy_trainer_model_properties(model)

        # clean hparams
        if hasattr(model, "hparams"):
            parsing.clean_namespace(model.hparams)

        # links data to the trainer
        self.trainer.data_connector.attach_data(model, train_dataloader,
                                                val_dataloaders, datamodule)

        # check that model is configured correctly
        self.trainer.config_validator.verify_loop_configurations(model)

    def setup_training(self, model: LightningModule):
        """Sanity check a few things before starting actual training.

        Args:
            model: The model to run sanity test on.
        """
        # --------------------------
        # Setup??
        # --------------------------
        ref_model = model
        if self.trainer.data_parallel:
            ref_model = model.module

        # set the ranks and devices
        self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank
        self.trainer.accelerator_backend.dist.device = ref_model.device

        # give model convenience properties
        ref_model.trainer = self.trainer

        # set local properties on the model
        self.trainer.model_connector.copy_trainer_model_properties(ref_model)

        # init amp. Must be done here instead of __init__ to allow ddp to work
        if self.trainer.amp_backend == AMPType.NATIVE and self.trainer.precision == 16 and not self.trainer.use_tpu:
            self.trainer.scaler = self.trainer.precision_connector.backend.scaler

        # log hyper-parameters
        if self.trainer.logger is not None:
            # save exp to get started (this is where the first experiment logs are written)
            self.trainer.logger.log_hyperparams(ref_model.hparams_initial)
            self.trainer.logger.log_graph(ref_model)
            self.trainer.logger.save()

        # wait for all to join if on distributed
        self.trainer.accelerator_backend.barrier("setup_training")

        # register auto-resubmit when on SLURM
        self.trainer.slurm_connector.register_slurm_signal_handlers()

        # --------------------------
        # Pre-train
        # --------------------------
        # on pretrain routine start
        self.trainer.on_pretrain_routine_start(ref_model)
        if self.trainer.is_function_implemented("on_pretrain_routine_start"):
            ref_model.on_pretrain_routine_start()

        # print model summary
        if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing:
            if self.trainer.weights_summary in ModelSummary.MODES:
                ref_model.summarize(mode=self.trainer.weights_summary)
            else:
                raise MisconfigurationException(
                    "weights_summary can be None, " +
                    ", ".join(ModelSummary.MODES))

        # track model now.
        # if cluster resets state, the model will update with the saved weights
        self.trainer.model = model

        # restore training and model before hpc is called
        self.trainer.checkpoint_connector.restore_weights(model)

        # on pretrain routine end
        self.trainer.on_pretrain_routine_end(ref_model)
        if self.trainer.is_function_implemented("on_pretrain_routine_end"):
            ref_model.on_pretrain_routine_end()

    def on_train_end(self):
        if self._teardown_already_run:
            return

        self._teardown_already_run = True

        # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
        # when a checkpoint was saved at the last step
        self.trainer.global_step -= 1
        self.check_checkpoint_callback(should_save=True, is_last=True)
        self.trainer.global_step += 1

        # hook
        self.trainer.call_hook("on_train_end")

        # kill loggers
        if self.trainer.logger is not None:
            self.trainer.logger.finalize("success")

        # summarize profile results
        if self.trainer.global_rank == 0:
            self.trainer.profiler.describe()

        # give accelerators a chance to finish
        self.trainer.accelerator_backend.on_train_end()

        # clear mem
        if self.trainer.on_gpu:
            model = self.trainer.get_model()
            model.cpu()
            torch.cuda.empty_cache()

    def check_checkpoint_callback(self, should_save, is_last=False):
        # TODO bake this logic into the checkpoint callback
        if should_save and self.trainer.checkpoint_connector.has_trained:
            checkpoint_callbacks = [
                c for c in self.trainer.callbacks
                if isinstance(c, ModelCheckpoint)
            ]

            if is_last and any(c.save_last for c in checkpoint_callbacks):
                rank_zero_info("Saving latest checkpoint...")

            model = self.trainer.get_model()

            for callback in checkpoint_callbacks:
                callback.on_validation_end(self.trainer, model)

    def on_train_epoch_start(self, epoch):

        # update training progress in trainer
        self.trainer.current_epoch = epoch

        model = self.trainer.get_model()

        # reset train dataloader
        if self.trainer.reload_dataloaders_every_epoch:
            self.trainer.reset_train_dataloader(model)

        # set seed for distributed sampler (enables shuffling for each epoch)
        try:
            self.trainer.train_dataloader.sampler.set_epoch(epoch)
        except Exception:
            pass

        # changing gradient according accumulation_scheduler
        self.trainer.accumulation_scheduler.on_epoch_start(
            self.trainer, self.trainer.get_model())

        # stores accumulated grad fractions per batch
        self.accumulated_loss = TensorRunningAccum(
            window_length=self.trainer.accumulate_grad_batches)

        # structured result accumulators for callbacks
        self.early_stopping_accumulator = Accumulator()
        self.checkpoint_accumulator = Accumulator()

        # hook
        self.trainer.call_hook("on_epoch_start")
        self.trainer.call_hook("on_train_epoch_start")

    def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch,
                           batch_idx, dataloader_idx):
        # hook
        self.trainer.call_hook('on_batch_end')
        self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch,
                               batch_idx, dataloader_idx)

        # figure out what to track for epoch end
        self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)

        # reset batch logger internals
        self.trainer.logger_connector.on_train_batch_end()

    def reset_train_val_dataloaders(self, model):
        if not self.trainer.reload_dataloaders_every_epoch:
            self.trainer.reset_train_dataloader(model)

        if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch:
            self.trainer.reset_val_dataloader(model)

    def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs):
        # track the outputs to reduce at the end of the epoch
        for opt_idx, opt_outputs in enumerate(epoch_end_outputs):
            # with 1 step (no tbptt) don't use a sequence at epoch end
            if isinstance(opt_outputs,
                          list) and len(opt_outputs) == 1 and not isinstance(
                              opt_outputs[0], Result):
                opt_outputs = opt_outputs[0]
            epoch_output[opt_idx].append(opt_outputs)

    def get_optimizers_iterable(self):
        """
        Generates an iterable with (idx, optimizer) for each optimizer.
        """
        if not self.trainer.optimizer_frequencies:
            # call training_step once per optimizer
            return list(enumerate(self.trainer.optimizers))

        optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
        optimizers_loop_length = optimizer_freq_cumsum[-1]
        current_place_in_loop = self.trainer.total_batch_idx % optimizers_loop_length

        # find optimzier index by looking for the first {item > current_place} in the cumsum list
        opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
        return [[opt_idx, self.trainer.optimizers[opt_idx]]]

    def on_after_backward(self, training_step_output, batch_idx,
                          untouched_loss):
        is_result_obj = isinstance(training_step_output, Result)

        if is_result_obj:
            training_step_output.detach()
        else:
            training_step_output.batch_loss = training_step_output.batch_loss.detach(
            )

        # insert after step hook
        self.trainer.call_hook("on_after_backward")

        # when in dev debugging track the losses
        self.trainer.dev_debugger.track_train_loss_history(
            batch_idx, untouched_loss.detach())

    def _check_training_step_output(self, training_step_output):
        if isinstance(training_step_output,
                      torch.Tensor) and not self.automatic_optimization:
            if training_step_output.grad_fn is None:
                # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
                raise MisconfigurationException(
                    "In manual optimization, `training_step` should not return a Tensor"
                )

    def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
        # give the PL module a result for logging
        model_ref = self.trainer.get_model()

        with self.trainer.profiler.profile("model_forward"):
            args = self.build_train_args(split_batch, batch_idx, opt_idx,
                                         hiddens)

            # manually capture logged metrics
            model_ref._current_fx_name = 'training_step'
            model_ref._results = Result()
            training_step_output = self.trainer.accelerator_backend.training_step(
                args)
            self.trainer.logger_connector.cache_logged_metrics()

            self._check_training_step_output(training_step_output)

            training_step_output = self.trainer.call_hook(
                "training_step_end", training_step_output)

            training_step_output_for_epoch_end, training_step_output = self._process_training_step_output(
                training_step_output, split_batch)
            is_result_obj = isinstance(training_step_output, Result)

            if training_step_output_for_epoch_end is None:
                return None

        # enable empty loss when using manual opt
        closure_loss = None
        untouched_loss = None

        if self.trainer.train_loop.automatic_optimization:
            # accumulate loss
            # (if accumulate_grad_batches = 1 no effect)
            if is_result_obj:
                closure_loss = training_step_output.minimize
            else:
                closure_loss = training_step_output.batch_loss

            closure_loss = closure_loss / self.trainer.accumulate_grad_batches

            # the loss will get scaled for amp. avoid any modifications to it
            untouched_loss = closure_loss.detach().clone()

        # result
        result = AttributeDict(
            closure_loss=closure_loss,
            loss=untouched_loss,
            training_step_output=training_step_output,
            training_step_output_for_epoch_end=
            training_step_output_for_epoch_end,
            hiddens=training_step_output.hiddens,
        )
        return result

    def _process_training_step_output(self, training_step_output, split_batch):
        training_step_output_for_epoch_end = training_step_output

        # enable validation_step return None
        if training_step_output_for_epoch_end is None:
            return None, None

        # -----------------------------------------
        # process result return (DEPRECATE in 1.0)
        # -----------------------------------------
        if isinstance(training_step_output, Result):
            training_step_output_for_epoch_end = self._process_result(
                training_step_output, split_batch)
            return training_step_output_for_epoch_end, training_step_output

        # -----------------------------------------
        # process hybrid (1.0)
        # -----------------------------------------
        # no need for these checks in 1.0.0
        # TODO: remove checks in 1.0.0
        is_tensor = isinstance(training_step_output_for_epoch_end,
                               torch.Tensor)
        is_1_0_output = is_tensor or ("log" not in training_step_output
                                      and "progress_bar"
                                      not in training_step_output)
        if is_1_0_output:
            return self._process_training_step_output_1_0(
                training_step_output, split_batch)

        # -----------------------------------------
        # process old dict (deprecate 1.0)
        # -----------------------------------------
        training_step_output = self.trainer.process_dict_result(
            training_step_output, train=True)

        training_step_output = AttributeDict(
            batch_loss=training_step_output[0],
            pbar_on_batch_end=training_step_output[1],
            log_metrics=training_step_output[2],
            callback_metrics=training_step_output[3],
            hiddens=training_step_output[4],
        )
        # if the user decides to finally reduce things in epoch_end, save raw output without graphs
        if isinstance(training_step_output_for_epoch_end, torch.Tensor):
            training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach(
            )
        else:
            training_step_output_for_epoch_end = recursive_detach(
                training_step_output_for_epoch_end)

        return training_step_output_for_epoch_end, training_step_output

    def _process_training_step_output_1_0(self, training_step_output,
                                          split_batch):
        result = self.trainer.get_model()._results

        loss = None
        hiddens = None

        # handle dict return
        if isinstance(training_step_output, dict):
            loss = training_step_output.pop("loss", None)
            hiddens = training_step_output.pop("hiddens", None)
            result["extra"] = training_step_output

        # handle scalar return
        elif isinstance(training_step_output, torch.Tensor):
            loss = training_step_output
            result["extra"] = {}

        # map to results under the hood
        result.minimize = loss
        result.hiddens = hiddens

        # track batch for manual reduction with result
        result.track_batch_size(len(split_batch))

        # track metrics without grads for epoch reduction
        training_step_output_for_epoch_end = copy(result)
        training_step_output_for_epoch_end.detach()
        if self.trainer.move_metrics_to_cpu:
            training_step_output_for_epoch_end.cpu()

        # what flows back into the system
        training_step_output = result

        return training_step_output_for_epoch_end, training_step_output

    def _process_result(self, training_step_output, split_batch):
        training_step_output.track_batch_size(len(split_batch))
        m = """
        TrainResult and EvalResult were deprecated in 0.9.1 and support will drop in 1.0.0.
        Use self.log and .write from the LightningModule to log metrics and write predictions.
        training_step can now only return a scalar (for the loss) or a dictionary with anything you want.

        Option 1:
        return loss

        Option 2:
        return {'loss': loss, 'anything_else': ...}

        Option 3:
        return {'loss': loss, 'hiddens': hiddens, 'anything_else': ...}
            """
        rank_zero_warn(m)

        # don't allow EvalResult in the training_step
        if isinstance(training_step_output, EvalResult):
            raise MisconfigurationException(
                "training_step cannot return EvalResult, "
                "use a dict or TrainResult instead")

        training_step_output_for_epoch_end = copy(training_step_output)
        training_step_output_for_epoch_end.detach()

        return training_step_output_for_epoch_end

    def optimizer_step(self, optimizer, opt_idx, batch_idx,
                       train_step_and_backward_closure):
        model_ref = self.trainer.get_model()

        is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
        using_native_amp = self.trainer.amp_backend == AMPType.NATIVE

        # native amp + lbfgs is a no go right now
        if using_native_amp and is_lbfgs:
            raise MisconfigurationException(
                'native PyTorch amp and lbfgs are not compatible.'
                ' To request, please file a Github issue in PyTorch and tag @mcarilli'
            )

        # model hook
        model_ref.optimizer_step(
            self.trainer.current_epoch,
            batch_idx,
            optimizer,
            opt_idx,
            train_step_and_backward_closure,
            on_tpu=self.trainer.use_tpu and TPU_AVAILABLE,
            using_native_amp=using_native_amp,
            using_lbfgs=is_lbfgs,
        )

    def on_before_zero_grad(self, optimizer):
        self.trainer.call_hook('on_before_zero_grad', optimizer)

    def track_and_norm_grad(self, optimizer):
        # track gradient norms
        grad_norm_dic = self._track_gradient_norm()

        # clip gradients
        self.trainer.accelerator_backend.clip_gradients(optimizer)
        self._cur_grad_norm_dict = grad_norm_dic

    def _track_gradient_norm(self):
        grad_norm_dict = {}
        if (self.trainer.global_step +
                1) % self.trainer.log_every_n_steps == 0:
            if float(self.trainer.track_grad_norm) > 0:
                model = self.trainer.get_model()
                grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm)
        return grad_norm_dict

    def process_hiddens(self, opt_closure_result):
        hiddens = opt_closure_result.hiddens
        if isinstance(opt_closure_result.training_step_output, Result):
            opt_closure_result.training_step_output_for_epoch_end.drop_hiddens(
            )
        return hiddens

    def tbptt_split_batch(self, batch):
        splits = [batch]
        if self.trainer.truncated_bptt_steps is not None:
            model_ref = self.trainer.get_model()
            with self.trainer.profiler.profile("tbptt_split_batch"):
                splits = model_ref.tbptt_split_batch(
                    batch, self.trainer.truncated_bptt_steps)
        return splits

    def run_training_epoch(self):

        # get model
        model = self.trainer.get_model()

        # modify dataloader if needed (ddp, etc...)
        train_dataloader = self.trainer.accelerator_backend.process_dataloader(
            self.trainer.train_dataloader)

        # track epoch output
        epoch_output = [[] for _ in range(self.num_optimizers)]

        # enable profiling for the dataloader
        train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(
            train_dataloader)
        dataloader_idx = 0
        should_check_val = False
        for batch_idx, (batch, is_last_batch) in train_dataloader:

            self.trainer.batch_idx = batch_idx

            # ------------------------------------
            # TRAINING_STEP + TRAINING_STEP_END
            # ------------------------------------
            with self.trainer.profiler.profile("run_training_batch"):
                batch_output = self.run_training_batch(batch, batch_idx,
                                                       dataloader_idx)

            # when returning -1 from train_step, we end epoch early
            if batch_output.signal == -1:
                break

            # only track outputs when user implements training_epoch_end
            # otherwise we will build up unnecessary memory
            epoch_end_outputs = self.process_train_step_outputs(
                batch_output.training_step_output_for_epoch_end,
                self.early_stopping_accumulator,
                self.checkpoint_accumulator,
            )

            # hook
            # TODO: add outputs to batches
            self.on_train_batch_end(epoch_output, epoch_end_outputs, batch,
                                    batch_idx, dataloader_idx)

            # -----------------------------------------
            # SAVE METRICS TO LOGGERS
            # -----------------------------------------
            self.trainer.logger_connector.log_train_step_metrics(batch_output)

            # -----------------------------------------
            # VALIDATE IF NEEDED + CHECKPOINT CALLBACK
            # -----------------------------------------
            should_check_val = self.should_check_val_fx(
                batch_idx, is_last_batch)
            if should_check_val:
                self.trainer.run_evaluation(test_mode=False)
                # reset stage to train
                self.trainer.logger_connector.set_stage("train")

            # -----------------------------------------
            # SAVE LOGGERS (ie: Tensorboard, etc...)
            # -----------------------------------------
            self.save_loggers_on_train_batch_end()

            # update LR schedulers
            monitor_metrics = deepcopy(
                self.trainer.logger_connector.callback_metrics)
            self.update_train_loop_lr_schedulers(
                monitor_metrics=monitor_metrics)
            self.trainer.checkpoint_connector.has_trained = True

            # max steps reached, end training
            if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1:
                accumulation_done = self._accumulated_batches_reached()
                # Ensure accumulation across batches has completed before breaking loop
                if accumulation_done:
                    break

            # end epoch early
            # stop when the flag is changed or we've gone past the amount
            # requested in the batches
            if self.trainer.should_stop:
                break

            self.trainer.total_batch_idx += 1

            # stop epoch if we limited the number of training batches
            if (batch_idx + 1) >= self.trainer.num_training_batches:
                break

            # progress global step according to grads progress
            self.increment_accumulated_grad_global_step()

        # epoch end hook
        self.run_on_epoch_end_hook(epoch_output)

        # log epoch metrics
        self.trainer.logger_connector.log_train_epoch_end_metrics(
            epoch_output, self.checkpoint_accumulator,
            self.early_stopping_accumulator, self.num_optimizers)

        # when no val loop is present or fast-dev-run still need to call checkpoints
        self.check_checkpoint_callback(not (
            should_check_val or is_overridden('validation_step', model)))

        # increment the global step once
        # progress global step according to grads progress
        self.increment_accumulated_grad_global_step()

    def run_training_batch(self, batch, batch_idx, dataloader_idx):
        # track grad norms
        grad_norm_dic = {}

        # bookkeeping
        using_results_obj = False
        self.trainer.hiddens = None

        # track all outputs across time and num of optimizers
        batch_outputs = [[]
                         for _ in range(len(self.get_optimizers_iterable()))]

        if batch is None:
            return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)

        # hook
        response = self.trainer.call_hook("on_batch_start")
        if response == -1:
            return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)

        # hook
        response = self.trainer.call_hook("on_train_batch_start", batch,
                                          batch_idx, dataloader_idx)
        if response == -1:
            return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)

        # lightning module hook
        splits = self.tbptt_split_batch(batch)

        for split_idx, split_batch in enumerate(splits):

            # create an iterable for optimizers and loop over them
            for opt_idx, optimizer in self.prepare_optimizers():

                # toggle model params + set info to logger_connector
                self.run_train_split_start(split_idx, split_batch, opt_idx,
                                           optimizer)

                if self.should_accumulate():
                    # For gradient accumulation

                    # -------------------
                    # calculate loss (train step + train step end)
                    # -------------------

                    # automatic_optimization=True: perform dpp sync only when performing optimizer_step
                    # automatic_optimization=False: don't block synchronization here
                    with self.block_ddp_sync_behaviour():
                        self.training_step_and_backward(
                            split_batch, batch_idx, opt_idx, optimizer,
                            self.trainer.hiddens)

                    batch_outputs = self._process_closure_result(
                        batch_outputs=batch_outputs,
                        opt_idx=opt_idx,
                    )

                # ------------------------------
                # BACKWARD PASS
                # ------------------------------
                # gradient update with accumulated gradients

                else:
                    if self.automatic_optimization:

                        def train_step_and_backward_closure():
                            result = self.training_step_and_backward(
                                split_batch, batch_idx, opt_idx, optimizer,
                                self.trainer.hiddens)
                            return None if result is None else result.loss

                        # optimizer step
                        self.optimizer_step(optimizer, opt_idx, batch_idx,
                                            train_step_and_backward_closure)

                    else:
                        self._curr_step_result = self.training_step(
                            split_batch, batch_idx, opt_idx,
                            self.trainer.hiddens)

                    if self._curr_step_result is None:
                        # user decided to skip optimization
                        # make sure to zero grad.
                        continue

                    batch_outputs = self._process_closure_result(
                        batch_outputs=batch_outputs,
                        opt_idx=opt_idx,
                    )

                    # todo: Properly aggregate grad_norm accros opt_idx and split_idx
                    grad_norm_dic = self._cur_grad_norm_dict
                    self._cur_grad_norm_dict = None

                    # update running loss + reset accumulated loss
                    self.update_running_loss()

        result = AttributeDict(
            signal=0,
            grad_norm_dic=grad_norm_dic,
            training_step_output_for_epoch_end=batch_outputs,
        )
        return result

    @contextmanager
    def block_ddp_sync_behaviour(self):
        """
        automatic_optimization = True
        Blocks ddp sync gradients behaviour on backwards pass.
        This is useful for skipping sync when accumulating gradients, reducing communication overhead

        automatic_optimization = False
        do not block ddp gradient sync when using manual optimization
        as gradients are needed within the training step

        Returns: context manager with sync behaviour off

        """
        if self.trainer.accelerator_backend is not None and self.automatic_optimization:
            yield self.trainer.accelerator_backend.block_ddp_plugin_sync_behaviour(
            )
        else:
            yield None

    def _process_closure_result(self, batch_outputs: list,
                                opt_idx: int) -> list:
        opt_closure_result = self._curr_step_result

        if opt_closure_result is not None:

            # cache metrics
            self.trainer.logger_connector.cache_training_step_metrics(
                opt_closure_result)

            # track hiddens
            self.trainer.hiddens = self.process_hiddens(opt_closure_result)

            # check if loss or model weights are nan
            if self.trainer.terminate_on_nan:
                self.trainer.detect_nan_tensors(opt_closure_result.loss)

            # track all the outputs across all steps
            batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0
            batch_outputs[batch_opt_idx].append(
                opt_closure_result.training_step_output_for_epoch_end)

            if self.automatic_optimization:
                # track total loss for logging (avoid mem leaks)
                self.accumulated_loss.append(opt_closure_result.loss)

        self._curr_step_result = None

        return batch_outputs

    def training_step_and_backward(self, split_batch, batch_idx, opt_idx,
                                   optimizer, hiddens):
        """
        wrap the forward step in a closure so second order methods work
        """
        with self.trainer.profiler.profile("training_step_and_backward"):
            # lightning module hook
            result = self.training_step(split_batch, batch_idx, opt_idx,
                                        hiddens)
            self._curr_step_result = result

            if result is None:
                self.warning_cache.warn(
                    "training_step returned None if it was on purpose, ignore this warning..."
                )
                return None

            if self.trainer.train_loop.automatic_optimization:
                # backward pass
                with self.trainer.profiler.profile("model_backward"):
                    self.backward(result, optimizer, opt_idx)

                # hook - call this hook only
                # when gradients have finished to accumulate
                if not self.should_accumulate():
                    self.on_after_backward(result.training_step_output,
                                           batch_idx, result.loss)

                # check if loss or model weights are nan
                if self.trainer.terminate_on_nan:
                    self.trainer.detect_nan_tensors(result.loss)

        return result

    def backward(self, result, optimizer, opt_idx, *args, **kwargs):
        self.trainer.dev_debugger.track_event("backward_call")

        # backward can be called manually in the training loop
        if isinstance(result, torch.Tensor):
            self.trainer.accelerator_backend.backward(result, optimizer,
                                                      opt_idx, *args, **kwargs)
        else:
            result.closure_loss = self.trainer.accelerator_backend.backward(
                result.closure_loss, optimizer, opt_idx, *args, **kwargs)

        if not self.should_accumulate():
            # track gradients
            self.track_and_norm_grad(optimizer=optimizer)

    def update_train_loop_lr_schedulers(self, monitor_metrics=None):
        num_accumulated_batches_reached = self._accumulated_batches_reached()
        num_training_batches_reached = self._num_training_batches_reached()

        if num_accumulated_batches_reached or num_training_batches_reached:
            # update lr
            self.trainer.optimizer_connector.update_learning_rates(
                interval="step", monitor_metrics=monitor_metrics)

    def run_on_epoch_end_hook(self, epoch_output):
        # inform logger the batch loop has finished
        self.trainer.logger_connector.on_train_epoch_end()

        self.trainer.call_hook('on_epoch_end')
        self.trainer.call_hook('on_train_epoch_end', epoch_output)

    def increment_accumulated_grad_global_step(self):
        num_accumulated_batches_reached = self._accumulated_batches_reached()
        num_training_batches_reached = self._num_training_batches_reached()

        # progress global step according to grads progress
        if num_accumulated_batches_reached or num_training_batches_reached:
            self.trainer.global_step += 1

    def _accumulated_batches_reached(self):
        return (self.trainer.batch_idx +
                1) % self.trainer.accumulate_grad_batches == 0

    def _num_training_batches_reached(self):
        return (self.trainer.batch_idx +
                1) == self.trainer.num_training_batches

    def should_accumulate(self):
        # checks if backward or backward + optimizer step (via closure)
        accumulation_done = self._accumulated_batches_reached()
        is_final_batch = self._num_training_batches_reached()
        return not (accumulation_done or is_final_batch)

    def should_check_val_fx(self, batch_idx, is_last_batch):
        # decide if we should run validation
        is_val_check_batch = (batch_idx +
                              1) % self.trainer.val_check_batch == 0
        is_val_check_epoch = (self.trainer.current_epoch +
                              1) % self.trainer.check_val_every_n_epoch == 0
        can_check_val = self.trainer.enable_validation and is_val_check_epoch
        should_check_val = is_val_check_batch or self.trainer.should_stop
        is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float(
            "inf")
        should_check_val = can_check_val and (
            should_check_val or is_last_batch_for_infinite_dataset)

        return should_check_val

    def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
        # enable not needing to add opt_idx to training_step
        args = [batch, batch_idx]

        if len(self.trainer.optimizers) > 1:
            if self.trainer.has_arg("training_step", "optimizer_idx"):
                args.append(opt_idx)
            else:
                num_opts = len(self.trainer.optimizers)
                raise ValueError(
                    f"Your LightningModule defines {num_opts} optimizers but "
                    f'training_step is missing the "optimizer_idx" argument.')

        # pass hiddens if using tbptt
        if self.trainer.truncated_bptt_steps is not None:
            args.append(hiddens)

        return args

    def save_loggers_on_train_batch_end(self):
        # when loggers should save to disk
        should_flush_logs = self.trainer.logger_connector.should_flush_logs
        if should_flush_logs or self.trainer.fast_dev_run is True:
            if self.trainer.is_global_zero and self.trainer.logger is not None:
                self.trainer.logger.save()

    def process_train_step_outputs(self, all_train_step_outputs,
                                   early_stopping_accumulator,
                                   checkpoint_accumulator):
        """
        Figure out what needs to be tracked/logged at the end of the epoch
        """

        # the training step outputs a list per optimizer. The list contains the outputs at each time step
        # when no TBPTT is used, then the list has 1 item per batch
        # when TBPTT IS used, then the list has n items (1 per time step)
        epoch_end_outputs = []
        for optimizer_idx_outputs in all_train_step_outputs:
            # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer
            if len(optimizer_idx_outputs) == 0:
                continue

            sample_output = optimizer_idx_outputs[-1]

            # pull out callback info if available (ie: Results object)
            if isinstance(sample_output,
                          dict) and "early_stop_on" in sample_output:
                early_stopping_accumulator.accumulate(
                    sample_output["early_stop_on"])

            if isinstance(sample_output,
                          dict) and "checkpoint_on" in sample_output:
                checkpoint_accumulator.accumulate(
                    sample_output["checkpoint_on"])

            # decide if we need to reduce at the end of the epoch automatically
            auto_reduce_tng_result = isinstance(
                sample_output,
                Result) and sample_output.should_reduce_on_epoch_end

            # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
            if is_overridden(
                    "training_epoch_end",
                    model=self.trainer.get_model()) or auto_reduce_tng_result:
                epoch_end_outputs.append(optimizer_idx_outputs)

        return epoch_end_outputs

    def prepare_optimizers(self):
        # in manual optimization we loop over all optimizers at once
        optimizers = self.get_optimizers_iterable()
        if not self.automatic_optimization:
            optimizers = [optimizers[0]]
        return optimizers

    def run_train_split_start(self, split_idx, split_batch, opt_idx,
                              optimizer):
        # set split_idx to trainer for tracking
        self.trainer.split_idx = split_idx

        # make sure only the gradients of the current optimizer's parameters are calculated
        # in the training step to prevent dangling gradients in multiple-optimizer setup.
        if self.automatic_optimization and len(self.trainer.optimizers) > 1:
            model = self.trainer.get_model()
            model.toggle_optimizer(optimizer, opt_idx)

        # use to track metrics internally
        self.trainer.logger_connector.on_train_split_start(
            split_idx, opt_idx, split_batch)

    def update_running_loss(self):
        accumulated_loss = self.accumulated_loss.mean()

        if accumulated_loss is not None:
            # calculate running loss for display
            self.running_loss.append(self.accumulated_loss.mean() *
                                     self.trainer.accumulate_grad_batches)

        # reset for next set of accumulated grads
        self.accumulated_loss.reset()
Пример #8
0
def get_a_var(obj):  # pragma: no-cover
    if isinstance(obj, torch.Tensor):
        return obj

    if isinstance(obj, (list, tuple)):
        for result in map(get_a_var, obj):
            if isinstance(result, torch.Tensor):
                return result
    if isinstance(obj, dict):
        for result in map(get_a_var, obj.items()):
            if isinstance(result, torch.Tensor):
                return result
    return None


warning_cache = WarningCache()


class LightningDataParallel(DataParallel):
    """
    Override the forward call in lightning so it goes to training and validation step respectively
    """

    def forward(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module(*inputs, **kwargs)

        for t in chain(self.module.parameters(), self.module.buffers()):
            if t.device != self.src_device_obj:
                raise RuntimeError("module must have its parameters and buffers "
                                   "on device {} (device_ids[0]) but found one of "
Пример #9
0
class EvaluationLoop(object):
    def __init__(self, trainer):
        self.trainer = trainer
        self.testing = False
        self.outputs = []
        self.step_metrics = []
        self.predictions = None
        self.max_batches = None
        self.warning_cache = WarningCache()
        self.num_dataloaders = None

    def on_trainer_init(self):
        self.trainer.num_val_batches = []
        self.trainer.num_sanity_val_batches = []
        self.trainer.num_test_batches = []
        self.trainer.test_dataloaders = None
        self.trainer.val_dataloaders = None
        self.trainer.running_sanity_check = False
        self.trainer.testing = False

        # when .test() is called, it sets this
        self.trainer.tested_ckpt_path = None

        # when true, prints test results
        self.trainer.verbose_test = True

    def get_evaluation_dataloaders(self, max_batches):
        # select dataloaders
        model = self.trainer.get_model()

        # select dataloaders
        if self.testing:
            self.trainer.reset_test_dataloader(model)

            dataloaders = self.trainer.test_dataloaders
            new_max_batches = self.trainer.num_test_batches
        else:
            # val
            in_sanity_check = self.trainer.running_sanity_check
            should_reload_every_epoch = self.trainer.reload_dataloaders_every_epoch
            if (self.trainer.val_dataloaders is None
                    or should_reload_every_epoch) and not in_sanity_check:
                self.trainer.reset_val_dataloader(model)

            dataloaders = self.trainer.val_dataloaders
            new_max_batches = self.trainer.num_val_batches

        if max_batches is None:
            max_batches = new_max_batches

        return dataloaders, max_batches

    def should_skip_evaluation(self, dataloaders, max_batches):
        # skip when dataloaders aren't defined
        if dataloaders is None:
            return True

        # enable disabling validation step with limit_val_batches = 0
        should_skip = sum(max_batches) == 0
        if should_skip:
            return True

        return False

    def on_evaluation_start(self, *args, **kwargs):
        if self.testing:
            self.trainer.call_hook('on_test_start', *args, **kwargs)
        else:
            self.trainer.call_hook('on_validation_start', *args, **kwargs)

    def on_evaluation_model_eval(self, *args, **kwargs):
        model_ref = self.trainer.get_model()
        if self.testing:
            model_ref.on_test_model_eval()
        else:
            model_ref.on_validation_model_eval()

    def on_evaluation_model_train(self, *args, **kwargs):
        model_ref = self.trainer.get_model()
        if self.testing:
            model_ref.on_test_model_train()
        else:
            model_ref.on_validation_model_train()

    def on_evaluation_end(self, *args, **kwargs):
        if self.testing:
            self.trainer.call_hook('on_test_end', *args, **kwargs)
        else:
            self.trainer.call_hook('on_validation_end', *args, **kwargs)

        # reset stage to train
        self.trainer.logger_connector.set_stage("train")

    def reload_evaluation_dataloaders(self):
        model = self.trainer.get_model()
        if self.testing:
            self.trainer.reset_test_dataloader(model)
        else:
            self.trainer.reset_val_dataloader(model)

    def is_using_eval_results(self):
        outputs = self.outputs
        using_eval_result = len(outputs) > 0 and len(
            outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult)
        return using_eval_result

    def setup(self, model, max_batches, dataloaders):
        # copy properties for forward overrides
        self.trainer.model_connector.copy_trainer_model_properties(model)

        # bookkeeping
        self.outputs = []
        self.predictions = PredictionCollection(self.trainer.global_rank,
                                                self.trainer.world_size)

        # convert max_batches to list
        if isinstance(max_batches, int):
            max_batches = [max_batches] * len(dataloaders)

        self.max_batches = max_batches
        self.num_dataloaders = self._get_num_dataloaders(dataloaders)

    def on_evaluation_epoch_start(self, *args, **kwargs):
        if self.testing:
            self.trainer.call_hook('on_test_epoch_start', *args, **kwargs)
        else:
            self.trainer.call_hook('on_validation_epoch_start', *args,
                                   **kwargs)

    def build_args(self, test_mode, batch, batch_idx, dataloader_idx):
        # make dataloader_idx arg in validation_step optional
        args = [batch, batch_idx]

        multiple_val_loaders = (
            not test_mode
            and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1)
        multiple_test_loaders = (
            test_mode
            and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1)

        if multiple_test_loaders or multiple_val_loaders:
            args.append(dataloader_idx)

        return args

    def _get_num_dataloaders(self, dataloaders):
        # case where user does:
        # return dl1, dl2
        length = len(dataloaders)
        if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)):
            length = len(dataloaders[0])
        return length

    def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx):
        # configure args
        args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)

        # run actual test step
        if self.testing:
            output = self.trainer.accelerator_backend.test_step(args)
        else:
            output = self.trainer.accelerator_backend.validation_step(args)

        # track batch size for weighted average
        is_result_obj = isinstance(output, Result)
        if is_result_obj:
            output.track_batch_size(batch)

        # allow only EvalResult when using structured results (from val_step)
        if is_result_obj and not isinstance(output, EvalResult):
            m = 'only EvalResults or dicts are allowed from validation_step'
            raise MisconfigurationException(m)

        return output

    def evaluation_step_end(self, *args, **kwargs):
        if self.testing:
            output = self.trainer.call_hook('test_step_end', *args, **kwargs)
        else:
            output = self.trainer.call_hook('validation_step_end', *args,
                                            **kwargs)
        return output

    def evaluation_epoch_end(self, num_dataloaders):
        using_eval_result = self.is_using_eval_results()

        # call the model epoch end
        deprecated_results = self.__run_eval_epoch_end(num_dataloaders,
                                                       using_eval_result)

        # 1.0
        epoch_logs = self.trainer.get_model()._results

        # enable returning anything
        for i, r in enumerate(deprecated_results):
            if not isinstance(r, (dict, Result, torch.Tensor)):
                deprecated_results[i] = []

        return deprecated_results, epoch_logs

    def log_epoch_metrics(self, deprecated_eval_results, epoch_logs,
                          test_mode):
        using_eval_result = self.is_using_eval_results()
        eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end(
            deprecated_eval_results, epoch_logs, using_eval_result, test_mode)
        return eval_loop_results

    def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
        model = self.trainer.get_model()

        # reset results
        model._results = Result()

        # with a single dataloader don't pass an array
        outputs = self.outputs
        eval_results = outputs
        if num_dataloaders == 1:
            eval_results = outputs[0]

        user_reduced = False

        if self.testing:
            if is_overridden('test_epoch_end', model=model):
                model._current_fx_name = 'test_epoch_end'
                if using_eval_result:
                    eval_results = self.__gather_epoch_end_eval_results(
                        outputs)

                eval_results = model.test_epoch_end(eval_results)
                user_reduced = True

        else:
            if is_overridden('validation_epoch_end', model=model):
                model._current_fx_name = 'validation_epoch_end'
                if using_eval_result:
                    eval_results = self.__gather_epoch_end_eval_results(
                        outputs)

                eval_results = model.validation_epoch_end(eval_results)
                user_reduced = True

        # depre warning
        if eval_results is not None and user_reduced:
            step = 'testing_epoch_end' if self.testing else 'validation_epoch_end'
            self.warning_cache.warn(
                f'The {step} should not return anything as of 9.1.'
                ' To log, use self.log(...) or self.write(...) directly in the LightningModule'
            )

        if using_eval_result and not user_reduced:
            eval_results = self.__auto_reduce_result_objs(outputs)

        if not isinstance(eval_results, list):
            eval_results = [eval_results]

        return eval_results

    def __gather_epoch_end_eval_results(self, outputs):
        eval_results = []
        for epoch_output in outputs:
            result = epoch_output[0].__class__.gather(epoch_output)
            if 'checkpoint_on' in result:
                result.checkpoint_on = result.checkpoint_on.mean()
            if 'early_stop_on' in result:
                result.early_stop_on = result.early_stop_on.mean()

            eval_results.append(result)

        # with 1 dataloader don't pass in a list
        if len(eval_results) == 1:
            eval_results = eval_results[0]
        return eval_results

    def __auto_reduce_result_objs(self, outputs):
        # outputs has a list of results per dataloader
        eval_results = []
        for dl_output in outputs:
            result = dl_output[0]
            result = result.__class__.reduce_on_epoch_end(dl_output)
            if 'checkpoint_on' in result:
                result.checkpoint_on = result.checkpoint_on.mean()
            if 'early_stop_on' in result:
                result.early_stop_on = result.early_stop_on.mean()
            eval_results.append(result)

        return eval_results

    def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx):
        # reset the result of the PL module
        model = self.trainer.get_model()
        model._results = Result()
        model._current_fx_name = 'evaluation_step'

        # set dataloader_idx and track batch_size
        self.trainer.logger_connector.on_evaluation_batch_start(
            self.testing, batch, dataloader_idx, self.num_dataloaders)

        if self.testing:
            self.trainer.call_hook('on_test_batch_start', batch, batch_idx,
                                   dataloader_idx)
        else:
            self.trainer.call_hook('on_validation_batch_start', batch,
                                   batch_idx, dataloader_idx)

    def on_evaluation_batch_end(self, *args, **kwargs):
        if self.testing:
            self.trainer.call_hook('on_test_batch_end', *args, **kwargs)
        else:
            self.trainer.call_hook('on_validation_batch_end', *args, **kwargs)

    def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx):
        # Add step predictions to prediction collection to write later
        if output is not None:
            do_write_predictions = isinstance(output, Result) and self.testing
            if do_write_predictions:
                self.predictions.add(output.pop('predictions', None))

        # track debug metrics
        self.trainer.dev_debugger.track_eval_loss_history(
            self.testing, batch_idx, dataloader_idx, output)

    def on_evaluation_epoch_end(self, *args, **kwargs):
        # call the callback hook
        if self.testing:
            self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
        else:
            self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)

    def log_evaluation_step_metrics(self, batch, batch_idx):
        results = self.trainer.get_model()._results
        if len(results) == 1:
            return None

        results.track_batch_size(batch)
        self.__log_result_step_metrics(results, batch_idx)

        return results

    # TODO: deprecate at 1.0
    def log_evaluation_step_metrics_legacy(self, output, batch_idx):
        if self.trainer.running_sanity_check:
            return

        if isinstance(output, EvalResult):
            self.__log_result_step_metrics(output, batch_idx)

    def __log_result_step_metrics(self, output, batch_idx):
        step_log_metrics = output.get_batch_log_metrics(
            include_forked_originals=False)
        step_pbar_metrics = output.get_batch_pbar_metrics(
            include_forked_originals=False)

        cached_batch_log_metrics = \
            self.trainer.logger_connector.cached_results.get_latest_batch_log_metrics()

        if len(step_log_metrics) > 0:
            # make the metrics appear as a different line in the same graph
            metrics_by_epoch = {}
            for k, v in step_log_metrics.items():
                metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v

            self.trainer.logger_connector.log_metrics(metrics_by_epoch, {},
                                                      step=batch_idx)

        if len(step_pbar_metrics) > 0:
            self.trainer.logger_connector.add_progress_bar_metrics(
                step_pbar_metrics)
Пример #10
0
class TrainLoop:
    def __init__(self, trainer):
        self.trainer = trainer
        self.early_stopping_accumulator = None
        self.checkpoint_accumulator = None
        self.accumulated_loss = None
        self.warning_cache = WarningCache()
        self._teardown_already_run = False
        self.running_loss = TensorRunningAccum(window_length=20)

    def on_trainer_init(self, max_epochs, min_epochs, max_steps, min_steps,
                        num_sanity_val_steps):
        self.trainer.global_step = 0
        self.trainer.current_epoch = 0
        self.trainer.interrupted = False
        self.trainer.should_stop = False
        self.trainer._state = TrainerState.INITIALIZING

        self.trainer.total_batch_idx = 0
        self.trainer.batch_idx = 0
        self.trainer.num_training_batches = 0
        self.trainer.train_dataloader = None

        self.trainer.max_epochs = max_epochs
        self.trainer.min_epochs = min_epochs
        self.trainer.max_steps = max_steps
        self.trainer.min_steps = min_steps

        if num_sanity_val_steps == -1:
            self.trainer.num_sanity_val_steps = float('inf')
        else:
            self.trainer.num_sanity_val_steps = num_sanity_val_steps

    @property
    def num_optimizers(self):
        num_optimizers = len(self.get_optimizers_iterable())
        return num_optimizers

    def on_train_start(self):
        # clear cache before training
        if self.trainer.on_gpu and self.trainer.root_gpu is not None:
            # use context because of:
            # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
            with torch.cuda.device(f'cuda:{self.trainer.root_gpu}'):
                torch.cuda.empty_cache()

        # hook
        self.trainer.call_hook('on_train_start')

    def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
        # bind logger and other properties
        self.trainer.model_connector.copy_trainer_model_properties(model)

        # clean hparams
        if hasattr(model, 'hparams'):
            parsing.clean_namespace(model.hparams)

        # links data to the trainer
        self.trainer.data_connector.attach_data(model, train_dataloader,
                                                val_dataloaders, datamodule)

        # check that model is configured correctly
        self.trainer.config_validator.verify_loop_configurations(model)

    def setup_training(self, model: LightningModule):
        """Sanity check a few things before starting actual training.

        Args:
            model: The model to run sanity test on.
        """
        # --------------------------
        # Setup??
        # --------------------------
        ref_model = model
        if self.trainer.data_parallel:
            ref_model = model.module

        self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank
        self.trainer.accelerator_backend.dist.device = ref_model.device

        # give model convenience properties
        ref_model.trainer = self.trainer

        # set local properties on the model
        self.trainer.model_connector.copy_trainer_model_properties(ref_model)

        # init amp. Must be done here instead of __init__ to allow ddp to work
        if self.trainer.amp_backend == AMPType.NATIVE and self.trainer.precision == 16 and not self.trainer.use_tpu:
            self.trainer.scaler = torch.cuda.amp.GradScaler()

        # log hyper-parameters
        if self.trainer.logger is not None:
            # save exp to get started
            self.trainer.logger.log_hyperparams(ref_model.hparams)
            self.trainer.logger.log_graph(ref_model)
            self.trainer.logger.save()

        # wait for all to join if on distributed
        self.trainer.accelerator_backend.barrier('setup_training')

        # register auto-resubmit when on SLURM
        self.trainer.slurm_connector.register_slurm_signal_handlers()

        # --------------------------
        # Pre-train
        # --------------------------
        # on pretrain routine start
        self.trainer.on_pretrain_routine_start(ref_model)
        if self.trainer.is_function_implemented('on_pretrain_routine_start'):
            ref_model.on_pretrain_routine_start()

        # print model summary
        if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing:
            if self.trainer.weights_summary in ModelSummary.MODES:
                ref_model.summarize(mode=self.trainer.weights_summary)
            else:
                raise MisconfigurationException(
                    "weights_summary can be None, " +
                    ", ".join(ModelSummary.MODES))

        # track model now.
        # if cluster resets state, the model will update with the saved weights
        self.trainer.model = model

        # restore training and model before hpc is called
        self.trainer.checkpoint_connector.restore_weights(model)

        # on pretrain routine end
        self.trainer.on_pretrain_routine_end(ref_model)
        if self.trainer.is_function_implemented('on_pretrain_routine_end'):
            ref_model.on_pretrain_routine_end()

    def on_train_end(self):
        if self._teardown_already_run:
            return

        self._teardown_already_run = True

        # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
        # when a checkpoint was saved at the last step
        self.trainer.global_step -= 1
        self.check_checkpoint_callback(should_save=True, is_last=True)
        self.trainer.global_step += 1

        # hook
        self.trainer.call_hook('on_train_end')

        # kill loggers
        if self.trainer.logger is not None:
            self.trainer.logger.finalize("success")

        # summarize profile results
        if self.trainer.global_rank == 0:
            self.trainer.profiler.describe()

        # give accelerators a chance to finish
        self.trainer.accelerator_backend.on_train_end()

        # clear mem
        if self.trainer.on_gpu:
            model = self.trainer.get_model()
            model.cpu()
            torch.cuda.empty_cache()

    def check_checkpoint_callback(self, should_save, is_last=False):
        # TODO bake this logic into the checkpoint callback
        if should_save:
            checkpoint_callbacks = [
                c for c in self.trainer.callbacks
                if isinstance(c, ModelCheckpoint)
            ]
            if is_last and any(c.save_last for c in checkpoint_callbacks):
                rank_zero_info('Saving latest checkpoint...')
            model = self.trainer.get_model()
            [
                c.on_validation_end(self.trainer, model)
                for c in checkpoint_callbacks
            ]

    def on_train_epoch_start(self, epoch):
        model = self.trainer.get_model()

        # set seed for distributed sampler (enables shuffling for each epoch)
        try:
            self.trainer.train_dataloader.sampler.set_epoch(epoch)
        except Exception:
            pass

        # update training progress in trainer
        self.trainer.current_epoch = epoch

        # changing gradient according accumulation_scheduler
        self.trainer.accumulation_scheduler.on_epoch_start(
            self.trainer, self.trainer.get_model())

        # stores accumulated grad fractions per batch
        self.accumulated_loss = TensorRunningAccum(
            window_length=self.trainer.accumulate_grad_batches)

        # structured result accumulators for callbacks
        self.early_stopping_accumulator = Accumulator()
        self.checkpoint_accumulator = Accumulator()

        # hook
        self.trainer.call_hook('on_epoch_start')
        self.trainer.call_hook('on_train_epoch_start')

    def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch,
                           batch_idx, dataloader_idx):
        # figure out what to track for epoch end
        self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)

        # hook
        self.trainer.call_hook('on_batch_end')
        self.trainer.call_hook('on_train_batch_end', batch, batch_idx,
                               dataloader_idx)

    def reset_train_val_dataloaders(self, model):
        if not self.trainer.reload_dataloaders_every_epoch:
            self.trainer.reset_train_dataloader(model)

        if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch:
            self.trainer.reset_val_dataloader(model)

    def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs):
        # track the outputs to reduce at the end of the epoch
        for opt_idx, opt_outputs in enumerate(epoch_end_outputs):
            # with 1 step (no tbptt) don't use a sequence at epoch end
            if isinstance(opt_outputs,
                          list) and len(opt_outputs) == 1 and not isinstance(
                              opt_outputs[0], Result):
                opt_outputs = opt_outputs[0]
            epoch_output[opt_idx].append(opt_outputs)

    def get_optimizers_iterable(self):
        """
        Generates an iterable with (idx, optimizer) for each optimizer.
        """
        if not self.trainer.optimizer_frequencies:
            # call training_step once per optimizer
            return list(enumerate(self.trainer.optimizers))

        optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
        optimizers_loop_length = optimizer_freq_cumsum[-1]
        current_place_in_loop = self.trainer.total_batch_idx % optimizers_loop_length

        # find optimzier index by looking for the first {item > current_place} in the cumsum list
        opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
        return [(opt_idx, self.trainer.optimizers[opt_idx])]

    def backward(self, result, optimizer, opt_idx):
        # backward pass
        with self.trainer.profiler.profile('model_backward'):
            result.closure_loss = self.trainer.accelerator_backend.backward(
                result.closure_loss, optimizer, opt_idx)

    def on_after_backward(self, training_step_output, batch_idx,
                          untouched_loss):
        is_result_obj = isinstance(training_step_output, Result)

        if is_result_obj:
            training_step_output.detach()
        else:
            training_step_output.batch_loss = training_step_output.batch_loss.detach(
            )

        # insert after step hook
        self.trainer.call_hook('on_after_backward')

        # when in dev debugging track the losses
        self.trainer.dev_debugger.track_train_loss_history(
            batch_idx, untouched_loss.detach())

    def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
        # give the PL module a result for logging
        model = self.trainer.get_model()
        model._results = Result()
        model._current_fx_name = 'training_step'

        with self.trainer.profiler.profile('model_forward'):
            args = self.build_train_args(split_batch, batch_idx, opt_idx,
                                         hiddens)
            training_step_output = self.trainer.accelerator_backend.training_step(
                args)
            training_step_output = self.trainer.call_hook(
                'training_step_end', training_step_output)

            training_step_output_for_epoch_end, training_step_output = self._process_training_step_output(
                training_step_output, split_batch)
            is_result_obj = isinstance(training_step_output, Result)

            if training_step_output_for_epoch_end is None:
                return None

        # accumulate loss
        # (if accumulate_grad_batches = 1 no effect)
        if is_result_obj:
            closure_loss = training_step_output.minimize
        else:
            closure_loss = training_step_output.batch_loss

        closure_loss = closure_loss / self.trainer.accumulate_grad_batches

        # the loss will get scaled for amp. avoid any modifications to it
        untouched_loss = closure_loss.detach().clone()

        # result
        result = AttributeDict(
            closure_loss=closure_loss,
            loss=untouched_loss,
            training_step_output=training_step_output,
            training_step_output_for_epoch_end=
            training_step_output_for_epoch_end,
            hiddens=training_step_output.hiddens,
        )
        return result

    def _process_training_step_output(self, training_step_output, split_batch):
        training_step_output_for_epoch_end = training_step_output

        # enable validation_step return None
        if training_step_output_for_epoch_end is None:
            return None, None

        # -----------------------------------------
        # process result return (DEPRECATE in 1.0)
        # -----------------------------------------
        if isinstance(training_step_output, Result):
            training_step_output_for_epoch_end = self._process_result(
                training_step_output, split_batch)
            return training_step_output_for_epoch_end, training_step_output

        # -----------------------------------------
        # process hybrid (1.0)
        # -----------------------------------------
        # no need for these checks in 1.0.0
        # TODO: remove checks in 1.0.0
        is_tensor = isinstance(training_step_output_for_epoch_end,
                               torch.Tensor)
        is_1_0_output = is_tensor or ('log' not in training_step_output
                                      and 'progress_bar'
                                      not in training_step_output)
        if is_1_0_output:
            return self._process_training_step_output_1_0(
                training_step_output, split_batch)

        # -----------------------------------------
        # process old dict (deprecate 1.0)
        # -----------------------------------------
        training_step_output = self.trainer.process_dict_result(
            training_step_output, train=True)

        training_step_output = AttributeDict(
            batch_loss=training_step_output[0],
            pbar_on_batch_end=training_step_output[1],
            log_metrics=training_step_output[2],
            callback_metrics=training_step_output[3],
            hiddens=training_step_output[4],
        )
        # if the user decides to finally reduce things in epoch_end, save raw output without graphs
        if isinstance(training_step_output_for_epoch_end, torch.Tensor):
            training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach(
            )
        else:
            training_step_output_for_epoch_end = recursive_detach(
                training_step_output_for_epoch_end)

        return training_step_output_for_epoch_end, training_step_output

    def _process_training_step_output_1_0(self, training_step_output,
                                          split_batch):
        result = self.trainer.get_model()._results

        loss = None
        hiddens = None

        # handle dict return
        if isinstance(training_step_output, dict):
            loss = training_step_output.pop('loss', None)
            hiddens = training_step_output.pop('hiddens', None)
            result['extra'] = training_step_output

        # handle scalar return
        elif isinstance(training_step_output, torch.Tensor):
            loss = training_step_output
            result['extra'] = {}

        # map to results under the hood
        result.minimize = loss
        result.hiddens = hiddens

        # track batch for manual reduction with result
        result.track_batch_size(len(split_batch))

        # track metrics without grads for epoch reduction
        training_step_output_for_epoch_end = copy(result)
        training_step_output_for_epoch_end.detach()

        # what flows back into the system
        training_step_output = result

        return training_step_output_for_epoch_end, training_step_output

    def _process_result(self, training_step_output, split_batch):
        training_step_output.track_batch_size(len(split_batch))
        m = """
        TrainResult and EvalResult were deprecated in 0.9.1 and support will drop in 1.0.0.
        Use self.log and .write from the LightningModule to log metrics and write predictions.
        training_step can now only return a scalar (for the loss) or a dictionary with anything you want.

        Option 1:
        return loss

        Option 2:
        return {'loss': loss, 'anything_else': ...}

        Option 3:
        return {'loss': loss, 'hiddens': hiddens, 'anything_else': ...}
            """
        rank_zero_warn(m)

        # don't allow EvalResult in the training_step
        if isinstance(training_step_output, EvalResult):
            raise MisconfigurationException(
                'training_step cannot return EvalResult, '
                'use a dict or TrainResult instead')

        training_step_output_for_epoch_end = copy(training_step_output)
        training_step_output_for_epoch_end.detach()

        return training_step_output_for_epoch_end

    def optimizer_step(self, optimizer, opt_idx, batch_idx,
                       train_step_and_backward_closure):
        with self.trainer.profiler.profile('optimizer_step'):
            # optimizer step lightningModule hook
            self.trainer.accelerator_backend.optimizer_step(
                optimizer, batch_idx, opt_idx, train_step_and_backward_closure)

    def on_before_zero_grad(self, optimizer):
        model = self.trainer.get_model()
        model.on_before_zero_grad(optimizer)

    def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
        self.trainer.accelerator_backend.optimizer_zero_grad(
            batch_idx, optimizer, opt_idx)

    def on_before_backward(self, batch_idx, optimizer):
        # track gradient norms
        grad_norm_dic = self._track_gradient_norm()

        # clip gradients
        self.trainer.accelerator_backend.clip_gradients(optimizer)
        return grad_norm_dic

    def _track_gradient_norm(self):
        grad_norm_dict = {}
        if (self.trainer.global_step + 1) % self.trainer.row_log_interval == 0:
            if float(self.trainer.track_grad_norm) > 0:
                model = self.trainer.get_model()
                grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm)
        return grad_norm_dict

    def log_training_step_metrics(self, opt_closure_result,
                                  batch_callback_metrics, batch_log_metrics):
        # track callback metrics
        callback_metrics = opt_closure_result.training_step_output.callback_metrics
        batch_callback_metrics.append(callback_metrics)

        # decide which metrics to log (results vs dict return)
        using_results_obj = isinstance(opt_closure_result.training_step_output,
                                       Result)
        if using_results_obj:
            metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics
            step_pbar_metrics = opt_closure_result.training_step_output.batch_pbar_metrics
        else:
            metrics_to_log = opt_closure_result.training_step_output.log_metrics
            step_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end

        # track batch log metrics
        batch_log_metrics.append(metrics_to_log)

        # track progress bar metrics
        if len(step_pbar_metrics) > 0:
            self.trainer.logger_connector.add_progress_bar_metrics(
                step_pbar_metrics)
            self.trainer.logger_connector.callback_metrics.update(
                step_pbar_metrics)

    def process_hiddens(self, opt_closure_result):
        hiddens = opt_closure_result.hiddens
        if isinstance(opt_closure_result.training_step_output, Result):
            opt_closure_result.training_step_output_for_epoch_end.drop_hiddens(
            )
        return hiddens

    def tbptt_split_batch(self, batch):
        splits = [batch]
        if self.trainer.truncated_bptt_steps is not None:
            model_ref = self.trainer.get_model()
            with self.trainer.profiler.profile('tbptt_split_batch'):
                splits = model_ref.tbptt_split_batch(
                    batch, self.trainer.truncated_bptt_steps)
        return splits

    def run_training_epoch(self):

        # get model
        model = self.trainer.get_model()

        # modify dataloader if needed (ddp, etc...)
        train_dataloader = self.trainer.accelerator_backend.process_dataloader(
            self.trainer.train_dataloader)

        # track epoch output
        epoch_output = [[] for _ in range(self.num_optimizers)]

        # enable profiling for the dataloader
        train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(
            train_dataloader)
        dataloader_idx = 0
        should_check_val = False
        for batch_idx, (batch, is_last_batch) in train_dataloader:
            self.trainer.batch_idx = batch_idx

            # ------------------------------------
            # TRAINING_STEP + TRAINING_STEP_END
            # ------------------------------------
            batch_output = self.run_training_batch(batch, batch_idx,
                                                   dataloader_idx)

            # when returning -1 from train_step, we end epoch early
            if batch_output.signal == -1:
                break

            # only track outputs when user implements training_epoch_end
            # otherwise we will build up unnecessary memory
            epoch_end_outputs = self.process_train_step_outputs(
                batch_output.training_step_output_for_epoch_end,
                self.early_stopping_accumulator, self.checkpoint_accumulator)

            # hook
            # TODO: add outputs to batches
            self.on_train_batch_end(epoch_output, epoch_end_outputs, batch,
                                    batch_idx, dataloader_idx)

            # -----------------------------------------
            # SAVE METRICS TO LOGGERS
            # -----------------------------------------
            self.trainer.logger_connector.log_train_step_metrics(batch_output)

            # -----------------------------------------
            # VALIDATE IF NEEDED + CHECKPOINT CALLBACK
            # -----------------------------------------
            should_check_val = self.should_check_val_fx(
                batch_idx, is_last_batch)
            if should_check_val:
                self.trainer.run_evaluation(test_mode=False)

            # -----------------------------------------
            # SAVE LOGGERS (ie: Tensorboard, etc...)
            # -----------------------------------------
            self.save_loggers_on_train_batch_end()

            # update LR schedulers
            monitor_metrics = deepcopy(
                self.trainer.logger_connector.callback_metrics)
            monitor_metrics.update(batch_output.batch_log_metrics)
            self.update_train_loop_lr_schedulers(
                monitor_metrics=monitor_metrics)

            # max steps reached, end training
            if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1:
                break

            # end epoch early
            # stop when the flag is changed or we've gone past the amount
            # requested in the batches
            if self.trainer.should_stop:
                break

            self.trainer.total_batch_idx += 1

            # stop epoch if we limited the number of training batches
            if batch_idx + 1 >= self.trainer.num_training_batches:
                break

            # progress global step according to grads progress
            self.increment_accumulated_grad_global_step()

        # log epoch metrics
        self.trainer.logger_connector.log_train_epoch_end_metrics(
            epoch_output, self.checkpoint_accumulator,
            self.early_stopping_accumulator, self.num_optimizers)

        # hook
        self.trainer.logger_connector.on_train_epoch_end(epoch_output)

        # when no val loop is present or fast-dev-run still need to call checkpoints
        self.check_checkpoint_callback(not (
            should_check_val or is_overridden('validation_step', model)))

        # epoch end hook
        self.run_on_epoch_end_hook()

        # increment the global step once
        # progress global step according to grads progress
        self.increment_accumulated_grad_global_step()

    def run_training_batch(self, batch, batch_idx, dataloader_idx):
        # track grad norms
        grad_norm_dic = {}

        # track all metrics for callbacks
        batch_callback_metrics = []

        # track metrics to log
        batch_log_metrics = []

        # bookkeeping
        using_results_obj = False
        self.trainer.hiddens = None

        # track all outputs across time and num of optimizers
        batch_outputs = [[]
                         for _ in range(len(self.get_optimizers_iterable()))]

        if batch is None:
            return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)

        # hook
        response = self.trainer.call_hook('on_batch_start')
        if response == -1:
            return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)

        # hook
        response = self.trainer.call_hook('on_train_batch_start', batch,
                                          batch_idx, dataloader_idx)
        if response == -1:
            return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)

        # lightning module hook
        splits = self.tbptt_split_batch(batch)

        for split_idx, split_batch in enumerate(splits):
            self.trainer.split_idx = split_idx

            # loop over optimizers
            for opt_idx, optimizer in self.get_optimizers_iterable():
                # make sure only the gradients of the current optimizer's parameters are calculated
                # in the training step to prevent dangling gradients in multiple-optimizer setup.
                if len(self.trainer.optimizers) > 1:
                    for param in self.trainer.get_model().parameters():
                        param.requires_grad = False
                    for group in optimizer.param_groups:
                        for param in group['params']:
                            param.requires_grad = True

                # -------------------
                # calculate loss (train step + train step end)
                # -------------------
                opt_closure_result = self.training_step_and_backward(
                    split_batch, batch_idx, opt_idx, optimizer,
                    self.trainer.hiddens)

                if opt_closure_result is None:
                    continue

                using_results_obj = isinstance(
                    opt_closure_result.training_step_output, Result)

                # log metrics
                self.log_training_step_metrics(opt_closure_result,
                                               batch_callback_metrics,
                                               batch_log_metrics)

                # track hiddens
                self.trainer.hiddens = self.process_hiddens(opt_closure_result)

                # check if loss or model weights are nan
                if self.trainer.terminate_on_nan:
                    self.trainer.detect_nan_tensors(opt_closure_result.loss)

                # track total loss for logging (avoid mem leaks)
                self.accumulated_loss.append(opt_closure_result.loss)

                # track all the outputs across all steps
                batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0
                batch_outputs[batch_opt_idx].append(
                    opt_closure_result.training_step_output_for_epoch_end)

                # ------------------------------
                # BACKWARD PASS
                # ------------------------------
                # gradient update with accumulated gradients
                accumulation_done = (
                    self.trainer.batch_idx +
                    1) % self.trainer.accumulate_grad_batches == 0
                is_final_batch = (self.trainer.batch_idx +
                                  1) == self.trainer.num_training_batches
                if accumulation_done or is_final_batch:
                    # hook
                    grad_norm_dic = self.on_before_backward(
                        batch_idx, optimizer)

                    # wrap forward + backward pass in closure for 2nd order optimizers
                    train_step_and_backward_closure = lambda: self.training_step_and_backward(
                        split_batch,
                        batch_idx,
                        opt_idx,
                        optimizer,
                        self.trainer.hiddens,
                    ).loss

                    # optimizer step
                    self.optimizer_step(optimizer, opt_idx, batch_idx,
                                        train_step_and_backward_closure)

                    # hook
                    self.on_before_zero_grad(optimizer)

                    # clear gradients
                    self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)

                    # calculate running loss for display
                    self.running_loss.append(
                        self.accumulated_loss.mean() *
                        self.trainer.accumulate_grad_batches)

                    # reset for next set of accumulated grads
                    self.accumulated_loss.reset()

        # collapse all metrics into one dict
        batch_log_metrics = {
            k: v
            for d in batch_log_metrics for k, v in d.items()
        }

        # track all metrics for callbacks
        self.trainer.logger_connector.callback_metrics.update(
            batch_log_metrics)
        self.trainer.logger_connector.callback_metrics.update({
            k: v
            for d in batch_callback_metrics for k, v in d.items()
            if v is not None
        })

        result = AttributeDict(
            signal=0,
            grad_norm_dic=grad_norm_dic,
            batch_log_metrics=batch_log_metrics,
            training_step_output_for_epoch_end=batch_outputs)
        return result

    def training_step_and_backward(self, split_batch, batch_idx, opt_idx,
                                   optimizer, hiddens):
        """
        wrap the forward step in a closure so second order methods work
        """
        # lightning module hook
        result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)

        if result is None:
            self.warning_cache.warn(
                'training_step returned None if it was on purpose, ignore this warning...'
            )
            return None

        # backward pass
        self.backward(result, optimizer, opt_idx)

        # hook
        self.on_after_backward(result.training_step_output, batch_idx,
                               result.loss)

        return result

    def update_train_loop_lr_schedulers(self, monitor_metrics=None):
        num_accumulated_batches_reached = (
            self.trainer.batch_idx +
            1) % self.trainer.accumulate_grad_batches == 0
        num_training_batches_reached = (self.trainer.batch_idx +
                                        1) == self.trainer.num_training_batches

        if num_accumulated_batches_reached or num_training_batches_reached:
            # update lr
            self.trainer.optimizer_connector.update_learning_rates(
                interval='step', monitor_metrics=monitor_metrics)

    def run_on_epoch_end_hook(self):
        self.trainer.call_hook('on_epoch_end')
        self.trainer.call_hook('on_train_epoch_end')

    def increment_accumulated_grad_global_step(self):
        num_accumulated_batches_reached = (
            self.trainer.batch_idx +
            1) % self.trainer.accumulate_grad_batches == 0
        num_training_batches_reached = (self.trainer.batch_idx +
                                        1) == self.trainer.num_training_batches

        # progress global step according to grads progress
        if num_accumulated_batches_reached or num_training_batches_reached:
            self.trainer.global_step += 1

    def should_check_val_fx(self, batch_idx, is_last_batch):
        # decide if we should run validation
        is_val_check_batch = (batch_idx +
                              1) % self.trainer.val_check_batch == 0
        is_val_check_epoch = (self.trainer.current_epoch +
                              1) % self.trainer.check_val_every_n_epoch == 0
        can_check_val = self.trainer.enable_validation and is_val_check_epoch
        should_check_val = is_val_check_batch or self.trainer.should_stop
        is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float(
            'inf')
        should_check_val = can_check_val and (
            should_check_val or is_last_batch_for_infinite_dataset)

        return should_check_val

    def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
        # enable not needing to add opt_idx to training_step
        args = [batch, batch_idx]

        if len(self.trainer.optimizers) > 1:
            if self.trainer.has_arg('training_step', 'optimizer_idx'):
                args.append(opt_idx)
            else:
                num_opts = len(self.trainer.optimizers)
                raise ValueError(
                    f'Your LightningModule defines {num_opts} optimizers but '
                    f'training_step is missing the "optimizer_idx" argument.')

        # pass hiddens if using tbptt
        if self.trainer.truncated_bptt_steps is not None:
            args.append(hiddens)

        return args

    def save_loggers_on_train_batch_end(self):
        # when loggers should save to disk
        should_save_log = ((self.trainer.global_step + 1) %
                           self.trainer.log_save_interval == 0
                           or self.trainer.should_stop)
        if should_save_log or self.trainer.fast_dev_run:
            if self.trainer.is_global_zero and self.trainer.logger is not None:
                self.trainer.logger.save()

    def process_train_step_outputs(self, all_train_step_outputs,
                                   early_stopping_accumulator,
                                   checkpoint_accumulator):
        """
        Figure out what needs to be tracked/logged at the end of the epoch
        """

        # the training step outputs a list per optimizer. The list contains the outputs at each time step
        # when no TBPTT is used, then the list has 1 item per batch
        # when TBPTT IS used, then the list has n items (1 per time step)
        epoch_end_outputs = []
        for optimizer_idx_outputs in all_train_step_outputs:
            # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer
            if len(optimizer_idx_outputs) == 0:
                continue

            sample_output = optimizer_idx_outputs[-1]

            # pull out callback info if available (ie: Results object)
            if isinstance(sample_output,
                          dict) and 'early_stop_on' in sample_output:
                early_stopping_accumulator.accumulate(
                    sample_output['early_stop_on'])

            if isinstance(sample_output,
                          dict) and 'checkpoint_on' in sample_output:
                checkpoint_accumulator.accumulate(
                    sample_output['checkpoint_on'])

            # decide if we need to reduce at the end of the epoch automatically
            auto_reduce_tng_result = isinstance(
                sample_output,
                Result) and sample_output.should_reduce_on_epoch_end

            # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
            if is_overridden(
                    'training_epoch_end',
                    model=self.trainer.get_model()) or auto_reduce_tng_result:
                epoch_end_outputs.append(optimizer_idx_outputs)

        return epoch_end_outputs
Пример #11
0
class WandbLogger(LightningLoggerBase):
    r"""
    Log using `Weights and Biases <https://www.wandb.com/>`_.

    Install it with pip:

    .. code-block:: bash

        pip install wandb

    Args:
        name: Display name for the run.
        save_dir: Path where data is saved.
        offline: Run offline (data can be streamed later to wandb servers).
        id: Sets the version, mainly used to resume a previous run.
        anonymous: Enables or explicitly disables anonymous logging.
        version: Sets the version, mainly used to resume a previous run.
        project: The name of the project to which this run will belong.
        log_model: Save checkpoints in wandb dir to upload on W&B servers.
        experiment: WandB experiment object.
        prefix: A string to put at the beginning of metric keys.
        \**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
            :func:`wandb.init` can be passed as keyword arguments in this logger.

    Example::

    .. code-block:: python

        from pytorch_lightning.loggers import WandbLogger
        from pytorch_lightning import Trainer
        wandb_logger = WandbLogger()
        trainer = Trainer(logger=wandb_logger)

    Note: When logging manually through `wandb.log` or `trainer.logger.experiment.log`,
    make sure to use `commit=False` so the logging step does not increase.

    See Also:
        - `Tutorial <https://app.wandb.ai/cayush/pytorchlightning/reports/
          Use-Pytorch-Lightning-with-Weights-%26-Biases--Vmlldzo2NjQ1Mw>`__
          on how to use W&B with Pytorch Lightning.

    """

    LOGGER_JOIN_CHAR = '-'

    def __init__(self,
                 name: Optional[str] = None,
                 save_dir: Optional[str] = None,
                 offline: bool = False,
                 id: Optional[str] = None,
                 anonymous: bool = False,
                 version: Optional[str] = None,
                 project: Optional[str] = None,
                 log_model: bool = False,
                 experiment=None,
                 prefix: str = '',
                 **kwargs):
        if wandb is None:
            raise ImportError(
                'You want to use `wandb` logger which is not installed yet,'  # pragma: no-cover
                ' install it with `pip install wandb`.')
        super().__init__()
        self._name = name
        self._save_dir = save_dir
        self._anonymous = 'allow' if anonymous else None
        self._id = version or id
        self._project = project
        self._experiment = experiment
        self._offline = offline
        self._log_model = log_model
        self._prefix = prefix
        self._kwargs = kwargs
        # logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
        self._step_offset = 0
        self.warning_cache = WarningCache()

    def __getstate__(self):
        state = self.__dict__.copy()
        # args needed to reload correct experiment
        state[
            '_id'] = self._experiment.id if self._experiment is not None else None

        # cannot be pickled
        state['_experiment'] = None
        return state

    @property
    @rank_zero_experiment
    def experiment(self) -> Run:
        r"""

        Actual wandb object. To use wandb features in your
        :class:`~pytorch_lightning.core.lightning.LightningModule` do the following.

        Example::

            self.logger.experiment.some_wandb_function()

        """
        if self._experiment is None:
            if self._offline:
                os.environ['WANDB_MODE'] = 'dryrun'
            self._experiment = wandb.init(
                name=self._name,
                dir=self._save_dir,
                project=self._project,
                anonymous=self._anonymous,
                id=self._id,
                resume='allow',
                **self._kwargs) if wandb.run is None else wandb.run
            # offset logging step when resuming a run
            self._step_offset = self._experiment.step
            # save checkpoints in wandb dir to upload on W&B servers
            if self._log_model:
                self._save_dir = self._experiment.dir
        return self._experiment

    def watch(self,
              model: nn.Module,
              log: str = 'gradients',
              log_freq: int = 100):
        self.experiment.watch(model, log=log, log_freq=log_freq)

    @rank_zero_only
    def log_hyperparams(self, params: Union[Dict[str, Any],
                                            Namespace]) -> None:
        params = self._convert_params(params)
        params = self._flatten_dict(params)
        params = self._sanitize_callable_params(params)
        self.experiment.config.update(params, allow_val_change=True)

    @rank_zero_only
    def log_metrics(self,
                    metrics: Dict[str, float],
                    step: Optional[int] = None) -> None:
        assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

        metrics = self._add_prefix(metrics)
        if step is not None and step + self._step_offset < self.experiment.step:
            self.warning_cache.warn(
                'Trying to log at a previous step. Use `commit=False` when logging metrics manually.'
            )
        self.experiment.log(
            metrics,
            step=(step + self._step_offset) if step is not None else None)

    @property
    def save_dir(self) -> Optional[str]:
        return self._save_dir

    @property
    def name(self) -> Optional[str]:
        # don't create an experiment if we don't have one
        return self._experiment.project_name(
        ) if self._experiment else self._name

    @property
    def version(self) -> Optional[str]:
        # don't create an experiment if we don't have one
        return self._experiment.id if self._experiment else self._id

    @rank_zero_only
    def finalize(self, status: str) -> None:
        # offset future training logged on same W&B run
        if self._experiment is not None:
            self._step_offset = self._experiment.step

        # upload all checkpoints from saving dir
        if self._log_model:
            wandb.save(os.path.join(self.save_dir, "*.ckpt"))
Пример #12
0
class WandbLogger(LightningLoggerBase):
    LOGGER_JOIN_CHAR = '-'
    NAME_HPARAMS_FILE = 'hparams.yaml'

    def __init__(self,
                 name: Optional[str] = None,
                 save_dir: Optional[str] = None,
                 offline: bool = False,
                 id: Optional[str] = None,
                 anonymous: bool = False,
                 version: Optional[str] = None,
                 project: Optional[str] = None,
                 log_model: bool = False,
                 log_graph: bool = False,
                 default_hp_metric: bool = True,
                 experiment=None,
                 prefix: str = '',
                 **kwargs):
        if wandb is None:
            raise ImportError(
                'You want to use `wandb` logger which is not installed yet,'  # pragma: no-cover
                ' install it with `pip install wandb`.')
        super().__init__()
        self._name = name
        self._save_dir = save_dir
        self._anonymous = 'allow' if anonymous else None
        self._id = version or id
        self._project = project
        self._experiment = experiment
        self._offline = offline
        self._log_model = log_model
        self._prefix = prefix
        self._log_graph = log_graph
        self._default_hp_metric = default_hp_metric
        self._kwargs = kwargs
        # logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
        self._step_offset = 0
        self.hparams = {}
        self.warning_cache = WarningCache()

    def __getstate__(self):
        state = self.__dict__.copy()
        # args needed to reload correct experiment
        state[
            '_id'] = self._experiment.wandb_experiment.id if self._experiment is not None else None

        # cannot be pickled
        state['_experiment'] = None
        return state

    @property
    @rank_zero_experiment
    def experiment(self) -> Run:
        r"""
        Actual wandb object. To use wandb features in your
        :class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
        Example::
            self.logger.experiment.some_wandb_function()
        """
        if self._experiment is None:
            if self._offline:
                os.environ['WANDB_MODE'] = 'dryrun'
            wandb_experiment = wandb.init(
                name=self._name,
                dir=self._save_dir,
                project=self._project,
                anonymous=self._anonymous,
                id=self._id,
                resume='allow',
                **self._kwargs) if wandb.run is None else wandb.run

            # offset logging step when resuming a run
            self._step_offset = wandb_experiment.step
            # save checkpoints in wandb dir to upload on W&B servers
            if self._log_model:
                self._save_dir = wandb_experiment.dir
            self._fs = get_filesystem(self.save_dir)

            tensorboard_experiment = SummaryWriter(
                log_dir=wandb_experiment.dir, **self._kwargs)
            self._experiment = ExperimentTuple(wandb_experiment,
                                               tensorboard_experiment)
            self._experiment._wandb_offset = self._step_offset
        return self._experiment

    @rank_zero_only
    def watch(self,
              model: nn.Module,
              log: str = 'gradients',
              log_freq: int = 100):
        wandb_experiment, tensorboard_experiment = self.experiment
        wandb_experiment.watch(model, log=log, log_freq=log_freq)

    @rank_zero_only
    def log_hyperparams(self,
                        params: Union[Dict[str, Any], Namespace],
                        metrics: Optional[Dict[str, Any]] = None) -> None:
        params = self._convert_params(params)
        wandb_experiment, tensorboard_experiment = self.experiment

        # store params to output
        if OMEGACONF_AVAILABLE and isinstance(params, Container):
            self.hparams = OmegaConf.merge(self.hparams, params)
        else:
            self.hparams.update(params)
        params = self._flatten_dict(params)
        params = self._sanitize_callable_params(params)
        if metrics is None:
            if self._default_hp_metric:
                metrics = {"hp_metric": -1}
        elif not isinstance(metrics, dict):
            metrics = {"hp_metric": metrics}

        # TensorBoard
        if metrics:
            metrics = self._add_prefix(metrics)
            for k, v in metrics.items():
                if isinstance(v, torch.Tensor):
                    v = v.item()
                tensorboard_experiment.add_scalar(k, v, 0)
            exp, ssi, sei = hparams(params, metrics)
            writer = tensorboard_experiment._get_file_writer()
            writer.add_summary(exp)
            writer.add_summary(ssi)
            writer.add_summary(sei)

        # Wandb
        wandb_experiment.config.update(params, allow_val_change=True)

    @rank_zero_only
    def log_metrics(self,
                    metrics: Dict[str, float],
                    step: Optional[int] = None) -> None:
        assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'
        metrics = self._add_prefix(metrics)
        wandb_experiment, tensorboard_experiment = self.experiment

        # TensorBoard
        for k, v in metrics.items():
            if isinstance(v, torch.Tensor):
                v = v.item()

            if isinstance(v, dict):
                tensorboard_experiment.add_scalars(k, v, step)
            else:
                try:
                    tensorboard_experiment.add_scalar(k, v, step)
                except Exception as e:
                    m = f'\n you tried to log {v} which is not currently supported. Try a dict or a scalar/tensor.'
                    type(e)(e.message + m)

        # Wandb
        if step is not None and step + self._step_offset < wandb_experiment.step:
            self.warning_cache.warn(
                'Trying to log at a previous step. Use `commit=False` when logging metrics manually.'
            )
        wandb_experiment.log(
            metrics,
            step=(step + self._step_offset) if step is not None else None)

    @rank_zero_only
    def log_graph(self, model: LightningModule, input_array=None):
        if self._log_graph:
            wandb_experiment, tensorboard_experiment = self.experiment
            if input_array is None:
                input_array = model.example_input_array

            if input_array is not None:
                input_array = model.transfer_batch_to_device(
                    input_array, model.device)
                tensorboard_experiment.add_graph(model, input_array)
            else:
                rank_zero_warn(
                    'Could not log computational graph since the'
                    ' `model.example_input_array` attribute is not set'
                    ' or `input_array` was not given', UserWarning)

    @property
    def save_dir(self) -> Optional[str]:
        return self._save_dir

    @property
    def name(self) -> Optional[str]:
        # don't create an experiment if we don't have one
        return self._experiment.wandb_experiment.project_name(
        ) if self._experiment else self._name

    @property
    def version(self) -> Optional[str]:
        # don't create an experiment if we don't have one
        return self._experiment.wandb_experiment.id if self._experiment else self._id

    @rank_zero_only
    def finalize(self, status: str) -> None:
        self.experiment.flush()
        self.save()

        # offset future training logged on same W&B run
        if self._experiment is not None:
            self._step_offset = self._experiment.wandb_experiment.step

        # upload all checkpoints from saving dir
        if self._log_model:
            wandb.save(os.path.join(self.save_dir, "*.ckpt"))
        wandb.save(os.path.join(self.save_dir, self.NAME_HPARAMS_FILE))

    @rank_zero_only
    def save(self) -> None:
        # Initialize experiment
        _ = self.experiment

        super().save()

        # prepare the file path
        hparams_file = os.path.join(self.save_dir, self.NAME_HPARAMS_FILE)

        # save the metatags file if it doesn't exist
        if not os.path.isfile(hparams_file):
            save_hparams_to_yaml(hparams_file, self.hparams)