from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

dataset = get_pendigits().rescale()

# X = dataset.X
pca = PCA(2)
pca.fit(dataset.X)

pca_X = pca.transform(dataset.X)
# print('X:', X)

bandwidths = {
    'gaussian_dist': BandwidthSelection.gaussian_distribution(pca_X),
    'likelihood_10': BandwidthSelection.cv_maximum_likelihood(pca_X),
    'likelihood_100': BandwidthSelection.cv_maximum_likelihood(pca_X, search=np.linspace(1e-4, 1, 100)),
}

# Set up the figure
f, axes = plt.subplots(int(len(bandwidths / 4)) + 1, 4, figsize=(8, 8))
x, y = list(map(np.array, zip(*pca_X)))
bandwidths = sorted([(k, v) for k, v in bandwidths.items()], key=lambda x: x[0])

for ax, (name, bandwidth) in zip(axes.flat, bandwidths):
    ax.set_aspect("equal")

    # Draw the two density plots
    sns.kdeplot(x, y, bw=bandwidth,
                cmap="Reds", shade=True, shade_lowest=False, ax=ax)
from dataset import *
from bandwidth_selection import BandwidthSelection

datasets = [
    get_iris_with_test(bandwidth='cv_ml').rescale(),
    get_pendigits(bandwidth='cv_ml').rescale(),
    get_yeast_with_test(bandwidth='cv_ml').rescale(),
    get_satimage(bandwidth='cv_ml').rescale(),
    get_banknote_with_test(bandwidth='cv_ml').rescale(),
    get_spam_with_test(bandwidth='cv_ml').rescale(),
    get_drd_with_test(bandwidth='cv_ml').rescale(),
    get_imagesegment(bandwidth='cv_ml').rescale(),
    get_pageblock_with_test(bandwidth='cv_ml').rescale(),
    get_statlogsegment_with_test(bandwidth='cv_ml').rescale(),
    get_winequality_with_test('white', bandwidth='cv_ml').rescale(),
    get_winequality_with_test('red', bandwidth='cv_ml').rescale(),
]

datasets.sort(key=lambda dataset: len(dataset.X))

for dataset in datasets:
    print('generating for', dataset.name)
    bw = BandwidthSelection.gaussian_distribution(dataset.X)
    print('bw:', 'rot:', bw, 'cv_ml:', dataset.bandwidth)