def get_area_size_counts_multiple_thresholds(resmaps_val_uint8,
                                             th_min=128,
                                             th_max=255,
                                             binning=True,
                                             nbins=200):
    if binning == True:
        # compute max_area
        max_area = get_max_area(resmaps_val_uint8, th_min)

        # initialize variables and parameters
        thresholds = np.arange(th_min, th_max + 1)
        range_area = (0.5, float(max_area) + 0.5)
        counts = np.zeros(shape=(len(thresholds), nbins))

        # loop over all thresholds and calculate area counts (with binning)
        for i, threshold in enumerate(thresholds):
            resmaps_val_uint8_th = threshold_images(resmaps_val_uint8,
                                                    threshold)
            # resmaps_val_uint8_th_fil = filter_median_images(resmaps_val_uint8_th)
            resmaps_val_labeled, areas_all_val = label_images(
                resmaps_val_uint8_th)
            areas_all_val_1d = [
                item for sublist in areas_all_val for item in sublist
            ]
            count, edges = np.histogram(areas_all_val_1d,
                                        bins=nbins,
                                        density=False,
                                        range=range_area)
            counts[i] = count
        return counts, edges

    else:
        max_area = get_max_area(resmaps_val_uint8, th_min)
        # initialize variables and parameters
        thresholds = np.arange(th_min, th_max + 1)
        areas = np.arange(1, max_area + 1)
        counts = np.zeros(shape=(len(thresholds), max_area))

        # loop over all thresholds and calculate area counts
        for i, threshold in enumerate(thresholds):
            resmaps_val_uint8_th = threshold_images(resmaps_val_uint8,
                                                    threshold)
            resmaps_val_uint8_th_fil = filter_median_images(
                resmaps_val_uint8_th)
            resmaps_val_labeled, areas_all_val = label_images(
                resmaps_val_uint8_th_fil)
            areas_all_val_1d = [
                item for sublist in areas_all_val for item in sublist
            ]
            counts[i] = np.array(
                [areas_all_val_1d.count(area) for area in areas])

        return counts, areas
def get_max_area(resmaps, threshold):
    resmaps_th = threshold_images(resmaps, threshold)
    # resmaps_fil = filter_median_images(resmaps_th)
    resmaps_labeled, areas_all = label_images(resmaps_th)
    areas_all_1d = [item for sublist in areas_all for item in sublist]
    max_area = np.amax(np.array(areas_all_1d))
    return max_area
def plot_area_distro_for_multiple_thresholds(resmaps,
                                             thresholds_to_plot,
                                             th_min,
                                             nbins=200,
                                             method="line",
                                             title="provide a title"):
    # fix range so that counts have the save x value
    max_area = get_max_area(resmaps, th_min)
    range_area = (0.5, float(max_area) + 0.5)

    # compute residual maps for multiple thresholds
    fig = plt.figure(figsize=(12, 5))
    for threshold in thresholds_to_plot:
        resmaps_th = threshold_images(resmaps, threshold)
        # resmaps_fil = filter_median_images(resmaps_th, kernel_size=3)
        resmaps_labeled, areas_all_val = label_images(resmaps_th)
        areas_all_1d = [item for sublist in areas_all_val for item in sublist]

        if method == "hist":
            count, bins, ignored = plt.hist(
                areas_all_1d,
                bins=nbins,
                density=False,
                range=range_area,  # previously commented
                histtype="barstacked",
                label="threshold = {}".format(threshold),
            )

        elif method == "line":
            count, bins = np.histogram(
                areas_all_1d,
                bins=nbins,
                density=False,
                range=range_area  # previously commented
            )
            bins_middle = bins[:-1] + ((bins[1] - bins[0]) / 2)
            plt.plot(
                bins_middle,
                count,
                linestyle="-",
                linewidth=0.5,  # 0.5
                marker="x",  # "o"
                markersize=0.5,  # 0.5
                label="threshold = {}".format(threshold),
            )
            # plt.fill_between(bins_middle, count)

    plt.title(title)
    plt.legend()
    plt.xlabel("area size in pixel")
    plt.ylabel("count")
    plt.grid()
    plt.show()
Esempio n. 4
0
def main(args):
    # Get validation arguments
    model_path = args.path
    save = args.save
    if args.area is None:
        pass
    else:
        area_list = sorted(list(set(args.area)))
        if len(area_list) > 2:
            raise ValueError(
                "Exactly two area values must be passed: area_min and area_max"
            )
        else:
            min_areas = list(np.arange(area_list[0], area_list[1] + 1))

    # ========================= SETUP ==============================

    # load model, setup and history
    model, setup, history = utils.load_model_HDF5(model_path)

    # data setup
    directory = setup["data_setup"]["directory"]
    val_data_dir = os.path.join(directory, "train")
    nb_training_images = setup["data_setup"]["nb_training_images"]
    nb_validation_images = setup["data_setup"]["nb_validation_images"]

    # preprocessing_setup
    rescale = setup["preprocessing_setup"]["rescale"]
    shape = setup["preprocessing_setup"]["shape"]
    preprocessing = setup["preprocessing_setup"]["preprocessing"]

    # train_setup
    color_mode = setup["train_setup"]["color_mode"]
    learning_rate = setup["train_setup"]["learning_rate"]
    decay = setup["train_setup"]["decay"]
    epochs_trained = setup["train_setup"]["epochs_trained"]
    nb_training_images_aug = setup["train_setup"]["nb_training_images_aug"]
    epochs = setup["train_setup"]["epochs"]
    batch_size = setup["train_setup"]["batch_size"]
    channels = setup["train_setup"]["channels"]
    validation_split = setup["train_setup"]["validation_split"]
    architecture = setup["train_setup"]["architecture"]
    loss = setup["train_setup"]["loss"]

    tag = setup["tag"]

    # create a results directory if not existent
    model_dir_name = os.path.basename(str(Path(model_path).parent))

    save_dir = os.path.join(
        os.getcwd(),
        "results",
        directory,
        architecture,
        loss,
        model_dir_name,
        "validation",
        # "a_" + str(min_area),
    )
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    # Plot and save loss and val_loss
    # plot = pd.DataFrame(history[["loss", "val_loss"]]).plot(figsize=(8, 5))
    # fig = plot.get_figure()
    # fig.savefig(os.path.join(save_dir, "train_val_losses.png"))

    # ============================= PREPROCESSING ===============================

    # This will do preprocessing
    if architecture in ["mvtec", "mvtec2"]:
        preprocessing_function = None
    elif architecture == "resnet":
        preprocessing_function = keras.applications.inception_resnet_v2.preprocess_input
    elif architecture == "nasnet":
        preprocessing_function = keras.applications.nasnet.preprocess_input

    # same preprocessing as in training
    validation_datagen = ImageDataGenerator(
        rescale=rescale,
        data_format="channels_last",
        validation_split=validation_split,
        preprocessing_function=preprocessing_function,
    )

    # retrieve preprocessed validation images as a numpy array
    validation_generator = validation_datagen.flow_from_directory(
        directory=val_data_dir,
        target_size=shape,
        color_mode=color_mode,
        batch_size=nb_validation_images,
        shuffle=False,
        class_mode="input",
        subset="validation",
    )
    imgs_val_input = validation_generator.next()[0]

    # retrieve validation image_names
    filenames = validation_generator.filenames

    # get reconstructed images (i.e predictions) on validation dataset
    imgs_val_pred = model.predict(imgs_val_input)

    # compute residual maps on validation dataset
    resmaps_val = calculate_resmaps(imgs_val_input, imgs_val_pred, loss="SSIM")

    if color_mode == "rgb":
        resmaps_val = tf.image.rgb_to_grayscale(resmaps_val)

    if save:
        utils.save_np(imgs_val_input, save_dir, "imgs_val_input.npy")
        utils.save_np(imgs_val_pred, save_dir, "imgs_val_pred.npy")
        utils.save_np(resmaps_val, save_dir, "resmaps_val.npy")

    if args.area is None:
        print("[INFO] exiting")
        exit()

    # scale pixel values linearly to [0,1]
    # resmaps_val = scale_pixel_values(architecture, resmaps_val)

    # Convert to 8-bit unsigned int
    resmaps_val = img_as_ubyte(resmaps_val)

    # blur resmaps
    # resmaps_val = filter_gauss_images(resmaps_val)

    # ========================= VALIDATION ALGORITHM ==============================

    # initialize validation dictionary
    dict_val = {"min_area": [], "threshold": []}

    # set threshold boundaries
    threshold_min = np.amin(resmaps_val)
    threshold_max = np.amax(resmaps_val)

    # initialize progress bar
    l = len(min_areas)
    printProgressBar(0, l, prefix="Progress:", suffix="Complete", length=50)

    # loop over all min_areas and compute corresponding thresholds
    for i, min_area in enumerate(min_areas):

        for threshold in range(threshold_min, threshold_max + 1):
            # threshold residual maps
            resmaps_th = threshold_images(resmaps_val, threshold)

            # filter images to remove salt noise
            # resmaps_val = filter_median_images(resmaps_val, kernel_size=3)

            # compute connected components
            resmaps_labeled, areas_all = label_images(resmaps_th)

            # check if area of largest anomalous region is below the minimum area
            areas_all_flat = [
                item for sublist in areas_all for item in sublist
            ]
            areas_all_flat.sort(reverse=True)
            if min_area > areas_all_flat[0]:
                break

        # print progress bar
        time.sleep(0.1)
        printProgressBar(i + 1,
                         l,
                         prefix="Progress:",
                         suffix="Complete",
                         length=50)

        # append min_area and corresponding threshold to validation dictionary
        dict_val["min_area"].append(min_area)
        dict_val["threshold"].append(threshold)

    # print validation DataFrame to console
    df_val = pd.DataFrame.from_dict(dict_val)
    with pd.option_context("display.max_rows", None, "display.max_columns",
                           None):
        print(df_val)

    # save validation DataFrame as .txt and .pkl files
    with open(os.path.join(save_dir, "validation_results.txt"), "w+") as f:
        # f.truncate(0)
        f.write(df_val.to_string(header=True, index=True))
        print("validation_results.txt saved at {}".format(save_dir))
    df_val.to_pickle(os.path.join(save_dir, "validation_results.pkl"))
    print("validation_results.pkl saved at {}".format(save_dir))
