Example #1
0
 def em(i, q_dyn_natparam, q_X_natparam, _, curr_elbo):
     q_X_ = stats.LDS(q_X_natparam, 'natural')
     ess = q_X_.expected_sufficient_statistics()
     batch_size = T.shape(ess)[0]
     yyT = ess[..., :-1, ds:2 * ds, ds:2 * ds]
     xxT = ess[..., :-1, :ds, :ds]
     yxT = ess[..., :-1, ds:2 * ds, :ds]
     x = ess[..., :-1, -1, :ds]
     y = ess[..., :-1, -1, ds:2 * ds]
     xaT = T.outer(x, a)
     yaT = T.outer(y, a)
     xaxaT = T.concatenate([
         T.concatenate([xxT, xaT], -1),
         T.concatenate([T.matrix_transpose(xaT), aaT], -1),
     ], -2)
     ess = [
         yyT,
         T.concatenate([yxT, yaT], -1), xaxaT,
         T.ones([batch_size, self.horizon - 1])
     ]
     q_dyn_natparam = [
         T.sum(a, [0]) * data_strength + b
         for a, b in zip(ess, initial_dyn_natparam)
     ]
     q_dyn_ = stats.MNIW(q_dyn_natparam, 'natural')
     q_stats = q_dyn_.expected_sufficient_statistics()
     p_X = stats.LDS((q_stats, state_prior, None,
                      q_A.expected_value(), self.horizon))
     q_X_ = stats.LDS((q_stats, state_prior, q_X,
                       q_A.expected_value(), self.horizon))
     elbo = (T.sum(stats.kl_divergence(q_X_, p_X)) +
             T.sum(stats.kl_divergence(q_dyn_, prior_dyn)))
     return i + 1, q_dyn_.get_parameters(
         'natural'), q_X_.get_parameters('natural'), curr_elbo, elbo
Example #2
0
 def kl_divergence(self, q_X, q_A, _):
     # q_Xt - [N, H, ds]
     # q_At - [N, H, da]
     if (q_X, q_A) not in self.cache:
         q_Xt = q_X.__class__([
             q_X.get_parameters('regular')[0][:, :-1],
             q_X.get_parameters('regular')[1][:, :-1],
         ])
         q_At = q_A.__class__([
             q_A.get_parameters('regular')[0][:, :-1],
             q_A.get_parameters('regular')[1][:, :-1],
         ])
         p_Xt1 = self.forward(q_Xt, q_At)
         q_Xt1 = q_X.__class__([
             q_X.get_parameters('regular')[0][:, 1:],
             q_X.get_parameters('regular')[1][:, 1:],
         ])
         rmse = T.sqrt(
             T.sum(T.square(
                 q_Xt1.get_parameters('regular')[1] -
                 p_Xt1.get_parameters('regular')[1]),
                   axis=-1))
         model_stdev = T.sqrt(p_Xt1.get_parameters('regular')[0])
         encoding_stdev = T.sqrt(q_Xt1.get_parameters('regular')[0])
         self.cache[(q_X, q_A)] = T.sum(stats.kl_divergence(q_Xt1, p_Xt1),
                                        axis=-1), {
                                            'rmse': rmse,
                                            'encoding-stdev':
                                            encoding_stdev,
                                            'model-stdev': model_stdev
                                        }
     return self.cache[(q_X, q_A)]
Example #3
0
 def kl_divergence(self, q_X, q_A, _):
     # q_Xt - [N, H, ds]
     # q_At - [N, H, da]
     if (q_X, q_A) not in self.cache:
         info = {}
         if self.smooth:
             state_prior = stats.GaussianScaleDiag([
                 T.ones(self.ds),
                 T.zeros(self.ds)
             ])
             p_X = stats.LDS(
                 (self.sufficient_statistics(), state_prior, None, q_A.expected_value(), self.horizon),
             'internal')
             kl = T.mean(stats.kl_divergence(q_X, p_X), axis=0)
             Q = self.get_dynamics()[1]
             info['model-stdev'] = T.sqrt(T.matrix_diag_part(Q))
         else:
             q_Xt = q_X.__class__([
                 q_X.get_parameters('regular')[0][:, :-1],
                 q_X.get_parameters('regular')[1][:, :-1],
             ])
             q_At = q_A.__class__([
                 q_A.get_parameters('regular')[0][:, :-1],
                 q_A.get_parameters('regular')[1][:, :-1],
             ])
             p_Xt1 = self.forward(q_Xt, q_At)
             q_Xt1 = q_X.__class__([
                 q_X.get_parameters('regular')[0][:, 1:],
                 q_X.get_parameters('regular')[1][:, 1:],
             ])
             rmse = T.sqrt(T.sum(T.square(q_Xt1.get_parameters('regular')[1] - p_Xt1.get_parameters('regular')[1]), axis=-1))
             kl = T.mean(T.sum(stats.kl_divergence(q_Xt1, p_Xt1), axis=-1), axis=0)
             Q = self.get_dynamics()[1]
             model_stdev = T.sqrt(T.matrix_diag_part(Q))
             info['rmse'] = rmse
             info['model-stdev'] = model_stdev
         self.cache[(q_X, q_A)] = kl, info
     return self.cache[(q_X, q_A)]
