Exemple #1
0
    def compare_metric(self, best_metric: Metric, new_metric: Metric) -> bool:
        if best_metric.is_new:
            return True

        best_loss = round(best_metric.get_current_loss(), 2)
        new_loss = round(new_metric.get_current_loss(), 2)

        if best_loss == new_loss:
            best_levenshtein = best_metric.get_accuracy_metric(
                MetricType.LevenshteinDistance)
            new_levenshtein = new_metric.get_accuracy_metric(
                MetricType.LevenshteinDistance)
            return new_levenshtein <= best_levenshtein
        else:
            return best_loss > new_loss
    def compare_metric(self, best_metric: Metric, new_metric: Metric) -> bool:
        if best_metric.is_new:
            return True

        current_best = 0
        new_result = 0

        if self.metric_log_key is not None:
            current_best = best_metric.get_accuracy_metric(self.metric_log_key)
            new_result = new_metric.get_accuracy_metric(self.metric_log_key)

        if current_best == new_result:
            result = best_metric.get_current_loss(
            ) >= new_metric.get_current_loss()
        else:
            result = current_best < new_result

        return result
    def compare_metric(self, best_metric: Metric, new_metric: Metric) -> bool:
        if best_metric.is_new:
            return True

        best_jaccard = round(
            best_metric.get_accuracy_metric(MetricType.JaccardSimilarity), 2)
        new_jaccard = round(
            new_metric.get_accuracy_metric(MetricType.JaccardSimilarity), 2)

        if best_jaccard == new_jaccard:
            best_levenshtein = best_metric.get_accuracy_metric(
                MetricType.LevenshteinDistance)
            new_levenshtein = new_metric.get_accuracy_metric(
                MetricType.LevenshteinDistance)
            new_is_better = new_levenshtein < best_levenshtein
        else:
            new_is_better = new_jaccard > best_jaccard

        return new_is_better
    def _evaluate(self) -> Metric:
        metric = Metric(amount_limit=None)
        data_loader_length = len(self.data_loader_validation)
        full_output_log = DataOutputLog()

        for i, batch in enumerate(self.data_loader_validation):
            if not batch:
                continue

            self._log_service.log_progress(i,
                                           data_loader_length,
                                           evaluation=True)

            loss_batch, metrics_batch, current_output_log = self._perform_batch_iteration(
                batch,
                train_mode=False,
                output_characters=(len(full_output_log) < 100))

            if math.isnan(loss_batch):
                raise Exception(
                    f'loss is NaN during evaluation at iteration {i}')

            if current_output_log is not None:
                full_output_log.extend(current_output_log)

            metric.add_accuracies(metrics_batch)
            metric.add_loss(loss_batch)

        final_metric = self._model.calculate_evaluation_metrics()
        metric.add_accuracies(final_metric)
        self._log_service.log_batch_results(full_output_log)

        assert not math.isnan(
            metric.get_current_loss()
        ), f'combined loss is NaN during evaluation at iteration {i}; losses are - {metric._losses}'

        return metric
