def __init__(self, model, data=Data(), transform=tf.identity): if hasattr(model, 'num_vars'): variational = Variational() variational.add(PointMass(model.num_vars, transform)) else: variational = Variational() variational.add(PointMass(0, transform)) VariationalInference.__init__(self, model, variational, data)
def __init__(self, model, data=None, params=None): with tf.variable_scope("variational"): if hasattr(model, 'n_vars'): variational = Variational() variational.add(PointMass(model.n_vars, params)) else: variational = Variational() variational.add(PointMass(0)) super(MAP, self).__init__(model, variational, data)
def main(): data = ed.Data(np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1])) model = BetaBernoulli() variational = Variational() variational.add(Beta()) # mean-field variational inference. inference = ed.MFVI(model, variational, data) inference.run(n_iter=10000)
def _test(self, sess, data, n_minibatch, x=None, is_file=False): model = NormalModel() variational = Variational() variational.add(Normal()) inference = ed.MFVI(model, variational, data) inference.initialize(n_minibatch=n_minibatch) if x is not None: # Placeholder setting. # Check data is same as data fed to it. feed_dict = {inference.data['x']: x} # avoid directly fetching placeholder data_id = { k: tf.identity(v) for k, v in six.iteritems(inference.data) } val = sess.run(data_id, feed_dict) assert np.all(val['x'] == x) elif is_file: # File reader setting. # Check data varies by session run. val = sess.run(inference.data) val_1 = sess.run(inference.data) assert not np.all(val['x'] == val_1['x']) elif n_minibatch is None: # Preloaded full setting. # Check data is full data. val = sess.run(inference.data) assert np.all(val['x'] == data['x']) elif n_minibatch == 1: # Preloaded batch setting, with n_minibatch=1. # Check data is randomly shuffled. assert not np.all([ sess.run(inference.data)['x'] == data['x'][i] for i in range(10) ]) else: # Preloaded batch setting. # Check data is randomly shuffled. val = sess.run(inference.data) assert not np.all(val['x'] == data['x'][:n_minibatch]) # Check data varies by session run. val_1 = sess.run(inference.data) assert not np.all(val['x'] == val_1['x']) inference.finalize()
def build_toy_dataset(N=40, noise_std=0.1): ed.set_seed(0) x = np.concatenate( [np.linspace(0, 2, num=N / 2), np.linspace(6, 8, num=N / 2)]) y = 0.075 * x + norm.rvs(0, noise_std, size=N) x = (x - 4.0) / 4.0 x = x.reshape((N, 1)) return {'x': x, 'y': y} ed.set_seed(42) model = LinearModel() variational = Variational() variational.add(Normal(model.n_vars)) data = build_toy_dataset() # Set up figure fig = plt.figure(figsize=(8, 8), facecolor='white') ax = fig.add_subplot(111, frameon=False) plt.ion() plt.show(block=False) sess = ed.get_session() inference = ed.MFVI(model, variational, data) inference.initialize(n_samples=5, n_print=5) for t in range(250): loss = inference.update() if t % inference.n_print == 0: