from brancher.standard_variables import NormalVariable as Normal
from brancher.standard_variables import BetaVariable as Beta
from brancher.standard_variables import LogNormalVariable as LogNormal
from brancher.inference import perform_inference
from brancher.inference import ReverseKL

## Create latent time series ##
x0 = Normal(0, 0.5, "x_0")
X = MarkovProcess(x0, lambda t, x: Normal(x, 0.2, "x_{}".format(t)))

## Create observation model ##
Y = Normal(X, 1., "y")

## Sample ##
num_timepoints = 30
temporal_sample = Y.get_timeseries_sample(1, query_points=num_timepoints)
temporal_sample.plot()
plt.show()

## Observe model ##
data = temporal_sample
query_points = range(num_timepoints)
Y.observe(data, query_points)

## Variational model
Qx0 = Normal(0, 0.5, "x_0")
QX = [Qx0]
for idx in range(1, 30):
    QX.append(
        Normal(QX[idx - 1],
               0.25,
## Create latent time series ##
x0 = Normal(0, 1, "x_0")
b = Beta(1, 1, "b")
sigma = LogNormal(0, 1, "sigma")
X = MarkovProcess(x0, lambda x: Normal(b * x, sigma, "x"))

## Create observation model ##
chi = LogNormal(1, 0.5, "sigma")
Y = Normal(X, chi, "y")

## Sample ##
num_timepoints = 20
temporal_sample = Y.get_timeseries_sample(1,
                                          query_points=num_timepoints,
                                          input_values={
                                              sigma: 1.,
                                              b: 1.
                                          })
temporal_sample.plot()
plt.show()

## Observe model ##
data = temporal_sample
query_points = range(num_timepoints)
Y.observe(data, query_points)

## Perform ML inference ##
perform_inference(Y,
                  inference_method=ReverseKL(),
                  number_iterations=1000,
                  optimizer="SGD",