Ejemplo n.º 1
0
def test_reduce_subset(dims, reduced_vars, op):
    reduced_vars = frozenset(reduced_vars)
    sizes = {'a': 3, 'b': 4, 'c': 5}
    shape = tuple(sizes[d] for d in dims)
    inputs = OrderedDict((d, bint(sizes[d])) for d in dims)
    data = rand(shape) + 0.5
    dtype = 'real'
    if op in [ops.and_, ops.or_]:
        data = astype(data, 'uint8')
        dtype = 2
    x = Tensor(data, inputs, dtype)
    actual = x.reduce(op, reduced_vars)
    expected_inputs = OrderedDict(
        (d, bint(sizes[d])) for d in dims if d not in reduced_vars)

    reduced_vars &= frozenset(dims)
    if not reduced_vars:
        assert actual is x
    else:
        if reduced_vars == frozenset(dims):
            data = REDUCE_OP_TO_NUMERIC[op](data, None)
        else:
            for pos in reversed(sorted(map(dims.index, reduced_vars))):
                data = REDUCE_OP_TO_NUMERIC[op](data, pos)
        check_funsor(actual, expected_inputs, Domain((), dtype))
        assert_close(actual,
                     Tensor(data, expected_inputs, dtype),
                     atol=1e-5,
                     rtol=1e-5)
Ejemplo n.º 2
0
def test_to_funsor(shape, dtype):
    t = astype(randn(shape), dtype)
    f = funsor.to_funsor(t)
    assert isinstance(f, Tensor)
    assert funsor.to_funsor(t, reals(*shape)) is f
    with pytest.raises(ValueError):
        funsor.to_funsor(t, reals(5, *shape))
Ejemplo n.º 3
0
def test_binary_funsor_funsor(symbol, dims1, dims2):
    sizes = {'a': 3, 'b': 4, 'c': 5}
    shape1 = tuple(sizes[d] for d in dims1)
    shape2 = tuple(sizes[d] for d in dims2)
    inputs1 = OrderedDict((d, bint(sizes[d])) for d in dims1)
    inputs2 = OrderedDict((d, bint(sizes[d])) for d in dims2)
    data1 = rand(shape1) + 0.5
    data2 = rand(shape2) + 0.5
    dtype = 'real'
    if symbol in BOOLEAN_OPS:
        dtype = 2
        data1 = astype(data1, 'uint8')
        data2 = astype(data2, 'uint8')
    x1 = Tensor(data1, inputs1, dtype)
    x2 = Tensor(data2, inputs2, dtype)
    inputs, aligned = align_tensors(x1, x2)
    expected_data = binary_eval(symbol, aligned[0], aligned[1])

    actual = binary_eval(symbol, x1, x2)
    check_funsor(actual, inputs, Domain((), dtype), expected_data)
Ejemplo n.º 4
0
def test_reduce_all(dims, op):
    sizes = {'a': 3, 'b': 4, 'c': 5}
    shape = tuple(sizes[d] for d in dims)
    inputs = OrderedDict((d, bint(sizes[d])) for d in dims)
    data = rand(shape) + 0.5
    if op in [ops.and_, ops.or_]:
        data = astype(data, 'uint8')
    expected_data = REDUCE_OP_TO_NUMERIC[op](data, None)

    x = Tensor(data, inputs)
    actual = x.reduce(op)
    check_funsor(actual, {}, reals(), expected_data)
Ejemplo n.º 5
0
def test_reduce_event(op, event_shape, dims):
    sizes = {'a': 3, 'b': 4, 'c': 5}
    batch_shape = tuple(sizes[d] for d in dims)
    shape = batch_shape + event_shape
    inputs = OrderedDict((d, bint(sizes[d])) for d in dims)
    numeric_op = REDUCE_OP_TO_NUMERIC[op]
    data = rand(shape) + 0.5
    dtype = 'real'
    if op in [ops.and_, ops.or_]:
        data = astype(data, 'uint8')
    expected_data = numeric_op(data.reshape(batch_shape + (-1, )), -1)

    x = Tensor(data, inputs, dtype=dtype)
    op_name = numeric_op.__name__[1:] if op in [ops.min, ops.max
                                                ] else numeric_op.__name__
    actual = getattr(x, op_name)()
    check_funsor(actual, inputs, Domain((), dtype), expected_data)
Ejemplo n.º 6
0
def test_unary(symbol, dims):
    sizes = {'a': 3, 'b': 4}
    shape = tuple(sizes[d] for d in dims)
    inputs = OrderedDict((d, bint(sizes[d])) for d in dims)
    dtype = 'real'
    data = rand(shape) + 0.5
    if symbol == '~':
        data = astype(data, 'uint8')
        dtype = 2
    if get_backend() != "torch" and symbol in [
            "abs", "sqrt", "exp", "log", "log1p", "sigmoid"
    ]:
        expected_data = getattr(ops, symbol)(data)
    else:
        expected_data = unary_eval(symbol, data)

    x = Tensor(data, inputs, dtype)
    actual = unary_eval(symbol, x)
    check_funsor(actual, inputs, funsor.Domain((), dtype), expected_data)