Esempio n. 1
0
def test_ubersum_total(equation, plates):
    inputs, outputs, operands, sizes = make_example(equation, fill=1, sizes=(2,))
    output = outputs[0]

    expected = naive_ubersum(equation, *operands, plates=plates)[0]
    actual = ubersum(equation, *operands, plates=plates, modulo_total=True)[0]
    expected = _normalize(expected, output, plates)
    actual = _normalize(actual, output, plates)
    assert_equal(expected, actual,
                 msg=u"Expected:\n{}\nActual:\n{}".format(
                     expected.detach().cpu(), actual.detach().cpu()))
Esempio n. 2
0
def test_ubersum(equation, plates):
    inputs, outputs, operands, sizes = make_example(equation)

    try:
        actual = ubersum(equation, *operands, plates=plates, modulo_total=True)
    except NotImplementedError:
        pytest.skip()

    assert isinstance(actual, tuple)
    assert len(actual) == len(outputs)
    expected = naive_ubersum(equation, *operands, plates=plates)
    for output, expected_part, actual_part in zip(outputs, expected, actual):
        actual_part = _normalize(actual_part, output, plates)
        expected_part = _normalize(expected_part, output, plates)
        assert_equal(expected_part, actual_part,
                     msg=u"For output '{}':\nExpected:\n{}\nActual:\n{}".format(
                         output, expected_part.detach().cpu(), actual_part.detach().cpu()))
Esempio n. 3
0
def test_naive_ubersum(equation, plates):
    inputs, outputs, operands, sizes = make_example(equation)

    actual = naive_ubersum(equation, *operands, plates=plates)

    assert isinstance(actual, tuple)
    assert len(actual) == len(outputs)
    for output, actual_part in zip(outputs, actual):
        expected_shape = tuple(sizes[dim] for dim in output)
        assert actual_part.shape == expected_shape
        if not plates:
            equation_part = ','.join(inputs) + '->' + output
            expected_part = opt_einsum.contract(equation_part, *operands,
                                                backend='pyro.ops.einsum.torch_log')
            assert_equal(expected_part, actual_part,
                         msg=u"For output '{}':\nExpected:\n{}\nActual:\n{}".format(
                             output, expected_part.detach().cpu(), actual_part.detach().cpu()))