def natural_smoother_general(forward_messages, init_params, pair_params, node_params): prediction_messages, filter_messages = uninterleave(forward_messages) inhomog = depth(pair_params) == 2 T = len(prediction_messages) pair_params, orig_pair_params = _repeat_param(pair_params, T - 1), pair_params node_params = list(zip(*_canonical_node_params(node_params))) def unit(filtered_message): J, h = filtered_message mu, Sigma = natural_to_mean(filtered_message) ExxT = Sigma + np.outer(mu, mu) return tuple_(J, h, mu), [(mu, ExxT, 0.)] def bind(result, step): next_smooth, stats = result J, h, (mu, ExxT, ExxnT) = step(next_smooth) return tuple_(J, h, mu), [(mu, ExxT, ExxnT)] + stats rts = lambda next_pred, filtered, pair_param: lambda next_smooth: \ natural_rts_backward_step(next_smooth, next_pred, filtered, pair_param) steps = reversed( list( map(rts, prediction_messages[1:], filter_messages[:-1], pair_params))) _, expected_stats = monad_runner(bind)(unit(filter_messages[-1]), steps) def make_init_stats(a): mu, ExxT, _ = a return (ExxT, mu, 1., 1.)[:len(init_params)] def make_pair_stats(a, b): (mu, ExxT, ExnxT), (mu_n, ExnxnT, _) = a, b return ExxT, ExnxT.T, ExnxnT, 1. is_diagonal = node_params[0][0].ndim == 1 if is_diagonal: def make_node_stats(a): mu, ExxT, _ = a return np.diag(ExxT), mu, 1. else: def make_node_stats(a): mu, ExxT, _ = a return ExxT, mu, 1. E_init_stats = make_init_stats(expected_stats[0]) E_pair_stats = list( map(make_pair_stats, expected_stats[:-1], expected_stats[1:])) E_node_stats = list(map(make_node_stats, expected_stats)) if not inhomog: E_pair_stats = reduce(add, E_pair_stats, zeros_like(orig_pair_params)) E_node_stats = list(map(np.array, list(zip(*E_node_stats)))) return E_init_stats, E_pair_stats, E_node_stats
def natural_smoother_general(forward_messages, init_params, pair_params, node_params): prediction_messages, filter_messages = uninterleave(forward_messages) inhomog = depth(pair_params) == 2 T = len(prediction_messages) pair_params, orig_pair_params = _repeat_param(pair_params, T-1), pair_params node_params = zip(*_canonical_node_params(node_params)) def unit(filtered_message): J, h = filtered_message mu, Sigma = natural_to_mean(filtered_message) ExxT = Sigma + np.outer(mu, mu) return make_tuple(J, h, mu), [(mu, ExxT, 0.)] def bind(result, step): next_smooth, stats = result J, h, (mu, ExxT, ExxnT) = step(next_smooth) return make_tuple(J, h, mu), [(mu, ExxT, ExxnT)] + stats rts = lambda next_pred, filtered, pair_param: lambda next_smooth: \ natural_rts_backward_step(next_smooth, next_pred, filtered, pair_param) steps = reversed(map(rts, prediction_messages[1:], filter_messages[:-1], pair_params)) _, expected_stats = monad_runner(bind)(unit(filter_messages[-1]), steps) def make_init_stats(a): mu, ExxT, _ = a return (ExxT, mu, 1., 1.)[:len(init_params)] def make_pair_stats(a, b): (mu, ExxT, ExnxT), (mu_n, ExnxnT, _) = a, b return ExxT, ExnxT.T, ExnxnT, 1. is_diagonal = node_params[0][0].ndim == 1 if is_diagonal: def make_node_stats(a): mu, ExxT, _ = a return np.diag(ExxT), mu, 1. else: def make_node_stats(a): mu, ExxT, _ = a return ExxT, mu, 1. E_init_stats = make_init_stats(expected_stats[0]) E_pair_stats = map(make_pair_stats, expected_stats[:-1], expected_stats[1:]) E_node_stats = map(make_node_stats, expected_stats) if not inhomog: E_pair_stats = reduce(add, E_pair_stats, zeros_like(orig_pair_params)) E_node_stats = map(np.array, zip(*E_node_stats)) return E_init_stats, E_pair_stats, E_node_stats
def natural_smooth_backward(forward_messages, natparam): prediction_messages, filter_messages = uninterleave(forward_messages) init_params, pair_params, node_params = _unpack_repeated(natparam) pair_params = map(itemgetter(0, 1, 2), pair_params) unit = lambda (J, h): [(J, h)] bind = lambda result, step: [step(result[0])] + result rts = lambda next_prediction, filtered, pair_param: lambda next_smoothed: \ natural_rts_backward_step(next_smoothed, next_prediction, filtered, pair_param) steps = map(rts, prediction_messages[1:], filter_messages, pair_params) return map(itemgetter(2), monad_runner(bind)(unit(filter_messages[-1]), steps))
def natural_sample_backward_general(forward_messages, pair_params, num_samples=None): filtered_messages = forward_messages[1::2] pair_params = _repeat_param(pair_params, len(filtered_messages) - 1) pair_params = map(itemgetter(0, 1), pair_params) unit = lambda sample: [sample] bind = lambda result, step: [step(result[0])] + result sample = lambda (J11, J12), (J_filt, h_filt): lambda next_sample: \ natural_sample(*natural_condition_on(J_filt, h_filt, next_sample, J11, J12)) steps = reversed(map(sample, pair_params, filtered_messages[:-1])) last_sample = natural_sample(*filtered_messages[-1], num_samples=num_samples) samples = monad_runner(bind)(unit(last_sample), steps) return np.concatenate([sample[None,...] for sample in samples])
def natural_sample_backward(forward_messages, natparam): _, filter_messages = uninterleave(forward_messages) _, pair_params, _ = _unpack_repeated(natparam, len(filter_messages)) pair_params = map(itemgetter(0, 1), pair_params) unit = lambda sample: [sample] bind = lambda result, step: [step(result[0])] + result sample = lambda (J11, J12), (J_filt, h_filt): lambda next_sample: \ natural_sample(*natural_condition_on(J_filt, h_filt, next_sample, J11, J12)) steps = reversed(map(sample, pair_params, filter_messages[:-1])) last_sample = natural_sample(*filter_messages[-1]) samples = monad_runner(bind)(unit(last_sample), steps) return np.array(samples)
def natural_sample_backward_general(forward_messages, pair_params, num_samples=None): filtered_messages = forward_messages[1::2] pair_params = _repeat_param(pair_params, len(filtered_messages) - 1) pair_params = map(itemgetter(0, 1), pair_params) unit = lambda sample: [sample] bind = lambda result, step: [step(result[0])] + result sample = lambda (J11, J12), (J_filt, h_filt): lambda next_sample: \ natural_sample(*natural_condition_on(J_filt, h_filt, next_sample, J11, J12)) steps = reversed(map(sample, pair_params, filtered_messages[:-1])) last_sample = natural_sample(*filtered_messages[-1], num_samples=num_samples) samples = monad_runner(bind)(unit(last_sample), steps) return np.concatenate([sample[None, ...] for sample in samples])
def natural_filter_forward_general(init_params, pair_params, node_params): init_params = _canonical_init_params(init_params) node_params = zip(*_canonical_node_params(node_params)) pair_params = _repeat_param(pair_params, len(node_params) - 1) def unit(J, h, logZ): return [(J, h)], logZ def bind(result, step): messages, lognorm = result new_message, term = step(messages[-1]) return messages + [new_message], lognorm + term condition = lambda node_param: lambda (J, h): natural_condition_on_general(J, h, *node_param) predict = lambda pair_param: lambda (J, h): natural_predict(J, h, *pair_param) steps = interleave(map(condition, node_params), map(predict, pair_params)) messages, lognorm = monad_runner(bind)(unit(*init_params), steps) lognorm += natural_lognorm(*messages[-1]) return messages, lognorm
def natural_filter_forward(natparam, data): T, p = data.shape init_params, pair_params, node_params = _unpack_repeated(natparam, T) def unit(J, h): return [(J, h)], 0. def bind(result, step): messages, lognorm = result new_message, term = step(messages[-1]) return messages + [new_message], lognorm + term condition = lambda node_param, y: lambda (J, h): natural_condition_on(J, h, y, *node_param) predict = lambda pair_param: lambda (J, h): natural_predict(J, h, *pair_param) steps = interleave(map(condition, node_params, data), map(predict, pair_params)) J_init, h_init, logZ_init = init_params messages, lognorm = monad_runner(bind)(unit(J_init, h_init), steps) lognorm += natural_lognorm(*messages[-1]) + logZ_init return messages, lognorm - T*p/2*np.log(2*np.pi)
def sample_backward(filtered_mus, filtered_sigmas, A, sigma_states): def filtered_sampler(mu_filt, sigma_filt): def sample_cond(next_state): (mu_cond, sigma_cond), _ = condition_on( mu_filt, sigma_filt, A, next_state, sigma_states) return sample(mu_cond, sigma_cond) return sample_cond def unit(sample): return [sample] def bind(result, step): samples = result sample = step(samples[0]) return [sample] + samples last_sample = sample(filtered_mus[-1], filtered_sigmas[-1]) steps = reversed(map(filtered_sampler, filtered_mus[:-1], filtered_sigmas[:-1])) samples = monad_runner(bind)(unit(last_sample), steps) return np.array(samples)
def filter_forward(data, mu_init, sigma_init, A, sigma_states, C, sigma_obs): def observe(y): def update_belief(mu, sigma): mu_pred, sigma_pred = predict(mu, sigma, A, sigma_states) (mu_filt, sigma_filt), ll = condition_on(mu_pred, sigma_pred, C, y, sigma_obs) return (mu_filt, sigma_filt), ll return update_belief def unit(mu, sigma): return ([mu], [sigma]), 0. def bind(result, step): (mus, sigmas), lognorm = result (mu, sigma), term = step(mus[-1], sigmas[-1]) return (mus + [mu], sigmas + [sigma]), lognorm + term (mu_filt, sigma_filt), ll = condition_on(mu_init, sigma_init, C, data[0], sigma_obs) (filtered_mus, filtered_sigmas), loglike = \ monad_runner(bind)(unit(mu_filt, sigma_filt), map(observe, data[1:])) return (filtered_mus, filtered_sigmas), loglike + ll
def natural_filter_forward_general(init_params, pair_params, node_params): init_params = _canonical_init_params(init_params) node_params = zip(*_canonical_node_params(node_params)) pair_params = _repeat_param(pair_params, len(node_params) - 1) def unit(J, h, logZ): return [(J, h)], logZ def bind(result, step): messages, lognorm = result new_message, term = step(messages[-1]) return messages + [new_message], lognorm + term condition = lambda node_param: lambda (J, h): natural_condition_on_general( J, h, *node_param) predict = lambda pair_param: lambda (J, h): natural_predict( J, h, *pair_param) steps = interleave(map(condition, node_params), map(predict, pair_params)) messages, lognorm = monad_runner(bind)(unit(*init_params), steps) lognorm += natural_lognorm(*messages[-1]) return messages, lognorm