# create the chain object
chain = HamiltonianChain(posterior=posterior,
                         grad=posterior.gradient,
                         start=[1, 0.1, 0.1])

# advance the chain to generate the sample
chain.advance(6000)

# choose how many samples will be thrown away from the start
# of the chain as 'burn-in'
chain.burn = 2000

chain.matrix_plot(filename='hmc_matrix_plot.png')

# extract sample and probability data from the chain
probs = chain.get_probabilities()
colors = exp(probs - max(probs))
xs, ys, zs = [chain.get_parameter(i) for i in [0, 1, 2]]

import plotly.graph_objects as go
from plotly import offline

fig = go.Figure(data=[
    go.Scatter3d(x=xs,
                 y=ys,
                 z=zs,
                 mode='markers',
                 marker=dict(
                     size=5, color=colors, colorscale='Viridis', opacity=0.6))
])
posterior = ToroidalGaussian()

from inference.mcmc import HamiltonianChain

hmc = HamiltonianChain(posterior=posterior,
                       grad=posterior.gradient,
                       start=[1, 0.1, 0.1])

hmc.advance(6000)
hmc.burn = 1000

from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize=(5, 4))
ax = fig.add_subplot(111, projection='3d')
ax.set_xticks([-1, -0.5, 0., 0.5, 1.])
ax.set_yticks([-1, -0.5, 0., 0.5, 1.])
ax.set_zticks([-1, -0.5, 0., 0.5, 1.])
# ax.set_title('Hamiltonian Monte-Carlo')
L = 1.1
ax.set_xlim([-L, L])
ax.set_ylim([-L, L])
ax.set_zlim([-L, L])
probs = array(hmc.get_probabilities())
inds = argsort(probs)
colors = exp(probs - max(probs))
xs, ys, zs = [array(hmc.get_parameter(i)) for i in [0, 1, 2]]
ax.scatter(xs, ys, zs, c=colors, marker='.', alpha=0.5)
plt.subplots_adjust(left=0., right=1., top=1., bottom=0.03)
plt.savefig('gallery_hmc.png')
plt.show()