Example #1
0
def test_normal_infer_shape():
    M_aet = iscalar("M")
    M_aet.tag.test_value = 3
    sd_aet = scalar("sd")
    sd_aet.tag.test_value = np.array(1.0, dtype=config.floatX)

    test_params = [
        ([aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)),
          sd_aet], None),
        (
            [
                aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)),
                sd_aet
            ],
            (M_aet, ),
        ),
        (
            [
                aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)),
                sd_aet
            ],
            (2, M_aet),
        ),
        ([aet.zeros((M_aet, )), sd_aet], None),
        ([aet.zeros((M_aet, )), sd_aet], (M_aet, )),
        ([aet.zeros((M_aet, )), sd_aet], (2, M_aet)),
        ([aet.zeros((M_aet, )), aet.ones((M_aet, ))], None),
        ([aet.zeros((M_aet, )), aet.ones((M_aet, ))], (2, M_aet)),
        (
            [
                np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
                np.array([[1e-6, 2e-6]], dtype=config.floatX),
            ],
            (3, 2, 2),
        ),
        (
            [
                np.array([1], dtype=config.floatX),
                np.array([10], dtype=config.floatX)
            ],
            (1, 2),
        ),
    ]
    for args, size in test_params:
        rv = normal(*args, size=size)
        rv_shape = tuple(normal._infer_shape(size or (), args, None))
        assert tuple(get_test_value(rv_shape)) == tuple(
            get_test_value(rv).shape)
Example #2
0
def test_normal_infer_shape(M, sd, size):
    rv = normal(M, sd, size=size)
    rv_shape = list(normal._infer_shape(size or (), [M, sd], None))

    all_args = (M, sd) + size
    fn_inputs = [
        i
        for i in graph_inputs([a for a in all_args if isinstance(a, Variable)])
        if not isinstance(i, (Constant, SharedVariable))
    ]
    aesara_fn = function(fn_inputs, [at.as_tensor(o) for o in rv_shape + [rv]],
                         mode=py_mode)

    *rv_shape_val, rv_val = aesara_fn(*[
        i.tag.test_value for i in fn_inputs
        if not isinstance(i, (SharedVariable, Constant))
    ])

    assert tuple(rv_shape_val) == tuple(rv_val.shape)