示例#1
0
    def _make_prediction_on_test_set(self):
        """Makes prediction on test set."""
        self.logger.info('\n****** Make prediction on test set ******\n')

        self.classifier = ImageClassifier(
            base_model_name=self.base_model_name,
            n_classes=self.n_classes,
            weights=None,
            dropout_rate=None,
            learning_rate=None,
            loss=None,
        )
        self.classifier.model = self.best_model

        self.data_generator = ValDataGenerator(
            samples=self.samples_test,
            image_dir=self.image_dir,
            batch_size=self.batch_size,
            n_classes=self.n_classes,
            basenet_preprocess=self.classifier.get_preprocess_input(),
        )

        predictions_dist = self.classifier.predict_generator(
            data_generator=self.data_generator,
            workers=WORKERS,
            use_multiprocessing=USE_MULTIPROCESSING,
            verbose=1,
        )

        self.y_pred = np.argmax(predictions_dist, axis=1)
        self.y_pred_prob = self._get_probabilities_prediction(
            predictions_dist=predictions_dist)
示例#2
0
 def _build_model(self):
     self.classifier = ImageClassifier(
         self.base_model_name,
         self.n_classes,
         self.learning_rate_dense,
         self.dropout_rate,
         self.loss,
     )
     self.classifier.build()
示例#3
0
    def test__init(self):
        global classifier
        classifier = ImageClassifier(
            TEST_CONFIG['base_model_name'],
            TEST_CONFIG['n_classes'],
            TEST_CONFIG['learning_rate_dense'],
            TEST_CONFIG['dropout_rate'],
            TEST_CONFIG['loss'],
        )

        assert classifier.weights == 'imagenet'
        assert classifier.base_module is importlib.import_module(
            'keras.applications.mobilenet')
