예제 #1
0
    def __init__(self, model, data=Data(), transform=tf.identity):
        if hasattr(model, 'num_vars'):
            variational = Variational()
            variational.add(PointMass(model.num_vars, transform))
        else:
            variational = Variational()
            variational.add(PointMass(0, transform))

        VariationalInference.__init__(self, model, variational, data)
예제 #2
0
    def __init__(self, model, data=None, params=None):
        with tf.variable_scope("variational"):
            if hasattr(model, 'n_vars'):
                variational = Variational()
                variational.add(PointMass(model.n_vars, params))
            else:
                variational = Variational()
                variational.add(PointMass(0))

        super(MAP, self).__init__(model, variational, data)
def main():
    data = ed.Data(np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
    model = BetaBernoulli()
    variational = Variational()
    variational.add(Beta())

    # mean-field variational inference.
    inference = ed.MFVI(model, variational, data)

    inference.run(n_iter=10000)
예제 #4
0
    def _test(self, sess, data, n_minibatch, x=None, is_file=False):
        model = NormalModel()
        variational = Variational()
        variational.add(Normal())

        inference = ed.MFVI(model, variational, data)
        inference.initialize(n_minibatch=n_minibatch)

        if x is not None:
            # Placeholder setting.
            # Check data is same as data fed to it.
            feed_dict = {inference.data['x']: x}
            # avoid directly fetching placeholder
            data_id = {
                k: tf.identity(v)
                for k, v in six.iteritems(inference.data)
            }
            val = sess.run(data_id, feed_dict)
            assert np.all(val['x'] == x)
        elif is_file:
            # File reader setting.
            # Check data varies by session run.
            val = sess.run(inference.data)
            val_1 = sess.run(inference.data)
            assert not np.all(val['x'] == val_1['x'])
        elif n_minibatch is None:
            # Preloaded full setting.
            # Check data is full data.
            val = sess.run(inference.data)
            assert np.all(val['x'] == data['x'])
        elif n_minibatch == 1:
            # Preloaded batch setting, with n_minibatch=1.
            # Check data is randomly shuffled.
            assert not np.all([
                sess.run(inference.data)['x'] == data['x'][i]
                for i in range(10)
            ])
        else:
            # Preloaded batch setting.
            # Check data is randomly shuffled.
            val = sess.run(inference.data)
            assert not np.all(val['x'] == data['x'][:n_minibatch])
            # Check data varies by session run.
            val_1 = sess.run(inference.data)
            assert not np.all(val['x'] == val_1['x'])

        inference.finalize()

def build_toy_dataset(N=40, noise_std=0.1):
    ed.set_seed(0)
    x = np.concatenate(
        [np.linspace(0, 2, num=N / 2),
         np.linspace(6, 8, num=N / 2)])
    y = 0.075 * x + norm.rvs(0, noise_std, size=N)
    x = (x - 4.0) / 4.0
    x = x.reshape((N, 1))
    return {'x': x, 'y': y}


ed.set_seed(42)
model = LinearModel()
variational = Variational()
variational.add(Normal(model.n_vars))
data = build_toy_dataset()

# Set up figure
fig = plt.figure(figsize=(8, 8), facecolor='white')
ax = fig.add_subplot(111, frameon=False)
plt.ion()
plt.show(block=False)

sess = ed.get_session()
inference = ed.MFVI(model, variational, data)
inference.initialize(n_samples=5, n_print=5)
for t in range(250):
    loss = inference.update()
    if t % inference.n_print == 0: