Пример #1
0
def train(normal_digit, anomalies, folder, file, p_train, p_test):

    # Create an experiment
    experiment = Experiment(project_name="deep-stats-thesis",
                            workspace="stecaron",
                            disabled=True)
    experiment.add_tag("mnist_kpca")

    # General parameters
    DOWNLOAD_MNIST = True
    PATH_DATA = os.path.join(os.path.expanduser("~"), 'Downloads/mnist')

    # Define training parameters
    hyper_params = {
        "TRAIN_SIZE": 2000,
        "TRAIN_NOISE": p_train,
        "TEST_SIZE": 800,
        "TEST_NOISE": p_test,
        # on which class we want to learn outliers
        "CLASS_SELECTED": [normal_digit],
        # which class we want to corrupt our dataset with
        "CLASS_CORRUPTED": anomalies,
        "INPUT_DIM": 28 * 28,  # In the case of MNIST
        "ALPHA": p_test,  # level of significance for the test
        # hyperparameters gamma in rbf kPCA
        "GAMMA": [1],
        "N_COMP": [30]
    }

    # Log experiment parameterso0p
    experiment.log_parameters(hyper_params)

    # Load data
    train_data, test_data = load_mnist(PATH_DATA, download=DOWNLOAD_MNIST)

    # Normalize data
    train_data.data = train_data.data / 255.
    test_data.data = test_data.data / 255.

    # Build "train" and "test" datasets
    id_maj_train = numpy.random.choice(numpy.where(
        numpy.isin(train_data.train_labels,
                   hyper_params["CLASS_SELECTED"]))[0],
                                       int((1 - hyper_params["TRAIN_NOISE"]) *
                                           hyper_params["TRAIN_SIZE"]),
                                       replace=False)
    id_min_train = numpy.random.choice(numpy.where(
        numpy.isin(train_data.train_labels,
                   hyper_params["CLASS_CORRUPTED"]))[0],
                                       int(hyper_params["TRAIN_NOISE"] *
                                           hyper_params["TRAIN_SIZE"]),
                                       replace=False)
    id_train = numpy.concatenate((id_maj_train, id_min_train))

    id_maj_test = numpy.random.choice(numpy.where(
        numpy.isin(test_data.test_labels, hyper_params["CLASS_SELECTED"]))[0],
                                      int((1 - hyper_params["TEST_NOISE"]) *
                                          hyper_params["TEST_SIZE"]),
                                      replace=False)
    id_min_test = numpy.random.choice(numpy.where(
        numpy.isin(test_data.test_labels, hyper_params["CLASS_CORRUPTED"]))[0],
                                      int(hyper_params["TEST_NOISE"] *
                                          hyper_params["TEST_SIZE"]),
                                      replace=False)
    id_test = numpy.concatenate((id_min_test, id_maj_test))

    train_data.data = train_data.data[id_train]
    train_data.targets = train_data.targets[id_train]

    test_data.data = test_data.data[id_test]
    test_data.targets = test_data.targets[id_test]

    train_data.targets = numpy.isin(train_data.train_labels,
                                    hyper_params["CLASS_CORRUPTED"])
    test_data.targets = numpy.isin(test_data.test_labels,
                                   hyper_params["CLASS_CORRUPTED"])

    # Flatten the data and transform to numpy array
    train_data.data = train_data.data.view(-1, 28 * 28).numpy()
    test_data.data = test_data.data.view(-1, 28 * 28).numpy()

    # Train kPCA
    # param_grid = [{"gamma": hyper_params["GAMMA"],
    #                "n_components": hyper_params["N_COMP"]}]

    param_grid = [{"n_components": hyper_params["N_COMP"]}]

    # kpca = KernelPCA(fit_inverse_transform=True,
    #                  kernel="rbf",
    #                  remove_zero_eig=True,
    #                  n_jobs=-1)

    kpca = PCA()

    #my_scorer2 = make_scorer(my_scorer, greater_is_better=True)
    # grid_search = GridSearchCV(kpca, param_grid, cv=ShuffleSplit(
    #     n_splits=3), scoring=my_scorer)
    kpca.fit(train_data.data)
    X_kpca = kpca.transform(train_data.data)
    X_train_back = kpca.inverse_transform(X_kpca)
    X_test_back = kpca.inverse_transform(kpca.transform(test_data.data))

    # Compute the distance between original data and reconstruction
    dist_train = numpy.linalg.norm(train_data.data - X_train_back,
                                   ord=2,
                                   axis=1)
    dist_test = numpy.linalg.norm(test_data.data - X_test_back, ord=2, axis=1)

    # Test performances on train
    train_anomalies_ind = numpy.argsort(dist_train)[int(
        (1 - hyper_params["ALPHA"]) *
        hyper_params["TRAIN_SIZE"]):int(hyper_params["TRAIN_SIZE"])]
    train_predictions = numpy.zeros(hyper_params["TRAIN_SIZE"])
    train_predictions[train_anomalies_ind] = 1

    train_recall = metrics.recall_score(train_data.targets, train_predictions)
    train_precision = metrics.precision_score(train_data.targets,
                                              train_predictions)
    train_f1_score = metrics.f1_score(train_data.targets, train_predictions)
    train_auc = metrics.roc_auc_score(train_data.targets, train_predictions)

    print(f"Train Precision: {train_precision}")
    print(f"Train Recall: {train_recall}")
    print(f"Train F1 Score: {train_f1_score}")
    print(f"Train AUC: {train_auc}")
    experiment.log_metric("train_precision", train_precision)
    experiment.log_metric("train_recall", train_recall)
    experiment.log_metric("train_f1_score", train_f1_score)
    experiment.log_metric("train_auc", train_auc)

    # Test performances on test
    test_probs = numpy.array(
        [numpy.sum(xi >= dist_train) / len(dist_train) for xi in dist_test],
        dtype=float)
    test_anomalies_ind = numpy.argwhere(
        test_probs >= 1 - hyper_params["ALPHA"])
    test_predictions = numpy.zeros(hyper_params["TEST_SIZE"])
    test_predictions[test_anomalies_ind] = 1

    test_recall = metrics.recall_score(test_data.targets, test_predictions)
    test_precision = metrics.precision_score(test_data.targets,
                                             test_predictions)
    test_f1_score = metrics.f1_score(test_data.targets, test_predictions)
    test_auc = metrics.roc_auc_score(test_data.targets, test_probs)
    test_average_precision = metrics.average_precision_score(
        test_data.targets, test_predictions)

    print(f"Test Precision: {test_precision}")
    print(f"Test Recall: {test_recall}")
    print(f"Test F1 Score: {test_f1_score}")
    print(f"Test AUC: {test_auc}")
    print(f"Test average Precision: {test_average_precision}")
    experiment.log_metric("test_precision", test_precision)
    experiment.log_metric("test_recall", test_recall)
    experiment.log_metric("test_f1_score", test_f1_score)
    experiment.log_metric("test_auc", test_auc)
    experiment.log_metric("test_average_precision", test_average_precision)

    # Save the results in the output file
    col_names = [
        "timestamp", "precision", "recall", "f1_score", "average_precision",
        "auc"
    ]
    results_file = os.path.join(folder, "results_" + file + ".csv")
    if os.path.exists(results_file):
        df_results = pandas.read_csv(results_file, names=col_names, header=0)
    else:
        df_results = pandas.DataFrame(columns=col_names)

    df_results = df_results.append(pandas.DataFrame(numpy.concatenate(
        (numpy.array(
            datetime.datetime.fromtimestamp(
                time.time()).strftime('%Y-%m-%d %H:%M:%S')).reshape(1),
         test_precision.reshape(1), test_recall.reshape(1),
         test_f1_score.reshape(1), test_average_precision.reshape(1),
         test_auc.reshape(1))).reshape(1, -1),
                                                    columns=col_names),
                                   ignore_index=True)

    df_results.to_csv(results_file)
