Пример #1
0
def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
    from aesara.tensor.subtensor import as_index_constant

    rng = shared(np.random.RandomState(1233532), borrow=False)

    indices_aet = ()
    for i in indices:
        i_aet = as_index_constant(i)
        if not isinstance(i_aet, slice):
            i_aet.tag.test_value = i
        indices_aet += (i_aet, )

    new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv(
        local_subtensor_rv_lift,
        lambda rv: rv[indices_aet],
        dist_op,
        dist_params,
        size,
        rng,
    )

    if lifted:
        assert isinstance(new_out.owner.op, RandomVariable)
        assert all(
            isinstance(i.owner.op, (AdvancedSubtensor, AdvancedSubtensor1,
                                    Subtensor))
            for i in new_out.owner.inputs[3:] if i.owner)
    else:
        assert isinstance(new_out.owner.op,
                          (AdvancedSubtensor, AdvancedSubtensor1, Subtensor))
        return

    f_base = function(
        f_inputs,
        dist_st,
        mode=no_mode,
    )

    arg_values = [p.get_test_value() for p in f_inputs]
    res_base = f_base(*arg_values)
    res_opt = f_opt(*arg_values)

    np.testing.assert_allclose(res_base, res_opt, rtol=1e-3)
Пример #2
0
def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):

    rng = shared(np.random.RandomState(1233532), borrow=False)

    dist_params_aet = []
    for p in dist_params:
        p_aet = aet.as_tensor(p).type()
        p_aet.tag.test_value = p
        dist_params_aet.append(p_aet)

    size_aet = []
    for s in size:
        s_aet = iscalar()
        s_aet.tag.test_value = s
        size_aet.append(s_aet)

    from aesara.tensor.subtensor import as_index_constant

    indices_aet = ()
    for i in indices:
        i_aet = as_index_constant(i)
        if not isinstance(i_aet, slice):
            i_aet.tag.test_value = i
        indices_aet += (i_aet,)

    dist_st = dist_op(*dist_params_aet, size=size_aet, rng=rng)[indices_aet]

    f_inputs = [
        p
        for p in dist_params_aet + size_aet + list(indices_aet)
        if not isinstance(p, (slice, Constant))
    ]

    mode = Mode(
        "py", EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100)
    )

    f_opt = function(
        f_inputs,
        dist_st,
        mode=mode,
    )

    (new_out,) = f_opt.maker.fgraph.outputs

    if lifted:
        assert isinstance(new_out.owner.op, RandomVariable)
        assert all(
            isinstance(i.owner.op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor))
            for i in new_out.owner.inputs[3:]
            if i.owner
        )
    else:
        assert isinstance(
            new_out.owner.op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)
        )
        return

    f_base = function(
        f_inputs,
        dist_st,
        mode=no_mode,
    )

    arg_values = [p.get_test_value() for p in f_inputs]
    res_base = f_base(*arg_values)
    res_opt = f_opt(*arg_values)

    np.testing.assert_allclose(res_base, res_opt, rtol=1e-3)