Esempio n. 1
0
def report_timing(level=Logging.VERBOSE):
    Logging(level).log("Time consumed:")
    for k in sorted(_TimingHelperClass.time_records.keys()):
        v = _TimingHelperClass.time_records[k]
        Logging(level).log(f"> {k}: {v:f}")
    Logging(level).log("------")
    _TimingHelperClass.time_records = {}
Esempio n. 2
0
    def train(self) -> None:
        # TODO: Incorporate `torchutils.prevent_oom`
        Logging(1).log("Training start.", timestamp=self.timestamp)
        while self.max_epochs == -1 or self.epoch < self.max_epochs:
            iterator = self.create_data_iterator()

            for batch in iterator:
                self.iterations += 1

                for hook in self.before_iteration_hooks:
                    hook(self)

                record_values = self.train_step(batch)

                for name, value in record_values.items():
                    if isinstance(value, tuple):
                        self.records[name].record.add(*value)
                    else:
                        self.records[name].record.add(value)

                for hook in self.after_iteration_hooks:
                    hook(self)

                if self.iterations % self.log_iters == 0:
                    self._print_summary(period='log')

                if self.iterations % self.valid_iters == 0:
                    self._print_summary(period='validate')

                    metric = self.validate()
                    if len(self.validation_history) == 0 or \
                            (metric > max(self.validation_history) and self.metric_higher_better) or \
                            (metric < min(self.validation_history) and not self.metric_higher_better):
                        self.bad_counter = 0
                        self.save_model()
                    else:
                        self.bad_counter += 1
                        Logging(1).log(
                            f"{utils.ordinal(self.bad_counter)} time degradation "
                            f"(threshold={self.decay_threshold}).")
                        if self.bad_counter >= self.decay_threshold:
                            self.decay_times += 1
                            if self.decay_times > self.patience:
                                Logging(1).log("Early stop!", color='red')
                                return
                            self.bad_counter = 0
                            self.decay()

                    self.validation_history.append(metric)

            self.epoch += 1
            self._print_summary(period='epoch')
Esempio n. 3
0
 def _print_summary(self, period: str = 'log') -> None:
     summary = []
     for record in self.records.values():
         if record.period == period:
             value = record.record.value()
             if record.post_compute is not None:
                 value = record.post_compute(value)
             if isinstance(value, float):
                 value = f'{value:.{record.precision}f}'
             summary.append((record.display, value))
             record.record.clear()
     if len(summary) == 0:
         return
     records = ', '.join(f'{name}={value}' for name, value in summary)
     log_message = self.LOG_MESSAGE.format(epoch=self.epoch,
                                           iter=self.iterations,
                                           records=records)
     Logging(1).log(log_message, timestamp=self.timestamp)