Пример #2
0
def train(normal_digit, anomalies, folder, file, p_train, p_test):

    # Create an experiment
    experiment = Experiment(project_name="deep-stats-thesis",
                            workspace="stecaron",
                            disabled=True)
    experiment.add_tag("mnist_conv_ae")

    # General parameters
    DOWNLOAD_MNIST = True
    PATH_DATA = os.path.join(os.path.expanduser("~"), 'Downloads/mnist')
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Define training parameters
    hyper_params = {
        "EPOCH": 75,
        "NUM_WORKERS": 10,
        "BATCH_SIZE": 256,
        "LR": 0.001,
        "TRAIN_SIZE": 4000,
        "TRAIN_NOISE": p_train,
        "TEST_SIZE": 800,
        "TEST_NOISE": p_test,
        # on which class we want to learn outliers
        "CLASS_SELECTED": [normal_digit],
        # which class we want to corrupt our dataset with
        "CLASS_CORRUPTED": anomalies,
        "ALPHA": p_test,
        "MODEL_NAME": "mnist_ae_model",
        "LOAD_MODEL": False,
        "LOAD_MODEL_NAME": "mnist_ae_model"
    }

    # Log experiment parameters
    experiment.log_parameters(hyper_params)

    # Load data
    train_data, test_data = load_mnist(PATH_DATA, download=DOWNLOAD_MNIST)

    # Train the autoencoder
    model = ConvAutoEncoder2()
    optimizer = torch.optim.Adam(model.parameters(), lr=hyper_params["LR"])
    #loss_func = nn.MSELoss()
    loss_func = nn.BCELoss()

    # Build "train" and "test" datasets
    id_maj_train = numpy.random.choice(numpy.where(
        numpy.isin(train_data.train_labels, hyper_params["CLASS_SELECTED"]))[0],
        int((1 - hyper_params["TRAIN_NOISE"]) *
            hyper_params["TRAIN_SIZE"]),
        replace=False)
    id_min_train = numpy.random.choice(numpy.where(
        numpy.isin(train_data.train_labels, hyper_params["CLASS_CORRUPTED"]))[0],
        int(hyper_params["TRAIN_NOISE"] *
            hyper_params["TRAIN_SIZE"]),
        replace=False)
    id_train = numpy.concatenate((id_maj_train, id_min_train))

    id_maj_test = numpy.random.choice(numpy.where(
        numpy.isin(test_data.test_labels, hyper_params["CLASS_SELECTED"]))[0],
        int((1 - hyper_params["TEST_NOISE"]) *
            hyper_params["TEST_SIZE"]),
        replace=False)
    id_min_test = numpy.random.choice(numpy.where(
        numpy.isin(test_data.test_labels, hyper_params["CLASS_CORRUPTED"]))[0],
        int(hyper_params["TEST_NOISE"] *
            hyper_params["TEST_SIZE"]),
        replace=False)
    id_test = numpy.concatenate((id_min_test, id_maj_test))

    train_data.data = train_data.data[id_train]
    train_data.targets = train_data.targets[id_train]

    test_data.data = test_data.data[id_test]
    test_data.targets = test_data.targets[id_test]

    train_data.targets = torch.from_numpy(
        numpy.isin(train_data.train_labels,
                   hyper_params["CLASS_CORRUPTED"])).type(torch.int32)
    test_data.targets = torch.from_numpy(
        numpy.isin(test_data.test_labels,
                   hyper_params["CLASS_CORRUPTED"])).type(torch.int32)

    train_loader = Data.DataLoader(dataset=train_data,
                                   batch_size=hyper_params["BATCH_SIZE"],
                                   shuffle=True,
                                   num_workers=hyper_params["NUM_WORKERS"])

    test_loader = Data.DataLoader(dataset=test_data,
                                  batch_size=test_data.data.shape[0],
                                  shuffle=False,
                                  num_workers=hyper_params["NUM_WORKERS"])
    model.train()
    if hyper_params["LOAD_MODEL"]:
        model = torch.load(hyper_params["LOAD_MODEL_NAME"])
    else:
        train_mnist(train_loader,
                    model,
                    criterion=optimizer,
                    n_epoch=hyper_params["EPOCH"],
                    experiment=experiment,
                    device=device,
                    model_name=hyper_params["MODEL_NAME"],
                    loss_func=loss_func,
                    loss_type="binary")

    # Compute p-values
    model.to(device)
    pval, test_errors = compute_reconstruction_pval(
        train_loader, model, test_loader, device)
    pval_order = numpy.argsort(pval)

    # Plot p-values
    x_line = numpy.arange(0, len(test_data), step=1)
    y_line = numpy.linspace(0, 1, len(test_data))
    y_adj = numpy.arange(0, len(test_data),
                         step=1) / len(test_data) * hyper_params["ALPHA"]
    zoom = int(0.2 * len(test_data))  # nb of points to zoom

    #index = numpy.isin(test_data.test_labels, hyper_params["CLASS_CORRUPTED"]).astype(int)
    index = numpy.array(test_data.targets).astype(int)

    fig, (ax1, ax2) = plt.subplots(2, 1)

    ax1.scatter(numpy.arange(0, len(pval), 1),
                pval[pval_order],
                c=index[pval_order].reshape(-1))
    ax1.plot(x_line, y_line, color="green")
    ax1.plot(x_line, y_adj, color="red")
    ax1.set_title(
        f'Entire test dataset with {int(hyper_params["TEST_NOISE"] * 100)}% of noise'
    )
    ax1.set_xticklabels([])

    ax2.scatter(numpy.arange(0, zoom, 1),
                pval[pval_order][0:zoom],
                c=index[pval_order].reshape(-1)[0:zoom])
    ax2.plot(x_line[0:zoom], y_line[0:zoom], color="green")
    ax2.plot(x_line[0:zoom], y_adj[0:zoom], color="red")
    ax2.set_title('Zoomed in')
    ax2.set_xticklabels([])

    experiment.log_figure(figure_name="empirical_test_hypothesis",
                          figure=fig,
                          overwrite=True)
    plt.savefig(os.path.join(folder, "pvalues_" + file + ".png"))
    plt.show()

    # Compute some stats
    precision, recall, f1_score, average_precision, roc_auc = test_performances(
        pval, index, hyper_params["ALPHA"])
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"F1 Score: {f1_score}")
    print(f"AUC: {roc_auc}")
    print(f"Average Precison: {average_precision}")
    experiment.log_metric("precision", precision)
    experiment.log_metric("recall", recall)
    experiment.log_metric("f1_score", f1_score)
    experiment.log_metric("auc", roc_auc)
    experiment.log_metric("average_precision", average_precision)

    # Show some examples

    fig, axs = plt.subplots(5, 5)
    fig.tight_layout()
    axs = axs.ravel()

    for i in range(25):
        image = test_data.data[pval_order[i]]
        axs[i].imshow(image, cmap='gray')
        axs[i].axis('off')

    experiment.log_figure(figure_name="rejetcted_observations",
                          figure=fig,
                          overwrite=True)
    plt.show()

    fig, axs = plt.subplots(5, 5)
    fig.tight_layout()
    axs = axs.ravel()

    for i in range(25):
        image = test_data.data[pval_order[int(len(pval) - 1) - i]]
        axs[i].imshow(image, cmap='gray')
        axs[i].axis('off')

    experiment.log_figure(figure_name="better_observations",
                          figure=fig,
                          overwrite=True)
    plt.show()

    # Save the results in the output file
    col_names = ["timestamp", "precision", "recall", "f1_score",
                 "average_precision", "auc"]
    results_file = os.path.join(folder, "results_" + file + ".csv")
    if os.path.exists(results_file):
        df_results = pandas.read_csv(results_file, names=col_names, header=0)
    else:
        df_results = pandas.DataFrame(columns=col_names)

    df_results = df_results.append(
        pandas.DataFrame(
            numpy.concatenate(
                (numpy.array(
                    datetime.datetime.fromtimestamp(
                        time.time()).strftime('%Y-%m-%d %H:%M:%S')).reshape(1),
                 precision.reshape(1), recall.reshape(1),
                 f1_score.reshape(1), average_precision.reshape(1),
                 roc_auc.reshape(1))).reshape(1, -1), columns=col_names), ignore_index=True)

    df_results.to_csv(results_file)
