def test_ubersum_total(equation, plates): inputs, outputs, operands, sizes = make_example(equation, fill=1, sizes=(2,)) output = outputs[0] expected = naive_ubersum(equation, *operands, plates=plates)[0] actual = ubersum(equation, *operands, plates=plates, modulo_total=True)[0] expected = _normalize(expected, output, plates) actual = _normalize(actual, output, plates) assert_equal(expected, actual, msg=u"Expected:\n{}\nActual:\n{}".format( expected.detach().cpu(), actual.detach().cpu()))
def test_ubersum(equation, plates): inputs, outputs, operands, sizes = make_example(equation) try: actual = ubersum(equation, *operands, plates=plates, modulo_total=True) except NotImplementedError: pytest.skip() assert isinstance(actual, tuple) assert len(actual) == len(outputs) expected = naive_ubersum(equation, *operands, plates=plates) for output, expected_part, actual_part in zip(outputs, expected, actual): actual_part = _normalize(actual_part, output, plates) expected_part = _normalize(expected_part, output, plates) assert_equal(expected_part, actual_part, msg=u"For output '{}':\nExpected:\n{}\nActual:\n{}".format( output, expected_part.detach().cpu(), actual_part.detach().cpu()))
def test_naive_ubersum(equation, plates): inputs, outputs, operands, sizes = make_example(equation) actual = naive_ubersum(equation, *operands, plates=plates) assert isinstance(actual, tuple) assert len(actual) == len(outputs) for output, actual_part in zip(outputs, actual): expected_shape = tuple(sizes[dim] for dim in output) assert actual_part.shape == expected_shape if not plates: equation_part = ','.join(inputs) + '->' + output expected_part = opt_einsum.contract(equation_part, *operands, backend='pyro.ops.einsum.torch_log') assert_equal(expected_part, actual_part, msg=u"For output '{}':\nExpected:\n{}\nActual:\n{}".format( output, expected_part.detach().cpu(), actual_part.detach().cpu()))