def plt_1d(class1, class2):

    prior1 = 0.5
    prior2 = 0.5

    mean1 = np.array([np.mean(class1[:, 0])])
    mean2 = np.array([np.mean(class2[:, 0])])

    # print mean1, mean2

    cov1 = np.array([[np.cov([class1[:, 0]])]])
    cov2 = np.array([[np.cov([class2[:, 0]])]])

    # print cov1, cov2

    discriminant_function1 = gdf.gen_discriminant_function_of_normal_distribution(mean1, cov1, prior1)
    discriminant_function2 = gdf.gen_discriminant_function_of_normal_distribution(mean2, cov2, prior2)

    # X = np.linspace(np.amin(class1[:, 0]), np.amax(class1[:, 0]), 200)

    X = np.linspace(-100, 100, 100)

    y1 = [discriminant_function1(np.array([x])) for x in X]

    y2 = [discriminant_function2(np.array([x])) for x in X]

    plt.plot(X, y1)

    plt.plot(X, y2)

    plt.show()
def classifier2(x):

    discriminant_function1 = gdf.gen_discriminant_function_of_normal_distribution(mean1, cov1, prior1)
    discriminant_function2 = gdf.gen_discriminant_function_of_normal_distribution(mean2, cov2, prior2)

    if discriminant_function1(x) > discriminant_function2(x):

        # print x, "class 1"

        return 1

    elif discriminant_function1(x) < discriminant_function2(x):

        # print x, "class 2"

        return 2
    else:

        # print x, "unsure"

        return 0
def plt_2d(class1, class2):

    prior1 = 0.5
    prior2 = 0.5

    mean1 = np.mean(class1[:, 0:2], axis=0)
    mean2 = np.mean(class2[:, 0:2], axis=0)

    # print mean1, mean2

    cov1 = np.cov([class1[:, 0], class1[:, 1]])
    cov2 = np.cov([class2[:, 0], class2[:, 1]])

    # print cov1, cov2

    discriminant_function1 = gdf.gen_discriminant_function_of_normal_distribution(mean1, cov1, prior1)
    discriminant_function2 = gdf.gen_discriminant_function_of_normal_distribution(mean2, cov2, prior2)

    x = np.linspace(-100, 100, 100)
    y = np.linspace(-100, 100, 100)

    X, Y = np.meshgrid(x, y)

    z1 = [discriminant_function1(np.array([x, y])) for x, y in zip(X[0], Y[:, 0])]

    z2 = [discriminant_function2(np.array([x, y])) for x, y in zip(X[0], Y[:, 0])]

    figure = plt.figure()

    axes = figure.add_subplot(111, projection='3d')

    axes.plot_surface(X, Y, z1, cmap="Greys")

    axes.plot_surface(X, Y, z2, cmap="Blues")

    plt.show()