import matplotlib.pyplot as pp
import PIL.Image as PImage

import mstools.colors as colors
from mstools.mean_shift.topological import ModeSeeking

# swan
img = PImage.open("/Users/mike/docs/classes/el612/proj/berk_data/BSR/BSDS500/data/images/test/8068.jpg")
iname = "swan"

image = colors.rgb2lab(np.array(img))
gray_image = image[..., 0]

sigma = 4.0

ms = ModeSeeking(color_bw=sigma)
c = ms.train_on_image(gray_image, density_xform=None)
x = c.cell_edges[0]
xc = (x[:-1] + x[1:]) / 2.0
p = c.mud_grid[..., -1]

peak_pts = [m.idx for (_, m) in ms.clusters.items()]
peak_hts = [m.elevation for (_, m) in ms.clusters.items()]
saddle_pts = [s.idx for s in ms.saddles]
saddle_hts = [s.elevation for s in ms.saddles]

f = pp.figure()
pp.plot(xc, p)
pp.plot(xc[peak_pts], peak_hts, "go", label="Modes")
pp.plot(xc[saddle_pts], saddle_hts, "ro", label="Saddles")
pp.title("Estimated Density of Image Lightness")
## ## # swan
## img = PImage.open('/Users/mike/docs/classes/el612/proj/berk_data/BSR/BSDS500/data/images/test/8068.jpg')
## # llama (HARD!)
## img = PImage.open('/Users/mike/docs/classes/el612/proj/berk_data/BSR/BSDS500/data/images/test/6046.jpg')
# starfish
img = PImage.open("/Users/mike/docs/classes/el612/proj/berk_data/BSR/BSDS500/data/images/train/12003.jpg")
## # monkey
## img = PImage.open('/Users/mike/docs/classes/el612/proj/berk_data/BSR/BSDS500/data/images/train/16052.jpg')


img = np.array(img)

gb_image = img[..., 0:2].astype("d")

sigma = 3
ms = ModeSeeking(color_bw=sigma)
c = ms.train_on_image(gb_image, bin_sigma=1.0)

p = c.mud_grid[..., -1]
zmask = p == 0
min_p = p[~zmask].min()
p[zmask] = 0.9 * min_p
np.log(p, out=p)

f = pp.figure()
pp.imshow(p, origin="lower", interpolation="nearest", cmap=pp.cm.hot)
pp.xlabel("blue")
pp.ylabel("green")
pp.colorbar()
pp.contour(p, levels=np.linspace(-10, p.max(), 15), origin="lower")
f.axes[0].set_title("Log-Density")