Esempio n. 5
0
def main(args):
    # parse arguments
    model_path = args.path
    save = args.save
    threshold = args.threshold
    min_area = args.area

    if threshold is None and min_area is None:
        adopt_validation = True
        print("[INFO] Testing with validation areas and thresholds.")
    elif threshold is not None and min_area is not None:
        adopt_validation = False
        print("[INFO] Testing with user passed areas and thresholds")
    else:
        raise Exception("Threshold and area must both be passed.")

    # ========================= SETUP ==============================

    # load model, setup and history
    model, setup, history = utils.load_model_HDF5(model_path)

    # data setup
    directory = setup["data_setup"]["directory"]
    test_data_dir = os.path.join(directory, "test")
    nb_training_images = setup["data_setup"]["nb_training_images"]
    nb_validation_images = setup["data_setup"]["nb_validation_images"]

    # preprocessing_setup
    rescale = setup["preprocessing_setup"]["rescale"]
    shape = setup["preprocessing_setup"]["shape"]
    preprocessing = setup["preprocessing_setup"]["preprocessing"]

    # train_setup
    color_mode = setup["train_setup"]["color_mode"]
    learning_rate = setup["train_setup"]["learning_rate"]
    decay = setup["train_setup"]["decay"]
    epochs_trained = setup["train_setup"]["epochs_trained"]
    nb_training_images_aug = setup["train_setup"]["nb_training_images_aug"]
    epochs = setup["train_setup"]["epochs"]
    batch_size = setup["train_setup"]["batch_size"]
    channels = setup["train_setup"]["channels"]
    validation_split = setup["train_setup"]["validation_split"]
    architecture = setup["train_setup"]["architecture"]
    loss = setup["train_setup"]["loss"]

    tag = setup["tag"]

    # create directory to save test results
    model_dir_name = os.path.basename(str(Path(model_path).parent))

    save_dir = os.path.join(
        os.getcwd(),
        "results",
        directory,
        architecture,
        loss,
        model_dir_name,
        "test",
        # "th_" + str(threshold) + "_a_" + str(min_area),
    )

    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    if adopt_validation:
        # load areas and corresponding thresholds from validation text file
        parent_dir = str(Path(save_dir).parent)
        val_dir = os.path.join(parent_dir, "validation")
        df_val = pd.read_pickle(os.path.join(val_dir,
                                             "validation_results.pkl"))
        dict_val = df_val.to_dict(orient="list")
        # create a list containing (min_area, threshold) pairs
        elems = list(zip(dict_val["min_area"], dict_val["threshold"]))
    else:
        # compute combinations of given areas and thresholds
        areas = sorted(list(set(args.area)))
        thresholds = sorted(list(set(args.threshold)))
        # create a list containing (min_area, threshold) pairs
        elems = [(area, threshold) for threshold in thresholds
                 for area in areas]

    with open(os.path.join(save_dir, "classification.txt"), "w+") as f:
        f.write("Classification of image files.\n\n")

    # ============================= PREPROCESSING ===============================

    # This will do preprocessing
    if architecture in ["mvtec", "mvtec2"]:
        preprocessing_function = None
    elif architecture == "resnet":
        preprocessing_function = keras.applications.inception_resnet_v2.preprocess_input
    elif architecture == "nasnet":
        preprocessing_function = keras.applications.nasnet.preprocess_input

    test_datagen = ImageDataGenerator(
        rescale=rescale,
        data_format="channels_last",
        preprocessing_function=preprocessing_function,
    )

    total_number = utils.get_total_number_test_images(test_data_dir)

    # retrieve preprocessed test images as a numpy array
    test_generator = test_datagen.flow_from_directory(
        directory=test_data_dir,
        target_size=shape,
        color_mode=color_mode,
        batch_size=total_number,
        shuffle=False,
        class_mode="input",
    )
    imgs_test_input = test_generator.next()[0]

    # retrieve test image names
    filenames = test_generator.filenames

    # predict on test images
    imgs_test_pred = model.predict(imgs_test_input)

    # calculate residual maps on test set
    resmaps_test = calculate_resmaps(imgs_test_input,
                                     imgs_test_pred,
                                     loss="SSIM")

    if color_mode == "rgb":
        resmaps_test = tf.image.rgb_to_grayscale(resmaps_test)

    # format float64, values between [-1,1]
    if save:
        utils.save_np(imgs_test_input, save_dir, "imgs_test_input.npy")
        utils.save_np(imgs_test_pred, save_dir, "imgs_test_pred.npy")
        utils.save_np(resmaps_test, save_dir, "resmaps_test.npy")

    # scale pixel values linearly to [0,1]
    # resmaps_test = scale_pixel_values(architecture, resmaps_test)

    # Convert to 8-bit unsigned int
    resmaps_test = img_as_ubyte(resmaps_test)

    # blur resmaps
    # resmaps_val = filter_gauss_images(resmaps_val)

    # =================== Classification Algorithm ==========================

    # initialize dictionary to store test results
    dict_test = {"min_area": [], "threshold": [], "TPR": [], "TNR": []}

    # initialize progress bar
    l = len(elems)
    printProgressBar(0, l, prefix="Progress:", suffix="Complete", length=50)

    # classify test images for all (min_area, threshold) pairs
    for i, elem in enumerate(elems):
        # get (min_area, threshold) pair
        min_area, threshold = elem[0], elem[1]

        # threshold residual maps with the given threshold
        resmaps_th = threshold_images(resmaps_test, threshold)

        # filter images to remove salt noise
        # resmaps_test = filter_median_images(resmaps_test, kernel_size=3)

        # compute connected components
        resmaps_labeled, areas_all = label_images(resmaps_th)

        # classify images
        y_pred = classify(areas_all, min_area)

        # retrieve ground truth
        y_true = [
            1 if "good" not in filename.split("/") else 0
            for filename in filenames
        ]

        # save classification of image files in a .txt file
        classification = {
            "filenames": filenames,
            "predictions": y_pred,
            "truth": y_true,
        }
        df_clf = pd.DataFrame.from_dict(classification)
        with open(os.path.join(save_dir, "classification.txt"), "a") as f:
            f.write("min_area = {}, threshold = {}, index = {}\n\n".format(
                min_area, threshold, i))
            f.write(df_clf.to_string(header=True, index=True))
            f.write("\n" + "_" * 50 + "\n\n")

        # condition positive (P)
        P = y_true.count(1)

        # condition negative (N)
        N = y_true.count(0)

        # true positive (TP)
        TP = np.sum([
            1 if y_pred[i] == y_true[i] == 1 else 0
            for i in range(total_number)
        ])

        # true negative (TN)
        TN = np.sum([
            1 if y_pred[i] == y_true[i] == False else 0
            for i in range(total_number)
        ])

        # sensitivity, recall, hit rate, or true positive rate (TPR)
        TPR = TP / P

        # specificity, selectivity or true negative rate (TNR)
        TNR = TN / N

        # confusion matrix
        conf_matrix = confusion_matrix(y_true, y_pred, normalize="true")
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred,
                                          normalize="true").ravel()

        # append test results to dictionary
        dict_test["min_area"].append(min_area)
        dict_test["threshold"].append(threshold)
        dict_test["TPR"].append(TPR)
        dict_test["TNR"].append(TNR)

        # print progress bar
        time.sleep(0.1)
        printProgressBar(i + 1,
                         l,
                         prefix="Progress:",
                         suffix="Complete",
                         length=50)

    # print test results to console
    df_test = pd.DataFrame.from_dict(dict_test)
    df_test.sort_values(by=["min_area", "threshold"], inplace=True)
    with pd.option_context("display.max_rows", None, "display.max_columns",
                           None):
        print(df_test)

    # save DataFrame (as .txt and .pkl)
    with open(os.path.join(save_dir, "test_results.txt"), "w+") as f:
        f.write(df_test.to_string(header=True, index=True))
        print("test_results.txt saved at {}".format(save_dir))

    df_test.to_pickle(os.path.join(save_dir, "test_results.pkl"))
    print("test_results.pkl saved at {}".format(save_dir))
