def compare_sampler_grads(lds, num_samples, seed): init_params, pair_params, node_params = lds messages, _ = natural_filter_forward_general(init_params, pair_params, node_params) def fun1(messages): npr.seed(seed) samples = natural_sample_backward_general(messages, pair_params, num_samples) return np.sum(np.sin(samples)) grads1 = grad(fun1)(messages) messages, _ = _natural_filter_forward_general(init_params, pair_params, node_params) def fun2(messages): npr.seed(seed) samples = _natural_sample_backward(messages, pair_params, num_samples) return np.sum(np.sin(samples)) grads2 = grad(fun2)(messages) unpack_dense_grads = lambda x: interleave(*map(lambda y: zip(*y), x)) assert allclose(grads1, unpack_dense_grads(grads2))
def natural_filter_forward_general(init_params, pair_params, node_params): init_params = _canonical_init_params(init_params) node_params = zip(*_canonical_node_params(node_params)) pair_params = _repeat_param(pair_params, len(node_params) - 1) def unit(J, h, logZ): return [(J, h)], logZ def bind(result, step): messages, lognorm = result new_message, term = step(messages[-1]) return messages + [new_message], lognorm + term condition = lambda node_param: lambda (J, h): natural_condition_on_general(J, h, *node_param) predict = lambda pair_param: lambda (J, h): natural_predict(J, h, *pair_param) steps = interleave(map(condition, node_params), map(predict, pair_params)) messages, lognorm = monad_runner(bind)(unit(*init_params), steps) lognorm += natural_lognorm(*messages[-1]) return messages, lognorm
def natural_filter_forward(natparam, data): T, p = data.shape init_params, pair_params, node_params = _unpack_repeated(natparam, T) def unit(J, h): return [(J, h)], 0. def bind(result, step): messages, lognorm = result new_message, term = step(messages[-1]) return messages + [new_message], lognorm + term condition = lambda node_param, y: lambda (J, h): natural_condition_on(J, h, y, *node_param) predict = lambda pair_param: lambda (J, h): natural_predict(J, h, *pair_param) steps = interleave(map(condition, node_params, data), map(predict, pair_params)) J_init, h_init, logZ_init = init_params messages, lognorm = monad_runner(bind)(unit(J_init, h_init), steps) lognorm += natural_lognorm(*messages[-1]) + logZ_init return messages, lognorm - T*p/2*np.log(2*np.pi)
def compare_sampler_grads(lds, num_samples, seed): init_params, pair_params, node_params = lds messages, _ = natural_filter_forward_general( init_params, pair_params, node_params) def fun1(messages): npr.seed(seed) samples = natural_sample_backward_general(messages, pair_params, num_samples) return np.sum(np.sin(samples)) grads1 = grad(fun1)(messages) messages, _ = _natural_filter_forward_general( init_params, pair_params, node_params) def fun2(messages): npr.seed(seed) samples = _natural_sample_backward(messages, pair_params, num_samples) return np.sum(np.sin(samples)) grads2 = grad(fun2)(messages) unpack_dense_grads = lambda x: interleave(*map(lambda y: zip(*y), x)) assert allclose(grads1, unpack_dense_grads(grads2))
def natural_filter_forward_general(init_params, pair_params, node_params): init_params = _canonical_init_params(init_params) node_params = zip(*_canonical_node_params(node_params)) pair_params = _repeat_param(pair_params, len(node_params) - 1) def unit(J, h, logZ): return [(J, h)], logZ def bind(result, step): messages, lognorm = result new_message, term = step(messages[-1]) return messages + [new_message], lognorm + term condition = lambda node_param: lambda (J, h): natural_condition_on_general( J, h, *node_param) predict = lambda pair_param: lambda (J, h): natural_predict( J, h, *pair_param) steps = interleave(map(condition, node_params), map(predict, pair_params)) messages, lognorm = monad_runner(bind)(unit(*init_params), steps) lognorm += natural_lognorm(*messages[-1]) return messages, lognorm
def unpack_dense_messages(messages): (J_predict, h_predict), (J_filtered, h_filtered) = messages prediction_messages = zip(J_predict, h_predict) filtered_messages = zip(J_filtered, h_filtered) return interleave(prediction_messages, filtered_messages)