Пример #3
0
def train(normal_digit, anomalies, folder, file, p_train, p_test):

    # Create an experiment
    experiment = Experiment(project_name="deep-stats-thesis",
                            workspace="stecaron",
                            disabled=False)
    experiment.add_tag("mnist_vae_svm")

    # General parameters
    DOWNLOAD_MNIST = True
    PATH_DATA = os.path.join(os.path.expanduser("~"), 'Downloads/mnist')
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #device = "cpu"

    # Define training parameters
    hyper_params = {
        "EPOCH": 75,
        "BATCH_SIZE": 500,
        "NUM_WORKERS": 10,
        "LR": 0.001,
        "TRAIN_SIZE": 4000,
        "TRAIN_NOISE": p_train,
        "TEST_SIZE": 1000,
        "TEST_NOISE": p_test,
        # on which class we want to learn outliers
        "CLASS_SELECTED": [normal_digit],
        # which class we want to corrupt our dataset with
        "CLASS_CORRUPTED": anomalies,
        "INPUT_DIM": 28 * 28,  # In the case of MNIST
        "HIDDEN_DIM":
        500,  # hidden layer dimensions (before the representations)
        "LATENT_DIM": 25,  # latent distribution dimensions
        "ALPHA": p_test,  # level of significance for the test
        "BETA_epoch": [5, 10, 25],
        "BETA": [0, 5, 1],  # hyperparameter to weight KLD vs RCL
        "MODEL_NAME": "mnist_vae_svm_model",
        "LOAD_MODEL": True,
        "LOAD_MODEL_NAME": "mnist_vae_svm_model"
    }

    # Log experiment parameterso0p
    experiment.log_parameters(hyper_params)

    # Load data
    train_data, test_data = load_mnist(PATH_DATA, download=DOWNLOAD_MNIST)

    # Train the autoencoder
    # model = VariationalAE(hyper_params["INPUT_DIM"], hyper_params["HIDDEN_DIM"],
    #                     hyper_params["LATENT_DIM"])
    model = ConvLargeVAE(z_dim=hyper_params["LATENT_DIM"])
    optimizer = torch.optim.Adam(model.parameters(), lr=hyper_params["LR"])

    # Build "train" and "test" datasets
    id_maj_train = numpy.random.choice(numpy.where(
        numpy.isin(train_data.train_labels,
                   hyper_params["CLASS_SELECTED"]))[0],
                                       int((1 - hyper_params["TRAIN_NOISE"]) *
                                           hyper_params["TRAIN_SIZE"]),
                                       replace=False)
    id_min_train = numpy.random.choice(numpy.where(
        numpy.isin(train_data.train_labels,
                   hyper_params["CLASS_CORRUPTED"]))[0],
                                       int(hyper_params["TRAIN_NOISE"] *
                                           hyper_params["TRAIN_SIZE"]),
                                       replace=False)
    id_train = numpy.concatenate((id_maj_train, id_min_train))

    id_maj_test = numpy.random.choice(numpy.where(
        numpy.isin(test_data.test_labels, hyper_params["CLASS_SELECTED"]))[0],
                                      int((1 - hyper_params["TEST_NOISE"]) *
                                          hyper_params["TEST_SIZE"]),
                                      replace=False)
    id_min_test = numpy.random.choice(numpy.where(
        numpy.isin(test_data.test_labels, hyper_params["CLASS_CORRUPTED"]))[0],
                                      int(hyper_params["TEST_NOISE"] *
                                          hyper_params["TEST_SIZE"]),
                                      replace=False)
    id_test = numpy.concatenate((id_min_test, id_maj_test))

    train_data.data = train_data.data[id_train]
    train_data.targets = train_data.targets[id_train]

    test_data.data = test_data.data[id_test]
    test_data.targets = test_data.targets[id_test]

    train_data.targets = torch.from_numpy(
        numpy.isin(train_data.train_labels,
                   hyper_params["CLASS_CORRUPTED"])).type(torch.int32)
    test_data.targets = torch.from_numpy(
        numpy.isin(test_data.test_labels,
                   hyper_params["CLASS_CORRUPTED"])).type(torch.int32)

    train_loader = Data.DataLoader(dataset=train_data,
                                   batch_size=hyper_params["BATCH_SIZE"],
                                   shuffle=True,
                                   num_workers=hyper_params["NUM_WORKERS"])

    test_loader = Data.DataLoader(dataset=test_data,
                                  batch_size=test_data.data.shape[0],
                                  shuffle=False,
                                  num_workers=hyper_params["NUM_WORKERS"])

    #scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=hyper_params["LR"], steps_per_epoch=len(train_loader), epochs=hyper_params["EPOCH"])

    if hyper_params["LOAD_MODEL"]:
        model = torch.load(hyper_params["LOAD_MODEL_NAME"])
    else:
        train_mnist_vae(
            train_loader,
            model,
            criterion=optimizer,
            n_epoch=hyper_params["EPOCH"],
            experiment=experiment,
            #scheduler=scheduler,
            beta_list=hyper_params["BETA"],
            beta_epoch=hyper_params["BETA_epoch"],
            model_name=hyper_params["MODEL_NAME"],
            device=device,
            loss_type="binary",
            flatten=False)

    # Compute p-values
    model.to(device)
    preds = compute_pval_loaders_svm(train_loader,
                                     test_loader,
                                     model,
                                     device=device,
                                     experiment=experiment,
                                     flatten=False)

    index = numpy.array(test_data.targets).astype(int)

    # Compute some stats
    precision = metrics.precision_score(index, preds)
    recall = metrics.recall_score(index, preds)
    f1_score = metrics.f1_score(index, preds)
    average_precision = metrics.average_precision_score(index, preds)
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"F1 Score: {f1_score}")
    print(f"Average Precision: {average_precision}")
    #print(f"AUC: {roc_auc}")
    experiment.log_metric("precision", precision)
    experiment.log_metric("recall", recall)
    experiment.log_metric("f1_score", f1_score)
    experiment.log_metric("average_precision", average_precision)
    #experiment.log_metric("auc", roc_auc)

    # Show some examples

    sample_erros = numpy.random.choice(
        numpy.where((index != preds) & (index == 1))[0], 25)
    sample_ok = numpy.random.choice(
        numpy.where((index == preds) & (index == 1))[0], 25)

    fig, axs = plt.subplots(5, 5)
    fig.tight_layout()
    axs = axs.ravel()

    for i in range(25):
        image = test_data.data[sample_erros[i]]
        axs[i].imshow(image, cmap='gray')
        axs[i].axis('off')

    experiment.log_figure(figure_name="rejetcted_observations",
                          figure=fig,
                          overwrite=True)
    plt.show()

    fig, axs = plt.subplots(5, 5)
    fig.tight_layout()
    axs = axs.ravel()

    for i in range(25):
        image = test_data.data[sample_ok[i]]
        axs[i].imshow(image, cmap='gray')
        axs[i].axis('off')

    experiment.log_figure(figure_name="better_observations",
                          figure=fig,
                          overwrite=True)
    plt.show()

    # Save the results in the output file
    col_names = [
        "timestamp", "precision", "recall", "f1_score", "average_precision",
        "auc"
    ]
    results_file = os.path.join(folder, "results_" + file + ".csv")
    if os.path.exists(results_file):
        df_results = pandas.read_csv(results_file, names=col_names, header=0)
    else:
        df_results = pandas.DataFrame(columns=col_names)

    df_results = df_results.append(pandas.DataFrame(numpy.concatenate(
        (numpy.array(
            datetime.datetime.fromtimestamp(
                time.time()).strftime('%Y-%m-%d %H:%M:%S')).reshape(1),
         precision.reshape(1), recall.reshape(1), f1_score.reshape(1),
         average_precision.reshape(1),
         numpy.array(numpy.nan).reshape(1))).reshape(1, -1),
                                                    columns=col_names),
                                   ignore_index=True)

    df_results.to_csv(results_file)