Example #4
0
 def kl_divergence(self, q_X, q_A, num_data):
     if (q_X, q_A) not in self.cache:
         if self.smooth:
             state_prior = stats.GaussianScaleDiag(
                 [T.ones(self.ds), T.zeros(self.ds)])
             self.p_X = stats.LDS(
                 (self.sufficient_statistics(), state_prior, None,
                  q_A.expected_value(), self.horizon), 'internal')
             local_kl = stats.kl_divergence(q_X, self.p_X)
             if self.time_varying:
                 global_kl = T.sum(
                     stats.kl_divergence(self.A_variational, self.A_prior))
             else:
                 global_kl = stats.kl_divergence(self.A_variational,
                                                 self.A_prior)
             prior_kl = T.mean(local_kl,
                               axis=0) + global_kl / T.to_float(num_data)
             A, Q = self.get_dynamics()
             model_stdev = T.sqrt(T.matrix_diag_part(Q))
             self.cache[(q_X, q_A)] = prior_kl, {
                 'local-kl': local_kl,
                 'global-kl': global_kl,
                 'model-stdev': model_stdev,
             }
         else:
             q_Xt = q_X.__class__([
                 q_X.get_parameters('regular')[0][:, :-1],
                 q_X.get_parameters('regular')[1][:, :-1],
             ])
             q_At = q_A.__class__([
                 q_A.get_parameters('regular')[0][:, :-1],
                 q_A.get_parameters('regular')[1][:, :-1],
             ])
             p_Xt1 = self.forward(q_Xt, q_At)
             q_Xt1 = q_X.__class__([
                 q_X.get_parameters('regular')[0][:, 1:],
                 q_X.get_parameters('regular')[1][:, 1:],
             ])
             num_data = T.to_float(num_data)
             rmse = T.sqrt(
                 T.sum(T.square(
                     q_Xt1.get_parameters('regular')[1] -
                     p_Xt1.get_parameters('regular')[1]),
                       axis=-1))
             A, Q = self.get_dynamics()
             model_stdev = T.sqrt(T.matrix_diag_part(Q))
             local_kl = T.sum(stats.kl_divergence(q_Xt1, p_Xt1), axis=1)
             if self.time_varying:
                 global_kl = T.sum(
                     stats.kl_divergence(self.A_variational, self.A_prior))
             else:
                 global_kl = stats.kl_divergence(self.A_variational,
                                                 self.A_prior)
             self.cache[(q_X, q_A)] = (T.mean(local_kl, axis=0) +
                                       global_kl / T.to_float(num_data), {
                                           'rmse': rmse,
                                           'model-stdev': model_stdev,
                                           'local-kl': local_kl,
                                           'global-kl': global_kl
                                       })
     return self.cache[(q_X, q_A)]
Example #5
0
 def kl_divergence(self, q_X, q_A, num_data):
     mu_shape = T.shape(q_X.get_parameters('regular')[1])
     p_X = stats.GaussianScaleDiag([T.ones(mu_shape), T.zeros(mu_shape)])
     return T.mean(T.sum(stats.kl_divergence(q_X, p_X), -1), 0), {}
Example #6
0
x_stats = Gaussian.pack([
    T.outer(X, X),
    X,
])
theta_cmessage = q_theta.expected_sufficient_statistics()

num_batches = N / T.to_float(batch_size)
nat_scale = 10.0

parent_z = q_pi.expected_sufficient_statistics()[None]
new_z = T.einsum('iab,jab->ij', x_tmessage, theta_cmessage) + parent_z
q_z = Categorical(new_z - T.logsumexp(new_z, -1)[..., None],
                  parameter_type='natural')
p_z = Categorical(parent_z - T.logsumexp(parent_z, -1),
                  parameter_type='natural')
l_z = T.sum(kl_divergence(q_z, p_z))
z_pmessage = q_z.expected_sufficient_statistics()

pi_stats = T.sum(z_pmessage, 0)
parent_pi = p_pi.get_parameters('natural')
current_pi = q_pi.get_parameters('natural')
pi_gradient = nat_scale / N * (parent_pi + num_batches * pi_stats - current_pi)
l_pi = T.sum(kl_divergence(q_pi, p_pi))

