output_mean = self.l3(h1)
        return {"mean": output_mean}


# Initialize encoder and decoders
encoder = BF.BrancherFunction(
    EncoderArchitecture(image_size=image_size, latent_size=latent_size))
decoder = BF.BrancherFunction(
    DecoderArchitecture(latent_size=latent_size, image_size=image_size))

# Generative model
z = NormalVariable(np.zeros((latent_size, )),
                   np.ones((latent_size, )),
                   name="z")
decoder_output = DeterministicVariable(decoder(z), name="decoder_output")
x = BinomialVariable(total_count=1, logits=decoder_output["mean"], name="x")
model = ProbabilisticModel([x, z])

# Amortized variational distribution
Qx = EmpiricalVariable(dataset, batch_size=100, name="x", is_observed=True)
encoder_output = DeterministicVariable(encoder(Qx), name="encoder_output")
Qz = NormalVariable(encoder_output["mean"], encoder_output["sd"], name="z")
model.set_posterior_model(ProbabilisticModel([Qx, Qz]))

# Joint-contrastive inference
inference.perform_inference(
    model,
    inference_method=ReverseKL(gradient_estimator=PathwiseDerivativeEstimator),
    number_iterations=1000,
    number_samples=1,
    optimizer="Adam",
minibatch_indices = RandomIndices(dataset_size=dataset_size,
                                  batch_size=minibatch_size,
                                  name="indices",
                                  is_observed=True)
x = EmpiricalVariable(input_variable,
                      indices=minibatch_indices,
                      name="x",
                      is_observed=True)
labels = EmpiricalVariable(output_labels,
                           indices=minibatch_indices,
                           name="labels",
                           is_observed=True)
weights = NormalVariable(np.zeros((1, number_regressors)), 0.5 * np.ones(
    (1, number_regressors)), "weights")
logit_p = BF.matmul(weights, x)
k = BinomialVariable(1, logit_p=logit_p, name="k")
model = ProbabilisticModel([k])

#samples = model._get_sample(300)
#model.calculate_log_probability(samples)

# Observations
k.observe(labels)

#observed_model = inference.get_observed_model(model)
#observed_samples = observed_model._get_sample(number_samples=1, observed=True)

# Variational Model
Qweights = NormalVariable(np.zeros((1, number_regressors)),
                          np.ones((1, number_regressors)),
                          "weights",
Beispiel #3
0
import matplotlib.pyplot as plt
import numpy as np

from brancher.variables import ProbabilisticModel
from brancher.standard_variables import BetaVariable, BinomialVariable
from brancher import inference
from brancher.visualizations import plot_posterior

# betaNormal/Binomial model
number_tosses = 1
p = BetaVariable(1., 1., "p")
k = BinomialVariable(number_tosses, probs=p, name="k")
model = ProbabilisticModel([k, p])

# Generate data
p_real = 0.8
data = model.get_sample(number_samples=30, input_values={p: p_real})

# Observe data
k.observe(data)

# Inference
inference.perform_inference(model,
                            number_iterations=1000,
                            number_samples=500,
                            lr=0.1,
                            optimizer='SGD')
loss_list = model.diagnostics["loss curve"]

#Plot loss
plt.plot(loss_list)
Beispiel #4
0
import matplotlib.pyplot as plt
import numpy as np

from brancher.variables import ProbabilisticModel
from brancher.standard_variables import BetaVariable, BinomialVariable
from brancher import inference

#Real model
number_samples = 1
p_real = 0.8
k_real = BinomialVariable(number_samples, probs=p_real, name="k")

# betaNormal/Binomial model
p = BetaVariable(1., 1., "p")
k = BinomialVariable(number_samples, probs=p, name="k")
model = ProbabilisticModel([k])

# Generate data
data = k_real._get_sample(number_samples=50)

# Observe data
k.observe(data[k_real][:, 0, :])

# Variational distribution
Qp = BetaVariable(1., 1., "p", learnable=True)
model.set_posterior_model(ProbabilisticModel([Qp]))

# Inference
inference.perform_inference(model,
                            number_iterations=3000,
                            number_samples=100,