예제 #1
0
    def compare_sampler_grads(lds, num_samples, seed):
        init_params, pair_params, node_params = lds

        messages, _ = natural_filter_forward_general(init_params, pair_params,
                                                     node_params)

        def fun1(messages):
            npr.seed(seed)
            samples = natural_sample_backward_general(messages, pair_params,
                                                      num_samples)
            return np.sum(np.sin(samples))

        grads1 = grad(fun1)(messages)

        messages, _ = _natural_filter_forward_general(init_params, pair_params,
                                                      node_params)

        def fun2(messages):
            npr.seed(seed)
            samples = _natural_sample_backward(messages, pair_params,
                                               num_samples)
            return np.sum(np.sin(samples))

        grads2 = grad(fun2)(messages)

        unpack_dense_grads = lambda x: interleave(*map(lambda y: zip(*y), x))

        assert allclose(grads1, unpack_dense_grads(grads2))
예제 #2
0
def natural_filter_forward_general(init_params, pair_params, node_params):
    init_params = _canonical_init_params(init_params)
    node_params = zip(*_canonical_node_params(node_params))
    pair_params = _repeat_param(pair_params, len(node_params) - 1)

    def unit(J, h, logZ):
        return [(J, h)], logZ

    def bind(result, step):
        messages, lognorm = result
        new_message, term = step(messages[-1])
        return messages + [new_message], lognorm + term

    condition = lambda node_param: lambda (J, h): natural_condition_on_general(J, h, *node_param)
    predict = lambda pair_param: lambda (J, h): natural_predict(J, h, *pair_param)
    steps = interleave(map(condition, node_params), map(predict, pair_params))

    messages, lognorm = monad_runner(bind)(unit(*init_params), steps)
    lognorm += natural_lognorm(*messages[-1])

    return messages, lognorm
예제 #3
0
def natural_filter_forward(natparam, data):
    T, p = data.shape
    init_params, pair_params, node_params = _unpack_repeated(natparam, T)

    def unit(J, h):
        return [(J, h)], 0.

    def bind(result, step):
        messages, lognorm = result
        new_message, term = step(messages[-1])
        return messages + [new_message], lognorm + term

    condition = lambda node_param, y: lambda (J, h): natural_condition_on(J, h, y, *node_param)
    predict = lambda pair_param: lambda (J, h): natural_predict(J, h, *pair_param)
    steps = interleave(map(condition, node_params, data), map(predict, pair_params))

    J_init, h_init, logZ_init = init_params
    messages, lognorm = monad_runner(bind)(unit(J_init, h_init), steps)
    lognorm += natural_lognorm(*messages[-1]) + logZ_init

    return messages, lognorm - T*p/2*np.log(2*np.pi)
예제 #4
0
    def compare_sampler_grads(lds, num_samples, seed):
        init_params, pair_params, node_params = lds

        messages, _ = natural_filter_forward_general(
            init_params, pair_params, node_params)
        def fun1(messages):
            npr.seed(seed)
            samples = natural_sample_backward_general(messages, pair_params, num_samples)
            return np.sum(np.sin(samples))
        grads1 = grad(fun1)(messages)

        messages, _ = _natural_filter_forward_general(
                init_params, pair_params, node_params)
        def fun2(messages):
            npr.seed(seed)
            samples = _natural_sample_backward(messages, pair_params, num_samples)
            return np.sum(np.sin(samples))
        grads2 = grad(fun2)(messages)

        unpack_dense_grads = lambda x: interleave(*map(lambda y: zip(*y), x))

        assert allclose(grads1, unpack_dense_grads(grads2))
예제 #5
0
def natural_filter_forward_general(init_params, pair_params, node_params):
    init_params = _canonical_init_params(init_params)
    node_params = zip(*_canonical_node_params(node_params))
    pair_params = _repeat_param(pair_params, len(node_params) - 1)

    def unit(J, h, logZ):
        return [(J, h)], logZ

    def bind(result, step):
        messages, lognorm = result
        new_message, term = step(messages[-1])
        return messages + [new_message], lognorm + term

    condition = lambda node_param: lambda (J, h): natural_condition_on_general(
        J, h, *node_param)
    predict = lambda pair_param: lambda (J, h): natural_predict(
        J, h, *pair_param)
    steps = interleave(map(condition, node_params), map(predict, pair_params))

    messages, lognorm = monad_runner(bind)(unit(*init_params), steps)
    lognorm += natural_lognorm(*messages[-1])

    return messages, lognorm
예제 #6
0
def unpack_dense_messages(messages):
    (J_predict, h_predict), (J_filtered, h_filtered) = messages
    prediction_messages = zip(J_predict, h_predict)
    filtered_messages = zip(J_filtered, h_filtered)
    return interleave(prediction_messages, filtered_messages)
예제 #7
0
def unpack_dense_messages(messages):
    (J_predict, h_predict), (J_filtered, h_filtered) = messages
    prediction_messages = zip(J_predict, h_predict)
    filtered_messages = zip(J_filtered, h_filtered)
    return interleave(prediction_messages, filtered_messages)