Esempio n. 6
0
def main(args):
    # Get finetuning arguments
    model_path = args.path
    save = args.save
    img_val = args.val
    img_test = args.test
    thresholds_to_plot = sorted(list(set(args.list)))
    # thresholds_to_plot.sort()

    # ========================= SETUP ==============================

    # load model, setup and history
    model, setup, history = utils.load_model_HDF5(model_path)

    # data setup
    directory = setup["data_setup"]["directory"]
    val_data_dir = os.path.join(directory, "train")
    nb_training_images = setup["data_setup"]["nb_training_images"]
    nb_validation_images = setup["data_setup"]["nb_validation_images"]

    # preprocessing_setup
    rescale = setup["preprocessing_setup"]["rescale"]
    shape = setup["preprocessing_setup"]["shape"]
    preprocessing = setup["preprocessing_setup"]["preprocessing"]

    # train_setup
    color_mode = setup["train_setup"]["color_mode"]
    learning_rate = setup["train_setup"]["learning_rate"]
    decay = setup["train_setup"]["decay"]
    epochs_trained = setup["train_setup"]["epochs_trained"]
    nb_training_images_aug = setup["train_setup"]["nb_training_images_aug"]
    epochs = setup["train_setup"]["epochs"]
    batch_size = setup["train_setup"]["batch_size"]
    channels = setup["train_setup"]["channels"]
    validation_split = setup["train_setup"]["validation_split"]
    architecture = setup["train_setup"]["architecture"]
    loss = setup["train_setup"]["loss"]

    # create directory to save results
    model_dir_name = os.path.basename(str(Path(model_path).parent))
    save_dir = os.path.join(
        os.getcwd(),
        "results",
        directory,
        architecture,
        loss,
        model_dir_name,
        "finetune",
    )

    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    # ============================= PREPROCESSING ===============================

    # get preprocessing function corresponding to model
    preprocessing_function = utils.get_preprocessing_function(architecture)

    # same preprocessing as in training
    validation_datagen = ImageDataGenerator(
        rescale=rescale,
        data_format="channels_last",
        validation_split=validation_split,
        preprocessing_function=preprocessing_function,
    )

    # retrieve preprocessed validation images as a numpy array
    validation_generator = validation_datagen.flow_from_directory(
        directory=val_data_dir,
        target_size=shape,
        color_mode=color_mode,
        batch_size=nb_validation_images,
        shuffle=False,
        class_mode="input",
        subset="validation",
    )
    imgs_val_input = validation_generator.next()[0]

    # ============== RECONSTRUCT IMAGES AND COMPUTE RESIDUAL MAPS ==============

    # get reconstructed images (i.e predictions) on validation dataset
    print("computing reconstructions of validation images...")
    imgs_val_pred = model.predict(imgs_val_input)

    # compute residual maps on validation dataset
    resmaps_val = imgs_val_input - imgs_val_pred

    if color_mode == "rgb":
        resmaps_val = tf.image.rgb_to_grayscale(resmaps_val)

    if save:
        utils.save_np(imgs_val_input, save_dir, "imgs_val_input.npy")
        utils.save_np(imgs_val_pred, save_dir, "imgs_val_pred.npy")
        utils.save_np(resmaps_val, save_dir, "resmaps_val.npy")

    # plot a sample validation image alongside its corresponding reconstruction and resmap for inspection
    if img_val is not None:
        plt.style.use("default")
        # compute image index
        index_val = validation_generator.filenames.index(img_val)
        fig = utils.plot_input_pred_resmaps_val(imgs_val_input, imgs_val_pred,
                                                resmaps_val, index_val)
        fig.savefig(os.path.join(save_dir, "val_plots.png"))
        print("figure saved at {}".format(
            os.path.join(save_dir, "val_plots.png")))

    # scale pixel values linearly to [0,1]
    resmaps_val = scale_pixel_values(architecture, resmaps_val)

    # Convert to 8-bit unsigned int for further processing
    resmaps_val = img_as_ubyte(resmaps_val)

    # ======= NUMBER OF REGIONS, THEIR MEAN SIZE AND STD DEVIATION WITH INCREASING THRESHOLDS =======

    # blur resmaps
    # resmaps_val = filter_gauss_images(resmaps_val)

    # compute descriptive values
    min_pixel_value = np.amin(resmaps_val)
    max_pixel_value = np.amax(resmaps_val)
    mu = resmaps_val.flatten().mean()
    sigma = resmaps_val.flatten().std()

    # set relevant threshold interval
    threshold_min = int(round(scipy.stats.norm(mu, sigma).ppf(0.97), 1)) - 1
    threshold_max = max_pixel_value

    dict_out = {
        "threshold": [],
        "nb_regions": [],
        "mean_areas_size": [],
        "std_areas_size": [],
        "sum_areas_size": [],
    }

    # compute and plot number of anomalous regions and their area sizes with increasing thresholds
    print(
        "computing anomalous regions and area sizes with increasing thresholds..."
    )
    for threshold in range(threshold_min, threshold_max + 1):
        # threshold residual maps
        resmaps_th = threshold_images(resmaps_val, threshold)

        # filter images to remove salt noise
        resmaps_fil = filter_median_images(resmaps_th, kernel_size=3)

        # compute anomalous regions and their size for current threshold
        resmaps_labeled, areas_all = label_images(resmaps_fil)
        areas_all_1d = [item for sublist in areas_all for item in sublist]

        # compute the size of the biggest anomalous region (corresponds with smallest threshold)
        if threshold == threshold_min:
            max_region_size = np.amax(np.array(areas_all_1d))

        nb_regions = len(areas_all_1d)
        if nb_regions == 0:
            break

        mean_areas_size = np.mean(areas_all_1d)
        std_areas_size = np.std(areas_all_1d)
        sum_areas_size = np.sum(areas_all_1d)

        # append values to dictionnary
        dict_out["threshold"].append(threshold)
        dict_out["nb_regions"].append(nb_regions)
        dict_out["mean_areas_size"].append(mean_areas_size)
        dict_out["std_areas_size"].append(std_areas_size)
        dict_out["sum_areas_size"].append(sum_areas_size)
        print("threshold: {}".format(threshold))

    plt.style.use("seaborn-darkgrid")

    # print DataFrame to console
    df_out = pd.DataFrame.from_dict(dict_out)
    with pd.option_context("display.max_rows", None, "display.max_columns",
                           None):
        print(df_out)

    # save DataFrame (as text AND as pkl)
    with open(os.path.join(save_dir, "test_results_all.txt"), "a") as f:
        f.truncate(0)
        f.write(df_out.to_string(header=True, index=True))

    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    ax1.plot(df_out.threshold, df_out.sum_areas_size, "#1f77b4")
    ax1.set_ylabel("sum of anomalous region's area size", color="#1f77b4")
    for tl in ax1.get_yticklabels():
        tl.set_color("#1f77b4")

    ax2 = ax1.twinx()
    ax2.plot(df_out.threshold, df_out.nb_regions, "#ff7f0e")
    ax2.set_ylabel("number of anomalous regions", color="#ff7f0e")
    for tl in ax2.get_yticklabels():
        tl.set_color("#ff7f0e")

    fig.savefig(os.path.join(save_dir, "plot_stat.pdf"))

    # plt.show()

    # plot a sample test image alongside its corresponding reconstruction and resmap for inspection
    if img_test is not None:
        plt.style.use("default")
        test_data_dir = os.path.join(directory, "test")
        total_number = utils.get_total_number_test_images(test_data_dir)

        test_datagen = ImageDataGenerator(
            rescale=rescale,
            data_format="channels_last",
            preprocessing_function=preprocessing_function,
        )

        # retrieve preprocessed test images as a numpy array
        test_generator = test_datagen.flow_from_directory(
            directory=test_data_dir,
            target_size=shape,
            color_mode=color_mode,
            batch_size=total_number,
            shuffle=False,
            class_mode="input",
        )
        imgs_test_input = test_generator.next()[0]

        # predict on test images
        print("computing reconstructions of validation images...")
        imgs_test_pred = model.predict(imgs_test_input)

        # compute residual maps on test set
        resmaps_test = imgs_test_input - imgs_test_pred

        if color_mode == "rgb":
            resmaps_test = tf.image.rgb_to_grayscale(resmaps_test)

        # compute image index
        index_test = test_generator.filenames.index(img_test)

        # save three images
        fig = utils.plot_input_pred_resmaps_test(imgs_test_input,
                                                 imgs_test_pred, resmaps_test,
                                                 index_test)
        fig.savefig(os.path.join(save_dir, "test_plots.png"))
        print("figure saved at {}".format(
            os.path.join(save_dir, "test_plots.png")))
