Example #1
0
    def on_train_epoch_start(self, epoch):
        model = self.trainer.get_model()

        # set seed for distributed sampler (enables shuffling for each epoch)
        try:
            self.trainer.train_dataloader.sampler.set_epoch(epoch)
        except Exception:
            pass

        # update training progress in trainer and model
        model.current_epoch = epoch
        self.trainer.current_epoch = epoch

        # changing gradient according accumulation_scheduler
        self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.get_model())

        # stores accumulated grad fractions per batch
        self.accumulated_loss = TensorRunningAccum(
            window_length=self.trainer.accumulate_grad_batches
        )

        # bookkeeping
        self.should_check_val = False

        # structured result accumulators for callbacks
        self.early_stopping_accumulator = Accumulator()
        self.checkpoint_accumulator = Accumulator()

        # hook
        self.trainer.call_hook('on_epoch_start')
        self.trainer.call_hook('on_train_epoch_start')
    def on_train_epoch_start(self, epoch):

        # update training progress in trainer
        self.trainer.current_epoch = epoch

        model = self.trainer.get_model()

        # reset train dataloader
        if self.trainer.reload_dataloaders_every_epoch:
            self.trainer.reset_train_dataloader(model)

        # set seed for distributed sampler (enables shuffling for each epoch)
        try:
            self.trainer.train_dataloader.sampler.set_epoch(epoch)
        except Exception:
            pass

        # changing gradient according accumulation_scheduler
        self.trainer.accumulation_scheduler.on_epoch_start(
            self.trainer, self.trainer.get_model())

        # stores accumulated grad fractions per batch
        self.accumulated_loss = TensorRunningAccum(
            window_length=self.trainer.accumulate_grad_batches)

        # structured result accumulators for callbacks
        self.early_stopping_accumulator = Accumulator()
        self.checkpoint_accumulator = Accumulator()

        # hook
        self.trainer.call_hook("on_epoch_start")
        self.trainer.call_hook("on_train_epoch_start")
    def on_train_epoch_start(self):
        # hook
        self.trainer.call_hook('on_epoch_start')
        self.trainer.call_hook('on_train_epoch_start')

        # bookkeeping
        self.should_check_val = False

        # structured result accumulators for callbacks
        self.early_stopping_accumulator = Accumulator()
        self.checkpoint_accumulator = Accumulator()
    def run_training_epoch(self):

        # get model
        model = self.get_model()

        # Epoch start events
        self.run_on_epoch_start_hook(model)

        # modify dataloader if needed (ddp, etc...)
        train_dataloader = self.accelerator_backend.process_dataloader(self.train_dataloader)

        # bookkeeping
        num_optimizers = len(self._get_optimizers_iterable())
        epoch_output = [[] for _ in range(num_optimizers)]
        should_check_val = False

        # structured result accumulators for callbacks
        early_stopping_accumulator = Accumulator()
        checkpoint_accumulator = Accumulator()

        # run epoch
        for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
                enumerate(_with_is_last(train_dataloader)), "get_train_batch"
        ):
            # stop epoch if we limited the number of training batches
            if batch_idx >= self.num_training_batches:
                break

            self.batch_idx = batch_idx
            model.global_step = self.global_step

            # ------------------------------------
            # TRAINING_STEP + TRAINING_STEP_END
            # ------------------------------------
            batch_output = self.run_training_batch(batch, batch_idx)

            # only track outputs when user implements training_epoch_end
            # otherwise we will build up unnecessary memory
            epoch_end_outputs = self.process_train_step_outputs(
                batch_output.training_step_output_for_epoch_end,
                early_stopping_accumulator,
                checkpoint_accumulator
            )

            # track the outputs to reduce at the end of the epoch
            for opt_idx, opt_outputs in enumerate(epoch_end_outputs):
                # with 1 step (no tbptt) don't use a sequence at epoch end
                if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result):
                    opt_outputs = opt_outputs[0]
                epoch_output[opt_idx].append(opt_outputs)

            # when returning -1 from train_step, we end epoch early
            self.should_stop = batch_output.signal == -1

            # -----------------------------------------
            # VALIDATE IF NEEDED + CHECKPOINT CALLBACK
            # -----------------------------------------
            should_check_val = self.should_check_val(batch_idx, is_last_batch)
            if should_check_val:
                self.run_evaluation(test_mode=False)

            # -----------------------------------------
            # SAVE LOGGERS (ie: Tensorboard, etc...)
            # -----------------------------------------
            self.save_loggers_in_training_loop(batch_idx)

            # -----------------------------------------
            # SAVE METRICS TO LOGGERS
            # -----------------------------------------
            self.save_train_loop_metrics_to_loggers(batch_idx, batch_output)

            # update LR schedulers
            monitor_metrics = deepcopy(self.callback_metrics)
            monitor_metrics.update(batch_output.batch_log_metrics)
            self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)

            # progress global step according to grads progress
            self.increment_accumulated_grad_global_step()

            # max steps reached, end training
            if self.max_steps is not None and self.max_steps == self.global_step:
                break

            # end epoch early
            # stop when the flag is changed or we've gone past the amount
            # requested in the batches
            if self.should_stop:
                break

        # let ddp devices catch up when using horovod
        self.sync_horovod()

        # process epoch outputs
        self.run_training_epoch_end(epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers)

        # checkpoint callback
        self.check_checkpoint_callback(should_check_val)

        # epoch end hook
        self.run_on_epoch_end_hook(model)
    def run_training_epoch(self):

        # get model
        model = self.get_model()

        # Epoch start events
        self.run_on_epoch_start_hook(model)

        # modify dataloader if needed (ddp, etc...)
        train_dataloader = self.prepare_train_loop_dataloader(self.train_dataloader)

        # bookkeeping
        epoch_output = []
        should_check_val = False

        # structured result accumulators for callbacks
        early_stopping_accumulator = Accumulator()
        checkpoint_accumulator = Accumulator()

        # run epoch
        for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
                enumerate(_with_is_last(train_dataloader)), "get_train_batch"
        ):
            # stop epoch if we limited the number of training batches
            if batch_idx >= self.num_training_batches:
                break

            self.batch_idx = batch_idx
            model.global_step = self.global_step

            # ------------------------------------
            # TRAINING_STEP + TRAINING_STEP_END
            # ------------------------------------
            batch_output = self.run_training_batch(batch, batch_idx)

            # only track outputs when user implements training_epoch_end
            # otherwise we will build up unnecessary memory
            step_out = batch_output.training_step_output_for_epoch_end
            should_auto_reduce_train_result = isinstance(step_out, Result) and step_out.should_reduce_on_epoch_end
            if isinstance(step_out, dict) and 'early_stop_on' in step_out:
                early_stopping_accumulator.accumulate(step_out['early_stop_on'])

            if isinstance(step_out, dict) and 'checkpoint_on' in step_out:
                checkpoint_accumulator.accumulate(step_out['checkpoint_on'])

            if self.is_overridden('training_epoch_end', model=self.get_model()) or should_auto_reduce_train_result:
                epoch_output.append(batch_output.training_step_output_for_epoch_end)

            # update LR schedulers
            self.update_train_loop_lr_schedulers()

            # when returning -1 from train_step, we end epoch early
            self.should_stop = batch_output.signal == -1

            # -----------------------------------------
            # VALIDATE IF NEEDED + CHECKPOINT CALLBACK
            # -----------------------------------------
            should_check_val = self.should_check_val(batch_idx, is_last_batch)
            if should_check_val:
                self.run_evaluation(test_mode=False)

            # -----------------------------------------
            # SAVE LOGGERS (ie: Tensorboard, etc...)
            # -----------------------------------------
            self.save_loggers_in_training_loop(batch_idx)

            # -----------------------------------------
            # SAVE METRICS TO LOGGERS
            # -----------------------------------------
            self.save_train_loop_metrics_to_loggers(batch_idx, batch_output)

            # progress global step according to grads progress
            self.increment_accumulated_grad_global_step()

            # max steps reached, end training
            if self.max_steps is not None and self.max_steps == self.global_step:
                break

            # end epoch early
            # stop when the flag is changed or we've gone past the amount
            # requested in the batches
            if self.should_stop:
                break

        # let ddp devices catch up when using horovod
        self.sync_horovod()

        # process epoch outputs
        self.run_training_epoch_end(epoch_output, checkpoint_accumulator, early_stopping_accumulator)

        # checkpoint callback
        self.check_checkpoint_callback(should_check_val)

        # epoch end hook
        self.run_on_epoch_end_hook(model)