def log_prior(self, leaf_values): return T.mean( log_normal(leaf_values, T.zeros_like(leaf_values, dtype='float32'), T.ones_like(leaf_values, dtype='float32'), self.embedding_size, dim=2))
def vmp(graph, data, max_iter=100, tol=1e-4): q, visible = {}, {} for node in top_sort(graph)[::-1]: if node in data: visible[node] = T.to_float(data[node]) else: q[node] = initialize_node(node, {}) ordering = list(q.keys()) params = [q[var].get_parameters('natural') for var in ordering] prev_elbo = T.constant(float('inf')) def cond(i, elbo, prev_elbo, q): return T.logical_and(i < max_iter, abs(elbo - prev_elbo) > tol) def step(i, elbo, prev_elbo, q): prev_elbo = elbo q_vars = { var: var.__class__(param, 'natural') for var, param in zip(ordering, q) } q, elbo = message_passing(q_vars, visible) return i + 1, elbo, prev_elbo, [ q[var].get_parameters('natural') for var in ordering ] i, elbo, prev_elbo, q = T.while_loop(cond, step, [0, float('inf'), 0.0, params]) return { var: var.__class__(param, 'natural') for var, param in zip(ordering, q) }, elbo
def get_child_message(x, y, hidden={}, visible={}): with graph_context({**hidden, **visible}): data = context(y) log_likelihood = y.log_likelihood(data) stats = x.statistics() param = T.grad(T.sum(log_likelihood), [x.get_statistic(s) for s in stats]) return {s: param[i] for i, s in enumerate(stats)}
def coerce(x, shape=None): from .deterministic_tensor import DeterministicTensor if isinstance(x, float) or isinstance(x, int): return DeterministicTensor(T.constant(x)) if isinstance(x, np.ndarray): return DeterministicTensor(T.constant(x)) if isinstance(x, T.core.Tensor): return DeterministicTensor(x)
def get_stat(x, name, feed_dict={}): node = get_current_graph().get_node(x) print(x, name) if node is not None: return node.get_stat(name, feed_dict=feed_dict) if name == 'x': return x elif name == 'xxT': return T.outer(x, x) elif name == '-0.5S^-1': return -0.5 * T.matrix_inverse(x) elif name == '-0.5log|S|': return -0.5 * T.logdet(x) raise Exception()
def next_state(self, state, action, t): state_action = T.concatenate([state, action], -1) sigma, delta_mu = self.network(state_action).get_parameters('regular') return stats.Gaussian([ sigma, delta_mu + state, ])
def log_z(self, parameter_type='regular', stop_gradient=False): if parameter_type == 'regular': pi = self.get_parameters('regular', stop_gradient=stop_gradient) eta = Stats.LogX(pi) else: eta = self.get_parameters('natural', stop_gradient=stop_gradient)[Stats.X] return T.logsumexp(eta, -1)
def message_passing(hidden, visible): elbo = 0.0 for var in top_sort(hidden)[::-1]: child_messages = [ get_child_message( var, c, hidden={k: v for k, v in hidden.items() if k != var}, visible=visible) for c in var.children() ] stats = var.statistics() parent_message = var.get_parameters('natural') e_p = var.__class__(parent_message, 'natural', graph=False) natparam = { s: parent_message[s] + sum([child_message[s] for child_message in child_messages]) for s in stats } q = var.__class__(natparam, 'natural', graph=False) elbo -= kl_divergence(q, e_p) hidden[var] = q for var in visible: with graph_context(hidden): elbo += T.sum(var.log_likelihood(visible[var])) return hidden, elbo