示例#4
0
class Evaluation:
    """Calculates performance metrics for trained models.

    Loads the best model (validation accuracy) from *models* directory in job directory.
    All metrics and graphs are based on *test_samples.json* in job directory.

    Attributes:
        image_dir: Path of image directory.
        job_dir: Path to job directory with samples.
        batch_size: Number of images per batch (default 64).
        base_model_name: Name of pretrained CNN (default MobileNet).
    """
    def __init__(self,
                 image_dir: str,
                 job_dir: str,
                 batch_size: int = BATCH_SIZE,
                 base_model_name: str = BASE_MODEL_NAME,
                 **kwargs) -> None:
        """Inits evaluation component.

        Loads the best model from job directory.
        Creates evaluation directory if app was started from commandline.
        """
        self.image_dir = Path(image_dir).resolve()
        self.job_dir = Path(job_dir).resolve()
        self.batch_size = batch_size
        self.base_model_name = base_model_name

        self.logger = get_logger(__name__, self.job_dir)
        self.samples_test: list = load_json(
            self.job_dir / 'test_samples.json')  # type: ignore
        self.class_mapping: dict = load_json(
            self.job_dir / 'class_mapping.json')  # type: ignore
        self.n_classes = len(self.class_mapping)
        self.classes = [
            str(self.class_mapping[str(i)]) for i in range(self.n_classes)
        ]
        self.y_true = np.array([i['label'] for i in self.samples_test])

        self._determine_plot_params()
        self._load_best_model()
        self._create_evaluation_dir()

    def _determine_plot_params(self):
        """Checks whether ipython kernel is present.

        Plots will only be shown if in ipython, otherwise saved as files.
        """
        try:
            __IPYTHON__
            self.show_plots = True
            self.save_plots = False
        except NameError:
            # Suppress figure window in terminal
            # https://matplotlib.org/faq/howto_faq.html#generate-images-without-having-a-window-appear
            import matplotlib

            matplotlib.use('Agg')
            self.show_plots = False
            self.save_plots = True

    def _load_best_model(self):
        """Loads best performing model from job_dir."""
        self.logger.info('\n****** Load model ******\n')

        job_path = self.job_dir / 'models'
        model_files = list(job_path.glob('**/*.hdf5'))
        max_acc_idx = np.argmax(
            [m.name.split('_')[3][:5] for m in model_files])
        self.best_model_file = str(model_files[max_acc_idx])
        self.best_model = load_model(self.best_model_file)

        self.logger.info('loaded {}\n'.format(self.best_model_file))

    def _create_evaluation_dir(self):
        """Creates evaluation dir for reporting."""
        if self.save_plots:
            evaluation_dir_name = os.path.basename(
                self.best_model_file).split('.hdf5')[0]
            self.evaluation_dir = self.job_dir / 'evaluation_{}'.format(
                evaluation_dir_name)

            if not self.evaluation_dir.exists():
                os.makedirs(self.evaluation_dir)

    def _plot_test_set_distribution(self):
        """Plots bars with number of samples for each label in test set."""
        self.logger.info(
            '\n****** Calculate distribution on test set ******\n')

        counts = np.bincount(self.y_true)
        title = 'Number of images in test set: {}'.format(
            len(self.samples_test))
        index = np.arange(self.n_classes)
        title_fontsize = 16 if self.n_classes < 4 else 18
        text_fontsize = 12 if self.n_classes < 4 else 14

        plt.bar(index, counts)
        plt.xlabel('Label', fontsize=text_fontsize)
        plt.ylabel('Number of images', fontsize=text_fontsize)
        plt.xticks(index, self.classes, fontsize=text_fontsize, rotation=30)
        plt.title(title, fontsize=title_fontsize)

        # figsize = [min(15, self.n_classes * 2), 5]
        # plt.figure(figsize=figsize)
        plt.tight_layout()

        if self.save_plots:
            target_file = self.evaluation_dir / 'test_set_distribution.pdf'
            plt.savefig(target_file)
            self.logger.info('saved under {}'.format(target_file))

        if self.show_plots:
            plt.show()

    @staticmethod
    def _get_probabilities_prediction(
            predictions_dist: List[List[float]]) -> List[float]:
        index = np.argmax(predictions_dist, axis=1)
        prob = [pred[index] for pred, index in zip(predictions_dist, index)]
        return prob

    def _make_prediction_on_test_set(self):
        """Makes prediction on test set."""
        self.logger.info('\n****** Make prediction on test set ******\n')

        self.classifier = ImageClassifier(
            base_model_name=self.base_model_name,
            n_classes=self.n_classes,
            weights=None,
            dropout_rate=None,
            learning_rate=None,
            loss=None,
        )
        self.classifier.model = self.best_model

        self.data_generator = ValDataGenerator(
            samples=self.samples_test,
            image_dir=self.image_dir,
            batch_size=self.batch_size,
            n_classes=self.n_classes,
            basenet_preprocess=self.classifier.get_preprocess_input(),
        )

        predictions_dist = self.classifier.predict_generator(
            data_generator=self.data_generator,
            workers=WORKERS,
            use_multiprocessing=USE_MULTIPROCESSING,
            verbose=1,
        )

        self.y_pred = np.argmax(predictions_dist, axis=1)
        self.y_pred_prob = self._get_probabilities_prediction(
            predictions_dist=predictions_dist)

    def _calc_classification_report(self):
        """Calculates classification report on prediction on test set."""
        self.logger.info('\n****** Calculate classification report ******\n')

        self.accuracy = accuracy_score(y_true=self.y_true, y_pred=self.y_pred)
        self.logger.info('\nModel achieves {}% accuracy on test set\n'.format(
            round(self.accuracy * 100, 2)))

        cr = classification_report(y_true=self.y_true,
                                   y_pred=self.y_pred,
                                   target_names=self.classes)
        self.logger.info(cr)

    def _plot_confusion_matrix(self):
        """Plots normalized confusion matrix."""
        self.logger.info('\n****** Plot confusion matrix ******\n')

        cm = confusion_matrix(y_true=self.y_true, y_pred=self.y_pred)
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

        figsize = [
            min(15, self.n_classes * 3.5),
            min(15, self.n_classes * 3.5)
        ]
        title_fontsize = 16 if self.n_classes < 4 else 18
        text_fontsize = 12 if self.n_classes < 4 else 14

        plt.figure(figsize=figsize)
        plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
        plt.title('Confusion matrix', fontsize=title_fontsize)
        plt.colorbar()
        tick_marks = np.arange(self.n_classes)
        plt.xticks(tick_marks,
                   self.classes,
                   rotation=45,
                   fontsize=text_fontsize)
        plt.yticks(tick_marks, self.classes, fontsize=text_fontsize)
        plt.ylabel('True label', fontsize=text_fontsize)
        plt.xlabel('Predicted label', fontsize=text_fontsize)

        thresh = cm.max() / 2.0
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(
                j,
                i,
                '{:.2f}'.format(cm[i, j]),
                horizontalalignment='center',
                color='white' if cm[i, j] > thresh else 'black',
                fontsize=text_fontsize,
            )

        plt.tight_layout()

        if self.save_plots:
            target_file = self.evaluation_dir / 'confusion_matrix.pdf'
            plt.savefig(target_file)
            self.logger.info('saved under {}'.format(target_file))

        if self.show_plots:
            plt.show()

    def run(self):
        """Runs evaluation pipeline on the best model found in job directory for the specific test set:

            - Plots test set distribution
            - Makes prediction on test set
            - Calculates classification report (accuracy, precision, recall)
            - Plots confusion matrix
        """

        self._plot_test_set_distribution()
        self._make_prediction_on_test_set()
        self._calc_classification_report()
        self._plot_confusion_matrix()

    # TO-DO: Enforce string or integer but not both at the same time
    def get_correct_wrong_examples(
            self,
            label: Union[int, str]) -> Tuple[TYPE_IMAGE_LIST, TYPE_IMAGE_LIST]:
        """Gets correctly and wrongly predicted samples for a given label.

        Args:
            label: int or str (label for which the predictions should be considered).

        Returns:
            (correct, wrong): Tuple of two image lists.
        """
        correct = []
        wrong = []

        if type(label) == str:
            class_mapping_inv = {v: k for k, v in self.class_mapping.items()}
            label = int(class_mapping_inv[label])

        for i, sample in enumerate(self.samples_test):
            if self.y_true[i] == label:
                image_file = self.image_dir / sample['image_id']
                if self.y_true[i] == self.y_pred[i]:
                    correct.append([
                        i,
                        load_image(image_file, target_size=(224, 224)), sample
                    ])
                else:
                    wrong.append([
                        i,
                        load_image(image_file, target_size=(224, 224)), sample
                    ])

        return correct, wrong

    # visualize misclassified images:
    def visualize_images(self,
                         image_list: TYPE_IMAGE_LIST,
                         show_heatmap: bool = False,
                         n_plot: int = 20):
        """Visualizes images in a sample list.

        Args:
            image_list: sample list.
            show_heatmap: boolean (generates a gradient based class activation map (grad-CAM), default False).
            n_plot: maximum number of plots to be shown (default 20).
        """
        if len(image_list) == 0:
            print('Empty list.')
            return
        else:
            n_rows = min(n_plot, len(image_list))
            n_cols = 2 if show_heatmap else 1

            figsize = [5 * n_cols, 5 * n_rows]
            plt.figure(figsize=figsize)

            plot_count = 1
            for (i, img, sample) in image_list[:n_rows]:
                plt.subplot(n_rows, n_cols, plot_count)
                plt.imshow(img)
                plt.axis('off')
                plt.title('true: {}, predicted: {} ({})'.format(
                    self.class_mapping[str(self.y_true[i])],
                    self.class_mapping[str(self.y_pred[i])],
                    str(round(self.y_pred_prob[i], 2)),
                ))
                plot_count += 1

                if show_heatmap is True:
                    heatmap = visualize_cam(
                        model=self.classifier.model,
                        layer_idx=89,
                        filter_indices=[self.y_pred[i]],
                        seed_input=self.classifier.get_preprocess_input()(
                            np.array(img).astype(np.float32)),
                    )
                    plt.subplot(n_rows, n_cols, plot_count)
                    plt.imshow(img)
                    plt.imshow(heatmap, alpha=0.7)
                    plt.axis('off')
                    plot_count += 1

        if self.save_plots:
            # TODO: pass name as argument
            target_file = self.evaluation_dir / 'misclassified_images.pdf'
            plt.savefig(target_file)
            self.logger.info('saved under {}'.format(target_file))

        if self.show_plots:
            plt.show()