def get_stats(resmaps, th_min=128, th_max=255, plot=False):

    dict_stat = {
        "threshold": [],
        "nb_regions": [],
        "mean_areas_size": [],
        "std_areas_size": [],
        "sum_areas_size": [],
    }

    # compute and plot number of anomalous regions and their area sizes with increasing thresholds
    print(
        "computing anomalous regions and area sizes with increasing thresholds..."
    )
    for threshold in range(th_min, th_max + 1):
        # threshold residual maps
        resmaps_th = threshold_images(resmaps, threshold)

        # filter images to remove salt noise
        # resmaps_fil = filter_median_images(resmaps_th, kernel_size=3)

        # compute anomalous regions and their size for current threshold
        resmaps_labeled, areas_all = label_images(resmaps_th)
        areas_all_1d = [item for sublist in areas_all for item in sublist]

        # compute the size of the biggest anomalous region (corresponds with smallest threshold)
        if threshold == th_min:
            max_region_size = np.amax(np.array(areas_all_1d))

        nb_regions = len(areas_all_1d)
        if nb_regions == 0:
            break

        mean_areas_size = np.mean(areas_all_1d)
        std_areas_size = np.std(areas_all_1d)
        sum_areas_size = np.sum(areas_all_1d)

        # append values to dictionnary
        dict_stat["threshold"].append(threshold)
        dict_stat["nb_regions"].append(nb_regions)
        dict_stat["mean_areas_size"].append(mean_areas_size)
        dict_stat["std_areas_size"].append(std_areas_size)
        dict_stat["sum_areas_size"].append(sum_areas_size)

    df_stat = pd.DataFrame.from_dict(dict_stat)

    if plot:
        fig = plt.figure()
        plt.style.use("seaborn-darkgrid")
        ax1 = fig.add_subplot(111)
        lns1 = ax1.plot(df_stat.threshold,
                        df_stat.mean_areas_size,
                        "C0",
                        label="mean_areas_size")  #1f77b4
        lns2 = ax1.plot(df_stat.threshold,
                        df_stat.std_areas_size,
                        "C1",
                        label="std_areas_size")
        ax1.set_xlabel("Thresholds")
        ax1.set_ylabel("areas size [number of pixels]")

        ax2 = ax1.twinx()
        lns3 = ax2.plot(df_stat.threshold,
                        df_stat.nb_regions,
                        "C2",
                        label="nb_regions")  #ff7f0e
        ax2.set_ylabel("number of anomalous regions", color="C2")
        for tl in ax2.get_yticklabels():
            tl.set_color("C2")

        lns = lns1 + lns2 + lns3
        labs = [l.get_label() for l in lns]
        ax1.legend(lns, labs, loc=0)

        plt.show()
    return df_stat