Exemple #5
0
    def compare_metric(self, best_metric: Metric, new_metrics: Metric) -> bool:
        if best_metric.is_new or best_metric.get_current_loss(
        ) >= new_metrics.get_current_loss():
            return True

        return False
    def log_evaluation(self,
                       train_metric: Metric,
                       validation_metric: Metric,
                       batches_done: int,
                       epoch: int,
                       iteration: int,
                       iterations: int,
                       new_best: bool,
                       metric_log_key: str = None):
        """
        logs progress to user through tensorboard and terminal
        """

        self._current_epoch = epoch
        self._current_iteration = iteration
        self._all_iterations = iterations

        time_passed = self.get_time_passed()
        train_loss = train_metric.get_current_loss()
        train_accuracies = train_metric.get_current_accuracies()
        validation_loss = validation_metric.get_current_loss()
        validation_accuracies = validation_metric.get_current_accuracies()
        if train_accuracies and len(train_accuracies) > 0:
            if metric_log_key is not None and train_metric.contains_accuracy_metric(
                    metric_log_key):
                train_accuracy = train_metric.get_accuracy_metric(
                    metric_log_key)
            else:
                train_accuracy = list(train_accuracies.values())[0]
        else:
            train_accuracy = 0

        if validation_accuracies and len(validation_accuracies) > 0:
            if metric_log_key is not None and validation_metric.contains_accuracy_metric(
                    metric_log_key):
                validation_accuracy = validation_metric.get_accuracy_metric(
                    metric_log_key)
            else:
                validation_accuracy = list(validation_accuracies.values())[0]
        else:
            validation_accuracy = 0

        print(
            colored(
                self._log_template.format(time_passed.total_seconds(), epoch,
                                          iteration, 1 + iteration, iterations,
                                          100. * (1 + iteration) / iterations,
                                          train_loss, train_accuracy,
                                          validation_loss, validation_accuracy,
                                          "BEST" if new_best else ""),
                self._evaluation_color))

        if self._external_logging_enabled:
            current_step = self._get_current_step()
            wandb.log({'Train loss': train_loss}, step=current_step)

            for key, value in train_accuracies.items():
                wandb.log({f'Train - {key}': value}, step=current_step)

            for key, value in validation_accuracies.items():
                wandb.log({f'Validation - {key}': value}, step=current_step)

            wandb.log({'Validation loss': validation_loss}, step=current_step)

            if current_step == 0:
                seconds_per_iteration = time_passed.total_seconds()
            else:
                seconds_per_iteration = time_passed.total_seconds(
                ) / current_step

            self.log_summary('Seconds per iteration', seconds_per_iteration)
    def train(self) -> bool:
        """
         main training function
        """
        epoch = 0

        try:
            self._log_service.initialize_evaluation()

            best_metrics = Metric(amount_limit=None)
            patience = self._initial_patience

            metric = Metric(amount_limit=self._arguments_service.eval_freq)

            start_epoch = 0
            start_iteration = 0
            resets_left = self._arguments_service.resets_limit
            reset_epoch_limit = self._arguments_service.training_reset_epoch_limit

            if self._arguments_service.resume_training:
                model_checkpoint = self._load_model()
                if model_checkpoint and not self._arguments_service.skip_best_metrics_on_resume:
                    best_metrics = model_checkpoint.best_metrics
                    start_epoch = model_checkpoint.epoch
                    start_iteration = model_checkpoint.iteration
                    resets_left = model_checkpoint.resets_left
                    metric.initialize(best_metrics)

            self.data_loader_train, self.data_loader_validation = self._dataloader_service.get_train_dataloaders(
            )
            self._optimizer = self._optimizer_base.get_optimizer()
            self._log_service.start_logging_model(
                self._model, self._loss_function.criterion)

            # run
            epoch = start_epoch
            model_has_converged = False
            while epoch < self._arguments_service.epochs:
                self._log_service.log_summary('Epoch', epoch)

                best_metrics, patience = self._perform_epoch_iteration(
                    epoch, best_metrics, patience, metric, resets_left,
                    start_iteration)

                start_iteration = 0  # reset the starting iteration

                # flush prints
                sys.stdout.flush()

                if patience == 0:
                    # we only prompt the model for changes on convergence once
                    should_start_again = not model_has_converged and self._model.on_convergence(
                    )
                    if should_start_again:
                        model_has_converged = True
                        model_checkpoint = self._load_model()
                        if model_checkpoint is not None:
                            best_metrics = model_checkpoint.best_metrics
                            start_epoch = model_checkpoint.epoch
                            start_iteration = model_checkpoint.iteration
                            resets_left = model_checkpoint.resets_left
                            metric.initialize(best_metrics)

                        self._initial_patience = self._arguments_service.patience
                        patience = self._initial_patience
                        epoch += 1
                    elif (self._arguments_service.reset_training_on_early_stop
                          and resets_left > 0 and reset_epoch_limit > epoch):
                        patience = self._initial_patience
                        resets_left -= 1
                        self._log_service.log_summary(key='Resets left',
                                                      value=resets_left)

                        print(
                            f'Resetting training due to early stop activated. Resets left: {resets_left}'
                        )
                    else:
                        print('Stopping training due to depleted patience')
                        break
                else:
                    epoch += 1

            if epoch >= self._arguments_service.epochs:
                print('Stopping training due to depleted epochs')

        except KeyboardInterrupt as e:
            print(f"Killed by user: {e}")
            if self._arguments_service.save_checkpoint_on_crash:
                self._model.save(self._model_path,
                                 epoch,
                                 0,
                                 best_metrics,
                                 resets_left,
                                 name_prefix=f'KILLED_at_epoch_{epoch}')

            return False
        except Exception as e:
            print(e)
            if self._arguments_service.save_checkpoint_on_crash:
                self._model.save(self._model_path,
                                 epoch,
                                 0,
                                 best_metrics,
                                 resets_left,
                                 name_prefix=f'CRASH_at_epoch_{epoch}')
            raise e

        # flush prints
        sys.stdout.flush()

        if self._arguments_service.save_checkpoint_on_finish:
            self._model.save(self._model_path,
                             epoch,
                             0,
                             best_metrics,
                             resets_left,
                             name_prefix=f'FINISHED_at_epoch_{epoch}')

        return True
    def _perform_epoch_iteration(
            self,
            epoch_num: int,
            best_metrics: Metric,
            patience: int,
            metric: Metric,
            resets_left: int,
            start_iteration: int = 0) -> Tuple[Metric, int]:
        """
        one epoch implementation
        """
        data_loader_length = len(self.data_loader_train)

        for i, batch in enumerate(self.data_loader_train):
            if i < start_iteration:
                continue

            self._log_service.log_progress(i, data_loader_length, epoch_num)

            loss_batch, accuracies_batch, _ = self._perform_batch_iteration(
                batch)
            assert not math.isnan(
                loss_batch), f'loss is NaN during training at iteration {i}'

            metric.add_loss(loss_batch)
            metric.add_accuracies(accuracies_batch)

            # calculate amount of batches and walltime passed
            batches_passed = i + (epoch_num * data_loader_length)

            # run on validation set and print progress to terminal
            # if we have eval_frequency or if we have finished the epoch
            if self._should_evaluate(batches_passed, i, data_loader_length):
                if not self._arguments_service.skip_validation:
                    validation_metric = self._evaluate()
                else:
                    validation_metric = Metric(metric=metric)

                assert not math.isnan(
                    metric.get_current_loss()
                ), f'combined loss is NaN during training at iteration {i}; losses are - {metric._losses}'

                new_best = self._model.compare_metric(best_metrics,
                                                      validation_metric)
                if new_best:
                    best_metrics, patience = self._save_current_best_result(
                        validation_metric, epoch_num, i, resets_left)
                else:
                    patience -= 1

                self._log_service.log_evaluation(
                    metric,
                    validation_metric,
                    batches_passed,
                    epoch_num,
                    i,
                    data_loader_length,
                    new_best,
                    metric_log_key=self._model.metric_log_key)

                self._log_service.log_summary(key='Patience left',
                                              value=patience)

                self._model.finalize_batch_evaluation(is_new_best=new_best)

            # check if runtime is expired
            self._validate_time_passed()

            if patience == 0:
                break

        return best_metrics, patience