示例#5
0
class Training:
    """Builds model and runs training.

    The following pretrained CNNs from Keras can be used for transfer learning:

    - Xception
    - VGG16
    - VGG19
    - ResNet50, ResNet101, ResNet152
    - ResNet50V2, ResNet101V2, ResNet152V2
    - ResNeXt50, ResNeXt101
    - InceptionV3
    - InceptionResNetV2
    - MobileNet
    - MobileNetV2
    - DenseNet121, DenseNet169, DenseNet201
    - NASNetLarge, NASNetMobile

    Training is split into two phases, at first only the last dense layer gets
    trained, and then all layers are trained. The maximum number of epochs for
    each phase is set by *epochs_train_dense* (default: 100) and
    *epochs_train_all* (default: 100), respectively. Similarly,
    *learning_rate_dense* (default: 0.001) and *learning_rate_all*
    (default: 0.0003) can be set.

    For each phase the learning rate is reduced after a patience period if no
    improvement in validation accuracy has been observed. The patience period
    depends on the average number of samples per class:

    - if n_per_class < 200: patience = 5 epochs
    - if n_per_class >= 200 and < 500: patience = 4 epochs
    - if n_per_class >= 500: patience = 2 epochs

    The training is stopped early after a patience period that is three times
    the learning rate patience to allow for two learning rate adjustments
    with no validation accuracy improvement before stopping training.

    Attributes:
        image_dir: Directory with images used for training.
        job_dir: Directory with train_samples.json, val_samples.json,
                 and class_mapping.json.
        epochs_train_dense: Maximum number of epochs to train dense layers (default 100).
        epochs_train_all: Maximum number of epochs to train all layers (default 100).
        learning_rate_dense: Learning rate for dense training phase (default 0.001).
        learning_rate_all: Learning rate for all training phase (default 0.0003).
        batch_size: Number of images per batch (default 64).
        dropout_rate: Fraction of nodes before output layer set to random value (default 0.75).
        base_model_name: Name of pretrained CNN (default MobileNet).
    """
    def __init__(
        self,
        image_dir: str,
        job_dir: str,
        epochs_train_dense: typing.Union[int, str] = EPOCHS_TRAIN_DENSE,
        epochs_train_all: typing.Union[int, str] = EPOCHS_TRAIN_ALL,
        learning_rate_dense: typing.Union[float, str] = LEARNING_RATE_DENSE,
        learning_rate_all: typing.Union[float, str] = LEARNING_RATE_ALL,
        batch_size: typing.Union[int, str] = BATCH_SIZE,
        dropout_rate: typing.Union[float, str] = DROPOUT_RATE,
        base_model_name: str = BASE_MODEL_NAME,
        loss: str = LOSS,
        **kwargs,
    ) -> None:
        """Inits training component.

        Checks whether multiprocessing is available and sets number of workers for training.
        """
        self.image_dir = Path(image_dir).resolve()
        self.job_dir = Path(job_dir).resolve()

        self.logger = get_logger(__name__, self.job_dir)
        self.samples_train = load_json(self.job_dir / 'train_samples.json')
        self.samples_val = load_json(self.job_dir / 'val_samples.json')
        self.class_mapping = load_json(self.job_dir / 'class_mapping.json')
        self.n_classes = len(self.class_mapping)

        self.epochs_train_dense = int(epochs_train_dense)
        self.epochs_train_all = int(epochs_train_all)
        self.learning_rate_dense = float(learning_rate_dense)
        self.learning_rate_all = float(learning_rate_all)
        self.batch_size = int(batch_size)
        self.dropout_rate = float(dropout_rate)
        self.base_model_name = base_model_name
        self.loss = loss
        self.use_multiprocessing, self.workers = use_multiprocessing()

    def _set_patience(self):
        """Adjust patience for early stopping and learning rate schedule
        based on training set size.
        """
        n_per_class = int(len(self.samples_train) / self.n_classes)

        self.patience_learning_rate = 5
        if n_per_class >= 200:
            self.patience_learning_rate = 4

        if n_per_class >= 500:
            self.patience_learning_rate = 2

        self.patience_early_stopping = 3 * self.patience_learning_rate

        self.logger.info('Early stopping patience: {}'.format(
            self.patience_early_stopping))
        self.logger.info('Learning rate patience: {}'.format(
            self.patience_learning_rate))

    def _build_model(self):
        self.classifier = ImageClassifier(
            self.base_model_name,
            self.n_classes,
            self.learning_rate_dense,
            self.dropout_rate,
            self.loss,
        )
        self.classifier.build()

    def _fit_model(self):
        training_generator = TrainDataGenerator(
            self.samples_train,
            self.image_dir,
            self.batch_size,
            self.n_classes,
            self.classifier.get_preprocess_input(),
        )

        validation_generator = ValDataGenerator(
            self.samples_val,
            self.image_dir,
            self.batch_size,
            self.n_classes,
            self.classifier.get_preprocess_input(),
        )

        # TODO: initialize callbacks TensorBoardBatch
        # tensorboard = TensorBoardBatch(log_dir=Path(job_dir).resolve() / 'logs')

        model_save_name = ('model_' + self.base_model_name.lower() +
                           '_{epoch:02d}_{val_accuracy:.3f}.hdf5')
        model_dir = self.job_dir / 'models'
        model_dir.mkdir(parents=True, exist_ok=True)

        logging_metrics = LoggingMetrics(logger=self.logger)
        logging_models = LoggingModels(
            logger=self.logger,
            filepath=str(model_dir / model_save_name),
            monitor='val_accuracy',
            verbose=1,
            save_best_only=True,
            save_weights_only=False,
        )

        def _train_dense_layers():
            if self.epochs_train_dense > 0:
                self.logger.info('\n****** Train dense layers ******\n')

                min_lr = self.learning_rate_dense / 10
                reduce_lr = ReduceLROnPlateau(
                    monitor='val_accuracy',
                    factor=0.3162,
                    patience=self.patience_learning_rate,
                    min_lr=min_lr,
                    verbose=1,
                )

                early_stopping = EarlyStopping(
                    monitor='val_accuracy',
                    min_delta=0.002,
                    patience=self.patience_early_stopping,
                    verbose=1,
                    mode='auto',
                    baseline=None,
                    restore_best_weights=True,
                )

                # freeze convolutional layers in base net
                for layer in self.classifier.get_base_layers():
                    layer.trainable = False

                self.classifier.compile()
                # self.classifier.summary()

                self.hist_dense = self.classifier.fit_generator(
                    generator=training_generator,
                    validation_data=validation_generator,
                    epochs=self.epochs_train_dense,
                    verbose=1,
                    use_multiprocessing=self.use_multiprocessing,
                    workers=self.workers,
                    max_queue_size=30,
                    callbacks=[
                        logging_metrics, logging_models, reduce_lr,
                        early_stopping
                    ],
                )

        def _train_all_layers():
            if self.epochs_train_all > 0:
                self.logger.info('\n****** Train all layers ******\n')

                min_lr = self.learning_rate_all / 10
                reduce_lr = ReduceLROnPlateau(
                    monitor='val_accuracy',
                    factor=0.3162,
                    patience=self.patience_learning_rate,
                    min_lr=min_lr,
                    verbose=1,
                )

                early_stopping = EarlyStopping(
                    monitor='val_accuracy',
                    min_delta=0.002,
                    patience=self.patience_early_stopping,
                    verbose=1,
                    mode='auto',
                    baseline=None,
                    restore_best_weights=False,
                )

                # unfreeze all layers
                for layer in self.classifier.get_base_layers():
                    layer.trainable = True

                self.classifier.set_learning_rate(self.learning_rate_all)

                self.classifier.compile()
                # self.classifier.summary()

                self.hist_all = self.classifier.fit_generator(
                    generator=training_generator,
                    validation_data=validation_generator,
                    epochs=self.epochs_train_dense + self.epochs_train_all,
                    initial_epoch=self.epochs_train_dense,
                    verbose=1,
                    use_multiprocessing=self.use_multiprocessing,
                    workers=self.workers,
                    max_queue_size=30,
                    callbacks=[
                        logging_metrics, logging_models, reduce_lr,
                        early_stopping
                    ],
                )

        self._set_patience()
        _train_dense_layers()
        _train_all_layers()

        K.clear_session()

    def run(self):
        """Builds the model and runs training.
        """
        self._build_model()
        self._fit_model()