index_test = 12
plot_img(imgs_test_input[index_test])
plot_img(imgs_test_pred[index_test])
plot_img(resmaps_test[index_test], title="resmaps_test")
hist_image(resmaps_test, title="resmaps_test")

# transform to uint8 ----------------------------------------------------------
resmaps_val_uint8 = img_as_ubyte(resmaps_val)
hist_image_uint8(resmaps_val_uint8, title="resmaps_val_uint8")

resmaps_test_uint8 = img_as_ubyte(resmaps_test)
hist_image_uint8(resmaps_test_uint8, title="resmaps_test_uint8")

# investigate thresholding image with various thresholds----------------------
threshold = 210
resmaps_val_uint8_th = threshold_images(resmaps_val_uint8,
                                        threshold)  # resmaps_val_th_uint8
index_val = 0
plot_img(resmaps_val_uint8_th[index_val],
         title="resmaps_val_uint8_th[{}]\n threshold = {}".format(
             index_val, threshold))

threshold = 151
resmaps_test_uint8_th = threshold_images(resmaps_test_uint8, threshold)
index_test = 12
plot_img(resmaps_test_uint8_th[index_test],
         title="resmaps_test_uint8_th[{}]\n threshold = {}".format(
             index_test, threshold))

#================= AREA SIZE DISTRIBUTION FOR ALL THRESHOLDS ==================