Exemple #1
0
def test_categorical_smoothed_node_data_smooth():
    data_1 = numpy.array([[1],
                          [0],
                          [1],
                          [0],
                          [1]])

    data_2 = numpy.array([[1, 0],
                          [0, 1],
                          [1, 1],
                          [0, 1],
                          [1, 0]])

    alpha = 0

    freqs = CategoricalSmoothedNode.smooth_freq_from_data(data_1, alpha)
    print('freqs', freqs)

    exp_freqs = CategoricalSmoothedNode.smooth_ll([2 / 5, 3 / 5], alpha)
    print('exp freqs', exp_freqs)
    assert_array_almost_equal(exp_freqs, freqs)

    # now create a node
    input_node = CategoricalSmoothedNode(var=0,
                                         var_values=2,
                                         instances={0, 2, 4})
    input_node.smooth_probs(alpha, data=data_1)
    exp_probs = CategoricalSmoothedNode.smooth_ll([0, 1], alpha)
    print('exp probs', exp_probs)
    print('probs', input_node._var_probs)

    assert_log_array_almost_equal(exp_probs,
                                  input_node._var_probs)

    input_node.smooth_probs(alpha, data=data_2)
    assert_log_array_almost_equal(exp_probs,
                                  input_node._var_probs)
Exemple #2
0
def test_categorical_smoothed_node_resmooth():
    for i, var in enumerate(vars):
        alpha = alphas[0]
        var_freq = freqs[i]
        smo = CategoricalSmoothedNode(i, var, alpha, var_freq)
        smo.eval(obs[i])
        print('smo values')
        print(smo.log_val)
        # checking the right value
        ll = compute_smoothed_ll(obs[i], var_freq, var, alpha)
        print('log values')
        print(ll)
        assert_almost_equal(ll, smo.log_val, 15)
        # now setting another alpha
        print('Changing smooth level')
        for alpha_new in alphas:
            smo.smooth_probs(alpha_new)
            smo.eval(obs[i])
            print('smo values')
            print(smo.log_val)
            ll = compute_smoothed_ll(obs[i], var_freq, var, alpha_new)
            print('log values')
            print(ll)
            assert_almost_equal(ll, smo.log_val, 15)