示例#6
0
class Evaluation:
    """Calculates performance metrics for trained models.

    Loads the best model (validation accuracy) from *models* directory in job directory.
    All metrics and graphs are based on *test_samples.json* in job directory.
    Plots will only be shown if number of classes 20 or less.

    Attributes:
        image_dir: Path of image directory.
        job_dir: Path to job directory with samples.
        batch_size: Number of images per batch (default 64).
        base_model_name: Name of pretrained CNN (default MobileNet).
    """
    def __init__(self,
                 image_dir: str,
                 job_dir: str,
                 batch_size: int = BATCH_SIZE,
                 base_model_name: str = BASE_MODEL_NAME,
                 **kwargs) -> None:
        """Inits evaluation component.

        Loads the best model from job directory.
        Creates evaluation directory if app was started from commandline.
        """
        self.image_dir = Path(image_dir).resolve()
        self.job_dir = Path(job_dir).resolve()
        self.batch_size = batch_size
        self.base_model_name = base_model_name

        self.logger = get_logger(__name__, self.job_dir)
        self.samples_test: list = load_json(
            self.job_dir / 'test_samples.json')  # type: ignore
        self.class_mapping: dict = load_json(
            self.job_dir / 'class_mapping.json')  # type: ignore
        self.n_classes = len(self.class_mapping)
        self.classes = [
            str(self.class_mapping[str(i)]) for i in range(self.n_classes)
        ]
        self.y_true = np.array([i['label'] for i in self.samples_test])
        self.figures = []

        self._determine_plot_params()
        self._load_best_model()
        self._create_evaluation_dir()

    def _determine_plot_params(self):
        """Determines fontsizes and checks whether ipython kernel is present.

        Plots will only be shown if in ipython, otherwise saved as files.
        """
        self.fontsize_title = 18 if self.n_classes < 4 else 18
        self.fontsize_label = 14 if self.n_classes < 4 else 14
        self.fontsize_ticks = 12 if self.n_classes < 4 else 12
        self.mode_ipython = True if self._is_in_ipython_mode() else False

    def _is_in_ipython_mode(self):
        try:
            __IPYTHON__
            return True

        except NameError:
            ## TODO: Is this obsolete? Please remove!
            # Suppress figure window in terminal
            # https://matplotlib.org/faq/howto_faq.html#generate-images-without-having-a-window-appear
            import matplotlib
            matplotlib.use('Agg')

            return False

    def _load_best_model(self):
        """Loads best performing model from job_dir."""
        self.logger.info('\n****** Load model ******\n')
        best_model_file = self._determine_best_modelfile()
        self.best_model_file = Path(best_model_file).resolve()
        self.best_model = load_model(self.best_model_file)

        self.logger.info('loaded {}'.format(self.best_model_file))

    def _determine_best_modelfile(self):
        """Determines best performing model from job_dir."""
        job_path = self.job_dir / 'models'
        model_files = list(job_path.glob('**/*.hdf5'))
        max_acc_idx = np.argmax(
            [m.name.split('_')[3][:5] for m in model_files])
        return model_files[max_acc_idx]

    def _create_evaluation_dir(self):
        """Creates evaluation dir for reporting."""
        if not self.mode_ipython:
            evaluation_dir_name = self.best_model_file.name.split('.hdf5')[0]
            self.evaluation_dir = self.job_dir / 'evaluation_{}'.format(
                evaluation_dir_name)

            self.evaluation_dir.mkdir(parents=True, exist_ok=True)

    @staticmethod
    def _get_probabilities_prediction(
            predictions_dist: List[List[float]]) -> List[float]:
        index = np.argmax(predictions_dist, axis=1)
        prob = [pred[index] for pred, index in zip(predictions_dist, index)]
        return prob

    def _make_prediction_on_test_set(self):
        """Makes prediction on test set."""

        self.classifier = ImageClassifier(
            base_model_name=self.base_model_name,
            n_classes=self.n_classes,
            weights=None,
            dropout_rate=None,
            learning_rate=None,
            loss=None,
        )
        self.classifier.model = self.best_model

        self.data_generator = ValDataGenerator(
            samples=self.samples_test,
            image_dir=self.image_dir,
            batch_size=self.batch_size,
            n_classes=self.n_classes,
            basenet_preprocess=self.classifier.get_preprocess_input(),
        )

        predictions_dist = self.classifier.predict_generator(
            data_generator=self.data_generator,
            workers=WORKERS,
            use_multiprocessing=USE_MULTIPROCESSING,
            verbose=1,
        )

        self.y_pred = np.argmax(predictions_dist, axis=1)
        self.y_pred_prob = self._get_probabilities_prediction(
            predictions_dist=predictions_dist)

    def _plot_test_set_distribution(self, figsize: (float, float) = [8, 5]):
        """Plots bars with number of samples for each label in test set."""
        assert self.mode_ipython, 'Plotting is only possible when in ipython-mode'

        if self.n_classes > MAX_N_CLASSES:
            self.logger.info(
                '\nPlotting only for max {} classes\n'.format(MAX_N_CLASSES))
            return

        x_tick_marks = np.arange(self.n_classes)
        y_values = np.bincount(self.y_true)
        title = 'Number of images in test set: {}'.format(
            len(self.samples_test))

        fig = plt.figure(figsize=figsize)
        plt.rcParams["axes.grid"] = True
        plt.bar(x_tick_marks, y_values)
        plt.title(title, fontsize=self.fontsize_title)
        plt.xlabel('Label', fontsize=self.fontsize_label)
        plt.ylabel('Number of images', fontsize=self.fontsize_label)
        plt.xticks(x_tick_marks,
                   self.classes,
                   fontsize=self.fontsize_ticks,
                   rotation=30)

        plt.tight_layout()
        plt.show()

    def _print_test_set_distribution(self):
        """Prints distribution for labels in test set."""
        assert not self.mode_ipython, 'Printing is recommended when not in ipython-mode'

        max_length = len(max(self.classes, key=len))
        y_values = np.bincount(self.y_true)
        for i, c in enumerate(self.classes):
            label = c + ' ' * (max_length - len(c))
            self.logger.info("{}\t{}".format(label, y_values[i]))

    def _print_classification_report(self):
        """Prints classification for labels in test set."""
        cr = classification_report(y_true=self.y_true,
                                   y_pred=self.y_pred,
                                   target_names=self.classes,
                                   output_dict=True)

        metrics = ['precision', 'recall', 'f1-score', 'support']
        categories = self.classes.copy()
        categories.extend(['macro avg', 'weighted avg'])

        max_length = len(max(self.classes, key=len))
        self.logger.info("{}\t{}\t{}\t{}\t{}".format(' ' * max_length, 'prec',
                                                     "rec", "f1", "support"))
        for c in categories:
            label = c + ' ' * (max_length - len(c))
            line_output = "{}\t".format(label)
            for m in metrics:
                if m == 'support':
                    line_output += "{}\t".format(cr[c][m])
                else:
                    line_output += "{0:.2f}\t".format(cr[c][m])
            self.logger.info(line_output)

    def _plot_confusion_matrix(self,
                               figsize: (float, float) = [9, 9],
                               precision: bool = False):
        """Plots normalized confusion matrix."""
        assert self.mode_ipython, 'Plotting is only possible when in ipython-mode'

        if self.n_classes > MAX_N_CLASSES:
            self.logger.info(
                '\nPlotting only for max {} classes\n'.format(MAX_N_CLASSES))
            return

        (title, xlabel, ylabel) = \
            ('Confusion matrix (precision)', 'True label', 'Predicted label') if precision \
            else ('Confusion matrix (recall)', 'Predicted label', 'True label')

        cm = confusion_matrix(y_true=self.y_pred, y_pred=self.y_true) if precision \
            else confusion_matrix(y_true=self.y_true, y_pred=self.y_pred)
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        tick_marks = np.arange(self.n_classes)

        fig = plt.figure(figsize=figsize)
        plt.rcParams["axes.grid"] = False
        plt.imshow(cm,
                   interpolation='nearest',
                   cmap=plt.cm.Blues,
                   vmin=0,
                   vmax=1)
        plt.colorbar()
        plt.title(title, fontsize=self.fontsize_title)
        plt.xlabel(xlabel, fontsize=self.fontsize_label)
        plt.ylabel(ylabel, fontsize=self.fontsize_label)
        plt.xticks(tick_marks,
                   self.classes,
                   rotation=45,
                   fontsize=self.fontsize_ticks,
                   ha='right')
        plt.yticks(tick_marks, self.classes, fontsize=self.fontsize_ticks)

        thresh = cm.max() / 2.0
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(
                j,
                i,
                '{:.2f}'.format(cm[i, j]),
                horizontalalignment='center',
                color='white' if cm[i, j] > thresh else 'black',
                fontsize=self.fontsize_ticks,
            )

        plt.tight_layout()
        plt.show()

    def _print_confusion_matrix(self, precision: bool = False):
        """Prints normalized confusion matrix."""
        assert not self.mode_ipython, 'Printing is recommended when not in ipython-mode'

        cm = confusion_matrix(y_true=self.y_pred, y_pred=self.y_true) if precision \
            else confusion_matrix(y_true=self.y_true, y_pred=self.y_pred)
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

        max_length = len(max(self.classes, key=len))
        for i, c in enumerate(self.classes):
            label = c + ' ' * (max_length - len(c))
            line_output = "{}\t".format(label)
            for x in cm[i].tolist():
                line_output += "{0:.2f}\t".format(x)
            self.logger.info(line_output)

    def _plot_correct_wrong_examples(self):
        """Plots correct and wrong examples for each label in test set."""
        assert self.mode_ipython, 'Plotting is only possible when in ipython-mode'

        if self.n_classes > MAX_N_CLASSES:
            self.logger.info(
                '\nPlotting only for max {} classes\n'.format(MAX_N_CLASSES))
            return

        for i in range(len(self.classes)):
            c, w = self.get_correct_wrong_examples(label=i)
            self.visualize_images(
                c,
                title='Label: "{}" (correct predicted)'.format(
                    self.classes[i]),
                show_heatmap=True,
                n_plot=3)
            self.visualize_images(w,
                                  title='Label: "{}" (wrong predicted)'.format(
                                      self.classes[i]),
                                  show_heatmap=True,
                                  n_plot=3)

    def _create_report(self, report_kernel_name: str, report_export_html: bool,
                       report_export_pdf: bool):
        """Creates report from notebook-template and stores it in different formats all figures.

            - Jupyter Notebook
            - HTML
            - PDF
        """
        assert not self.mode_ipython, 'Create report is only possible when not in ipython mode'

        filepath_template = dirname(
            imageatm.notebooks.__file__) + '/evaluation_template.ipynb'
        filepath_notebook = self.evaluation_dir / 'evaluation_report.ipynb'
        filepath_html = self.evaluation_dir / 'evaluation_report.html'
        filepath_pdf = self.evaluation_dir / 'evaluation_report.pdf'

        pm.execute_notebook(str(filepath_template),
                            str(filepath_notebook),
                            parameters=dict(image_dir=str(self.image_dir),
                                            job_dir=str(self.job_dir)),
                            kernel_name=report_kernel_name)

        with open(filepath_notebook) as f:
            nb = nbformat.read(f, as_version=4)

        if report_export_html:
            self.logger.info('\n****** Create HTML ******\n')
            with open(filepath_notebook) as f:
                nb = nbformat.read(f, as_version=4)

            html_exporter = HTMLExporter()
            html_data, resources = html_exporter.from_notebook_node(nb)

            with open(filepath_html, 'w') as f:
                f.write(html_data)
                f.close()

        if report_export_pdf:
            self.logger.info('\n****** Create PDF ******\n')

            pdf_exporter = PDFExporter()
            pdf_exporter.template_file = dirname(
                imageatm.notebooks.__file__
            ) + '/tex_templates/evaluation_report.tplx'
            pdf_data, resources = pdf_exporter.from_notebook_node(
                nb, resources={'metadata': {
                    'name': 'Evaluation Report'
                }})

            with open(filepath_pdf, 'wb') as f:
                f.write(pdf_data)
                f.close()

    # TO-DO: Enforce string or integer but not both at the same time
    def get_correct_wrong_examples(
            self,
            label: Union[int, str]) -> Tuple[TYPE_IMAGE_LIST, TYPE_IMAGE_LIST]:
        """Gets correctly and wrongly predicted samples for a given label.

        Args:
            label: int or str (label for which the predictions should be considered).

        Returns:
            (correct, wrong): Tuple of two image lists.
        """
        correct = []
        wrong = []

        if type(label) == str:
            class_mapping_inv = {v: k for k, v in self.class_mapping.items()}
            label = int(class_mapping_inv[label])

        for i, sample in enumerate(self.samples_test):
            if self.y_true[i] == label:
                image_file = self.image_dir / sample['image_id']
                if self.y_true[i] == self.y_pred[i]:
                    correct.append([
                        i,
                        load_image(image_file, target_size=(224, 224)), sample
                    ])
                else:
                    wrong.append([
                        i,
                        load_image(image_file, target_size=(224, 224)), sample
                    ])

        return correct, wrong

    def visualize_images(self,
                         image_list: TYPE_IMAGE_LIST,
                         title: str = 'Images for visualisation',
                         show_heatmap: bool = False,
                         n_plot: int = 20):
        """Visualizes images in a sample list.

        Args:
            image_list: sample list.
            show_heatmap: boolean (generates a gradient based class activation map (grad-CAM), default False).
            n_plot: maximum number of plots to be shown (default 20).
        """
        assert self.mode_ipython, 'Plotting is only possible when in ipython-mode'

        if len(image_list) == 0:
            print('Empty list.')
            return
        else:
            n_rows = min(n_plot, len(image_list))
            n_cols = 2 if show_heatmap else 1

            figsize = [5 * n_cols, 4 * n_rows]
            fig = plt.figure(figsize=figsize)
            fig.suptitle(title, fontsize=self.fontsize_title)

            plot_count = 1
            for (i, img, sample) in image_list[:n_rows]:
                plt.subplot(n_rows, n_cols, plot_count)
                plt.imshow(img)
                plt.axis('off')
                plt.title('true: {}, predicted: {} ({})'.format(
                    self.class_mapping[str(self.y_true[i])],
                    self.class_mapping[str(self.y_pred[i])],
                    str(round(self.y_pred_prob[i], 2)),
                ))
                plot_count += 1

                if show_heatmap is True:
                    heatmap = visualize_cam(
                        model=self.classifier.model,
                        layer_idx=89,
                        filter_indices=[self.y_pred[i]],
                        seed_input=self.classifier.get_preprocess_input()(
                            np.array(img).astype(np.float32)),
                    )
                    plt.subplot(n_rows, n_cols, plot_count)
                    plt.imshow(img)
                    plt.imshow(heatmap, alpha=0.7)
                    plt.axis('off')
                    plot_count += 1

            plt.show()

    def run(self,
            report_create: bool = False,
            report_kernel_name: str = 'imageatm',
            report_export_html: bool = False,
            report_export_pdf: bool = False):
        """Runs evaluation pipeline on the best model found in job directory for the specific test set:

            - Makes prediction on test set
            - Plots test set distribution
            - Plots classification report (accuracy, precision, recall)
            - Plots confusion matrix (on precsion and on recall)
            - Plots correct and wrong examples

           If not in ipython mode an evaluation report is created.

        Args:
        	report_create: boolean (create ipython kernel)
            report_kernel_name: str (name of ipython kernel)
            report_export_html: boolean (exports report to html).
            report_export_pdf: boolean (exports report to pdf).
        """
        if self.mode_ipython:
            self.logger.info('\n****** Make prediction on test set ******\n')
            self._make_prediction_on_test_set()

            self.logger.info('\n****** Plot distribution on test set ******\n')
            self._plot_test_set_distribution(figsize=[8, 5])

            self.logger.info('\n****** Plot classification report ******\n')
            # self._plot_classification_report(figsize=[4 + self.n_classes*0.5, 4 + self.n_classes*0.5])
            self._print_classification_report()

            self.logger.info(
                '\n****** Plot confusion matrix (recall) ******\n')
            self._plot_confusion_matrix(
                figsize=[4 + self.n_classes * 0.5, 4 + self.n_classes * 0.5])

            self.logger.info(
                '\n****** Plot confusion matrix (precision) ******\n')
            self._plot_confusion_matrix(
                figsize=[4 + self.n_classes * 0.5, 4 + self.n_classes * 0.5],
                precision=True)

            self.logger.info(
                '\n****** Plot correct and wrong examples ******\n')
            self._plot_correct_wrong_examples()

        elif report_create:
            self.logger.info(
                '\n****** Create Jupyter Notebook (this may take a while) ******\n'
            )
            self._create_report(report_kernel_name, report_export_html,
                                report_export_pdf)

        else:
            self.logger.info('\n****** Make prediction on test set ******\n')
            self._make_prediction_on_test_set()

            self.logger.info(
                '\n****** Print distribution on test set ******\n')
            self._print_test_set_distribution()

            self.logger.info('\n****** Print classification report ******\n')
            self._print_classification_report()

            self.logger.info(
                '\n****** Print confusion matrix (recall) ******\n')
            self._print_confusion_matrix()

            self.logger.info(
                '\n****** Print confusion matrix (precision) ******\n')
            self._print_confusion_matrix(precision=True)