Пример #4
0
    plt.show()


def plot_outliers_idx(dt, idx_in, idx_out, shape=(2, 5)):

    fig, ax = plt.subplots(shape[0], shape[1])
    ax = ax.flatten()

    i = 0
    for id in idx_in:
        plottable_image = numpy.reshape(dt[id], (28, 28))
        ax[i].imshow(plottable_image, cmap='gray_r')
        ax[i].axis('off')
        ax[i].set_title('Inliers')
        i += 1

    for id in idx_out:
        plottable_image = numpy.reshape(dt[id], (28, 28))
        ax[i].imshow(plottable_image, cmap='gray_r')
        ax[i].axis('off')
        ax[i].set_title('Outliers')
        i += 1

    plt.show()


if __name__ == '__main__':
    from src.mnist.data import load_mnist
    train_data, test_data = load_mnist('/Users/stephanecaron/Downloads/mnist')
    plot_n_images(train_data, 4)
Пример #5
0
def train(normal_digit, anomalies, folder, file, p_train, p_test):
    # Create an experiment
    experiment = Experiment(project_name="deep-stats-thesis",
                            workspace="stecaron",
                            disabled=True)
    experiment.add_tag("mnist_vae")

    # General parameters
    DOWNLOAD_MNIST = True
    PATH_DATA = os.path.join(os.path.expanduser("~"),
                             'Downloads/mnist')
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #device = "cpu"

    # Define training parameters
    hyper_params = {
        "EPOCH": 75,
        "BATCH_SIZE": 500,
        "NUM_WORKERS": 10,
        "LR": 0.0001,
        "TRAIN_SIZE": 4000,
        "TRAIN_NOISE": p_train,
        "TEST_SIZE": 800,
        "TEST_NOISE": p_test,
        # on which class we want to learn outliers
        "CLASS_SELECTED": [normal_digit],
        # which class we want to corrupt our dataset with
        "CLASS_CORRUPTED": anomalies,
        # "CLASS_CORRUPTED": numpy.delete(numpy.linspace(0, 9, 10).astype(int), normal_digit).tolist(),
        "INPUT_DIM": 28 * 28,  # In the case of MNIST
        # hidden layer dimensions (before the representations)
        "HIDDEN_DIM": 500,
        "LATENT_DIM": 2,  # latent distribution dimensions
        "ALPHA": p_test,  # level of significance for the test
        "BETA_epoch": [5, 10, 25],
        "BETA": [0, 5, 1],  # hyperparameter to weight KLD vs RCL
        "MODEL_NAME": "mnist_vae_model",
        "LOAD_MODEL": False,
        "LOAD_MODEL_NAME": "mnist_vae_model"
    }

    # Log experiment parameterso0p
    experiment.log_parameters(hyper_params)

    # Load data
    train_data, test_data = load_mnist(PATH_DATA, download=DOWNLOAD_MNIST)

    # Train the autoencoder
    model = ConvLargeVAE(z_dim=hyper_params["LATENT_DIM"])
    optimizer = torch.optim.Adam(model.parameters(), lr=hyper_params["LR"])

    # Build "train" and "test" datasets
    id_maj_train = numpy.random.choice(
        numpy.where(numpy.isin(train_data.train_labels,
                               hyper_params["CLASS_SELECTED"]))[0],
        int((1 - hyper_params["TRAIN_NOISE"]) * hyper_params["TRAIN_SIZE"]),
        replace=False
    )
    id_min_train = numpy.random.choice(
        numpy.where(numpy.isin(train_data.train_labels,
                               hyper_params["CLASS_CORRUPTED"]))[0],
        int(hyper_params["TRAIN_NOISE"] * hyper_params["TRAIN_SIZE"]),
        replace=False
    )
    id_train = numpy.concatenate((id_maj_train, id_min_train))

    id_maj_test = numpy.random.choice(
        numpy.where(numpy.isin(test_data.test_labels,
                               hyper_params["CLASS_SELECTED"]))[0],
        int((1 - hyper_params["TEST_NOISE"]) * hyper_params["TEST_SIZE"]),
        replace=False
    )
    id_min_test = numpy.random.choice(
        numpy.where(numpy.isin(test_data.test_labels,
                               hyper_params["CLASS_CORRUPTED"]))[0],
        int(hyper_params["TEST_NOISE"] * hyper_params["TEST_SIZE"]),
        replace=False
    )
    id_test = numpy.concatenate((id_min_test, id_maj_test))

    train_data.data = train_data.data[id_train]
    train_data.targets = train_data.targets[id_train]

    test_data.data = test_data.data[id_test]
    test_data.targets = test_data.targets[id_test]

    train_data.targets = torch.from_numpy(numpy.isin(
        train_data.train_labels, hyper_params["CLASS_CORRUPTED"])).type(torch.int32)
    test_data.targets = torch.from_numpy(numpy.isin(
        test_data.test_labels, hyper_params["CLASS_CORRUPTED"])).type(torch.int32)

    train_loader = Data.DataLoader(dataset=train_data,
                                   batch_size=hyper_params["BATCH_SIZE"],
                                   shuffle=True,
                                   num_workers=hyper_params["NUM_WORKERS"])

    test_loader = Data.DataLoader(dataset=test_data,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=hyper_params["NUM_WORKERS"])

    model_save = os.path.join(folder, hyper_params["MODEL_NAME"] + file)

    if hyper_params["LOAD_MODEL"]:
        model = torch.load(hyper_params["LOAD_MODEL_NAME"])
    else:
        train_mnist_vae(train_loader,
                        # test_loader,
                        model,
                        criterion=optimizer,
                        n_epoch=hyper_params["EPOCH"],
                        experiment=experiment,
                        beta_list=hyper_params["BETA"],
                        beta_epoch=hyper_params["BETA_epoch"],
                        model_name=model_save,
                        device=device,
                        # latent_dim=hyper_params['LATENT_DIM'],
                        loss_type="binary",
                        flatten=False)

    # Compute p-values
    model.to(device)
    pval, _ = compute_pval_loaders(train_loader,
                                   test_loader,
                                   model,
                                   device=device,
                                   experiment=experiment,
                                   folder=folder,
                                   file=file,
                                   flatten=False)

    pval_order = numpy.argsort(pval)

    # Plot p-values
    x_line = numpy.arange(0, len(test_data), step=1)
    y_line = numpy.linspace(0, 1, len(test_data))
    y_adj = numpy.arange(0, len(test_data),
                         step=1) / len(test_data) * hyper_params["ALPHA"]
    zoom = int(0.2 * len(test_data))  # nb of points to zoom

    #index = numpy.isin(test_data.test_labels, hyper_params["CLASS_CORRUPTED"]).astype(int)
    index = numpy.array(test_data.targets).astype(int)

    fig, (ax1, ax2) = plt.subplots(2, 1)

    ax1.scatter(numpy.arange(0, len(pval), 1),
                pval[pval_order],
                c=index[pval_order].reshape(-1))
    ax1.plot(x_line, y_line, color="green")
    ax1.axhline(hyper_params["ALPHA"], color="red")
    #ax1.plot(x_line, y_adj, color="red")
    ax1.set_ylabel(r"Score $(1 - \gamma)$")
    ax1.set_title(
        f'Jeu de données test avec {int(hyper_params["TEST_NOISE"] * 100)}% de contamination'
    )
    ax1.set_xticklabels([])

    ax2.scatter(numpy.arange(0, zoom, 1),
                pval[pval_order][0:zoom],
                c=index[pval_order].reshape(-1)[0:zoom])
    ax2.plot(x_line[0:zoom], y_line[0:zoom], color="green")
    ax2.axhline(hyper_params["ALPHA"], color="red")
    #ax2.plot(x_line[0:zoom], y_adj[0:zoom], color="red")
    ax2.set_ylabel(r"Score $(1 - \gamma)$")
    ax2.set_title('Vue rapprochée')
    ax2.set_xticklabels([])

    experiment.log_figure(figure_name="empirical_test_hypothesis",
                          figure=fig,
                          overwrite=True)
    plt.savefig(os.path.join(folder, "pvalues_" + file + ".pdf"))
    plt.show()

    # Compute some stats
    precision, recall, f1_score, average_precision, roc_auc = test_performances(
        pval, index, hyper_params["ALPHA"])
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"F1 Score: {f1_score}")
    print(f"AUC: {roc_auc}")
    print(f"Average Precision: {average_precision}")
    experiment.log_metric("precision", precision)
    experiment.log_metric("recall", recall)
    experiment.log_metric("f1_score", f1_score)
    experiment.log_metric("auc", roc_auc)
    experiment.log_metric("average_precision", average_precision)

    # Show some examples

    fig, axs = plt.subplots(5, 5)
    fig.tight_layout()
    axs = axs.ravel()

    for i in range(25):
        image = test_data.data[pval_order[i]]
        axs[i].imshow(image, cmap='gray')
        axs[i].axis('off')

    experiment.log_figure(figure_name="rejetcted_observations",
                          figure=fig,
                          overwrite=True)
    plt.savefig(os.path.join(folder, "rejected_observations_" + file + ".pdf"))
    plt.show()

    fig, axs = plt.subplots(5, 5)
    fig.tight_layout()
    axs = axs.ravel()

    for i in range(25):
        image = test_data.data[pval_order[int(len(pval) - 1) - i]]
        axs[i].imshow(image, cmap='gray')
        axs[i].axis('off')

    experiment.log_figure(figure_name="better_observations",
                          figure=fig,
                          overwrite=True)
    plt.savefig(os.path.join(folder, "better_observations_" + file + ".pdf"))
    plt.show()

    # Plot some errors
    preds = numpy.zeros(index.shape[0])
    preds[numpy.argwhere(pval <= hyper_params["ALPHA"])] = 1
    false_positive = numpy.where((index != preds) & (index == 1))[0]
    nb_errors = numpy.min([16, false_positive.shape[0]])

    sample_errors = numpy.random.choice(
        false_positive, nb_errors, replace=False)
    fig, axs = plt.subplots(4, 4)
    fig.tight_layout()
    axs = axs.ravel()

    for i in range(nb_errors):
        image = test_data.data[sample_errors[i]]
        axs[i].imshow(image, cmap='gray')
        axs[i].axis('off')

    plt.savefig(os.path.join(folder, "false_positive_sample_" + file + ".pdf"))
    plt.show()

    false_negative = numpy.where((index != preds) & (index == 0))[0]
    nb_errors = numpy.min([16, false_negative.shape[0]])

    sample_errors = numpy.random.choice(
        false_negative, nb_errors, replace=False)
    fig, axs = plt.subplots(4, 4)
    fig.tight_layout()
    axs = axs.ravel()

    for i in range(nb_errors):
        image = test_data.data[sample_errors[i]]
        axs[i].imshow(image, cmap='gray')
        axs[i].axis('off')

    plt.savefig(os.path.join(folder, "false_negative_sample_" + file + ".pdf"))
    plt.show()

    # Save the results in the output file
    col_names = ["timestamp", "precision", "recall", "f1_score",
                 "average_precision", "auc"]
    results_file = os.path.join(folder, "results_" + file + ".csv")
    if os.path.exists(results_file):
        df_results = pandas.read_csv(results_file, names=col_names, header=0)
    else:
        df_results = pandas.DataFrame(columns=col_names)

    df_results = df_results.append(
        pandas.DataFrame(
            numpy.concatenate(
                (numpy.array(
                    datetime.datetime.fromtimestamp(
                        time.time()).strftime('%Y-%m-%d %H:%M:%S')).reshape(1),
                 precision.reshape(1), recall.reshape(1),
                 f1_score.reshape(1), average_precision.reshape(1),
                 roc_auc.reshape(1))).reshape(1, -1), columns=col_names), ignore_index=True)

    df_results.to_csv(results_file)