def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): from aesara.tensor.subtensor import as_index_constant rng = shared(np.random.RandomState(1233532), borrow=False) indices_aet = () for i in indices: i_aet = as_index_constant(i) if not isinstance(i_aet, slice): i_aet.tag.test_value = i indices_aet += (i_aet, ) new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv( local_subtensor_rv_lift, lambda rv: rv[indices_aet], dist_op, dist_params, size, rng, ) if lifted: assert isinstance(new_out.owner.op, RandomVariable) assert all( isinstance(i.owner.op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)) for i in new_out.owner.inputs[3:] if i.owner) else: assert isinstance(new_out.owner.op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)) return f_base = function( f_inputs, dist_st, mode=no_mode, ) arg_values = [p.get_test_value() for p in f_inputs] res_base = f_base(*arg_values) res_opt = f_opt(*arg_values) np.testing.assert_allclose(res_base, res_opt, rtol=1e-3)
def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): rng = shared(np.random.RandomState(1233532), borrow=False) dist_params_aet = [] for p in dist_params: p_aet = aet.as_tensor(p).type() p_aet.tag.test_value = p dist_params_aet.append(p_aet) size_aet = [] for s in size: s_aet = iscalar() s_aet.tag.test_value = s size_aet.append(s_aet) from aesara.tensor.subtensor import as_index_constant indices_aet = () for i in indices: i_aet = as_index_constant(i) if not isinstance(i_aet, slice): i_aet.tag.test_value = i indices_aet += (i_aet,) dist_st = dist_op(*dist_params_aet, size=size_aet, rng=rng)[indices_aet] f_inputs = [ p for p in dist_params_aet + size_aet + list(indices_aet) if not isinstance(p, (slice, Constant)) ] mode = Mode( "py", EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100) ) f_opt = function( f_inputs, dist_st, mode=mode, ) (new_out,) = f_opt.maker.fgraph.outputs if lifted: assert isinstance(new_out.owner.op, RandomVariable) assert all( isinstance(i.owner.op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)) for i in new_out.owner.inputs[3:] if i.owner ) else: assert isinstance( new_out.owner.op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor) ) return f_base = function( f_inputs, dist_st, mode=no_mode, ) arg_values = [p.get_test_value() for p in f_inputs] res_base = f_base(*arg_values) res_opt = f_opt(*arg_values) np.testing.assert_allclose(res_base, res_opt, rtol=1e-3)