def train(model_name: str, data: CATDOGDataset) -> None:
    train_dataset = data.get_train_set()
    val_dataset = data.get_val_set()

    model = build_model(data.img_shape, data.num_classes)

    model.compile(loss=categorical_crossentropy,
                  optimizer=Adam(learning_rate=LEARNING_RATE),
                  metrics=["accuracy"])

    model.summary()

    model_log_dir = os.path.join(LOGS_DIR, f"model_{model_name}")

    tb_callback = TensorBoard(log_dir=model_log_dir,
                              histogram_freq=0,
                              profile_batch=0,
                              write_graph=False)

    mc_callback = ModelCheckpoint(os.path.join(MODELS_DIR, model_name),
                                  monitor='val_loss',
                                  mode='auto',
                                  save_best_only=True,
                                  save_weights_only=False,
                                  verbose=1)

    es_callback = EarlyStopping(
        monitor='val_loss',  # auch val_accuracy probieren
        mode='auto',
        patience=5,
        verbose=1)

    reduce_lr_callback = ReduceLROnPlateau(
        monitor='val_accuracy',  # auch val_accuracy probieren
        mode='auto',
        factor=0.9,
        patience=3,
        cooldown=0,
        min_lr=0,
        verbose=1)

    history = model.fit(
        train_dataset,
        epochs=EPOCHS,
        batch_size=data.batch_size,
        verbose=1,
        validation_data=val_dataset,
        callbacks=[tb_callback, mc_callback, reduce_lr_callback])
def evaluate(model_name: str, data: CATDOGDataset) -> None:
    model_path = os.path.join(MODELS_DIR, model_name)

    model = load_model(model_path)

    test_dataset = data.get_test_set()

    y_pred = model.predict(test_dataset)
    plot_predicted_images(y_pred, test_dataset)

    score = model.evaluate(test_dataset)
    print(
        f'The Accuracy on the Test set is {score[1]:.2%} and the loss is {score[0]:.3f}'
    )
    plt.subplot(121)
    plt.title('Loss')
    plt.plot(history.history['loss'], label='train')
    plt.plot(history.history['val_loss'], label='test')
    plt.grid()
    # plot accuracy
    plt.subplot(122)
    plt.title('Accuracy')
    plt.plot(history.history['accuracy'], label='train')
    plt.plot(history.history['val_accuracy'], label='test')
    plt.grid()
    plt.show()


if __name__ == "__main__":
    data = CATDOGDataset()

    train_dataset = data.get_train_set()
    test_dataset = data.get_test_set()
    val_dataset = data.get_val_set()

    model = build_model(data.img_shape, data.num_classes)

    model.compile(loss=categorical_crossentropy,
                  optimizer=Adam(learning_rate=LEARNING_RATE),
                  metrics=["accuracy"])

    model.summary()

    history = model.fit(train_dataset,
                        epochs=EPOCHS,
        verbose=1,
        validation_data=val_dataset,
        callbacks=[tb_callback, mc_callback, reduce_lr_callback])


def evaluate(model_name: str, data: CATDOGDataset) -> None:
    model_path = os.path.join(MODELS_DIR, model_name)

    model = load_model(model_path)

    test_dataset = data.get_test_set()

    y_pred = model.predict(test_dataset)
    plot_predicted_images(y_pred, test_dataset)

    score = model.evaluate(test_dataset)
    print(
        f'The Accuracy on the Test set is {score[1]:.2%} and the loss is {score[0]:.3f}'
    )


if __name__ == "__main__":
    data = CATDOGDataset()

    model_name = f"catvsdog_4block_load_model"

    if not os.path.exists(os.path.join(MODELS_DIR, model_name)):
        train(model_name, data)

    evaluate(model_name, data)