theta_stats = T.einsum('ia,ibc->abc', z_pmessage, x_tmessage)
parent_theta = p_theta.get_parameters('natural')[None]
current_theta = q_theta.get_parameters('natural')
theta_gradient = nat_scale / N * (parent_theta + num_batches * theta_stats -
                                  current_theta)
l_theta = T.sum(kl_divergence(q_theta, p_theta))
Example #7
0
x_tmessage = NIW.pack([
    T.outer(X, X),
    X,
    T.ones(N),
    T.ones(N),
])
x_stats = Gaussian.pack([
    T.outer(X, X),
    X,
])
theta_cmessage = q_theta.expected_sufficient_statistics()

new_pi = p_pi.get_parameters('natural') + T.sum(z_pmessage, 0)
parent_pi = p_pi.get_parameters('natural')
pi_update = T.assign(q_pi.get_parameters('natural'), new_pi)
l_pi = T.sum(kl_divergence(q_pi, p_pi))

new_theta = T.einsum('ia,ibc->abc', z_pmessage,
                     x_tmessage) + p_theta.get_parameters('natural')[None]
parent_theta = p_theta.get_parameters('natural')
theta_update = T.assign(q_theta.get_parameters('natural'), new_theta)
l_theta = T.sum(kl_divergence(q_theta, p_theta))

parent_z = q_pi.expected_sufficient_statistics()[None]
new_z = T.einsum('iab,jab->ij', x_tmessage,
                 theta_cmessage) + q_pi.expected_sufficient_statistics()[None]
new_z = new_z - T.logsumexp(new_z, -1)[..., None]
z_update = T.assign(q_z.get_parameters('natural'), new_z)
l_z = T.sum(kl_divergence(q_z, Categorical(parent_z,
                                           parameter_type='natural')))
Example #8
0
# T.ones([batch_size]),
# T.ones([batch_size]),
# ])

# q_theta = make_variable(NIW(map(lambda x: np.array(x).astype(T.floatx()), [np.eye(D), np.random.multivariate_normal(mean=np.zeros([D]), cov=np.eye(D) * 20), 1.0, 1.0])))

num_batches = N / T.to_float(batch_size)
nat_scale = 1.0

theta_stats = T.sum(x_tmessage, 0)
parent_theta = p_theta.get_parameters('natural')
q_theta = NIW(parent_theta + theta_stats, parameter_type='natural')
sigma, mu = Gaussian(q_theta.expected_sufficient_statistics(),
                     parameter_type='natural').get_parameters('regular')
# theta_gradient = nat_scale / N * (parent_theta + num_batches * theta_stats - current_theta)
l_theta = T.sum(kl_divergence(q_theta, p_theta))

x_param = q_theta.expected_sufficient_statistics()[None]
q_x = Gaussian(T.tile(x_param, [batch_size, 1, 1]), parameter_type='natural')
l_x = T.sum(q_x.log_likelihood(X))

elbo = l_theta + l_x
elbos = []
l_thetas = []
l_xs = []

# natgrads = [(theta_gradient, q_theta.get_parameters('natural'))]

# nat_op = tf.group(*[T.assign(b, a + b) for a, b in natgrads])
# nat_opt = tf.train.GradientDescentOptimizer(1e-2)
# nat_op = nat_opt.apply_gradients([(-a, b) for a, b in natgrads])
Example #9
0
lr = 1e-4
batch_size = T.shape(x)[0]
num_batches = T.to_float(N / batch_size)

with T.initialization('xavier'):
    # stats_net = Relu(D + 1, 20) >> Relu(20) >> GaussianLayer(D)
    stats_net = GaussianLayer(D + 1, D)
net_out = stats_net(T.concat([x, y[..., None]], -1))
stats = T.sum(net_out.get_parameters('natural'), 0)[None]

natural_gradient = (p_w.get_parameters('natural') + num_batches * stats -
                    q_w.get_parameters('natural')) / N
next_w = Gaussian(q_w.get_parameters('natural') + lr * natural_gradient,
                  parameter_type='natural')

l_w = kl_divergence(q_w, p_w)[0]

p_y = Bernoulli(T.sigmoid(T.einsum('jw,iw->ij', next_w.expected_value(), x)))
l_y = T.sum(p_y.log_likelihood(y[..., None]))
elbo = l_w + l_y

nat_op = T.assign(q_w.get_parameters('natural'),
                  next_w.get_parameters('natural'))
grad_op = tf.train.RMSPropOptimizer(1e-4).minimize(-elbo)
train_op = tf.group(nat_op, grad_op)
sess = T.interactive_session()

predictions = T.cast(
    T.sigmoid(T.einsum('jw,iw->i', q_w.expected_value(), T.to_float(X))) + 0.5,
    np.int32)
accuracy = T.mean(