def test_pymc_broadcastable():
    """Test PyMC3 to Theano conversion amid array broadcasting."""
    tt.config.compute_test_value = 'ignore'

    mu_X = tt.vector('mu_X')
    sd_X = tt.vector('sd_X')
    mu_Y = tt.vector('mu_Y')
    sd_Y = tt.vector('sd_Y')
    mu_X.tag.test_value = np.array([0.], dtype=tt.config.floatX)
    sd_X.tag.test_value = np.array([1.], dtype=tt.config.floatX)
    mu_Y.tag.test_value = np.array([1.], dtype=tt.config.floatX)
    sd_Y.tag.test_value = np.array([0.5], dtype=tt.config.floatX)

    with pm.Model() as model:
        X_rv = pm.Normal('X_rv', mu_X, sd=sd_X, shape=(1, ))
        Y_rv = pm.Normal('Y_rv', mu_Y, sd=sd_Y, shape=(1, ))
        Z_rv = pm.Normal('Z_rv',
                         X_rv + Y_rv,
                         sd=sd_X + sd_Y,
                         shape=(1, ),
                         observed=[10.])

    with pytest.warns(UserWarning):
        fgraph = model_graph(model)

    Z_rv_tt = canonicalize(fgraph, return_graph=False)

    # This will break comparison if we don't reuse it
    rng = Z_rv_tt.owner.inputs[1].owner.inputs[-1]

    mu_X_ = mt.vector('mu_X')
    sd_X_ = mt.vector('sd_X')
    mu_Y_ = mt.vector('mu_Y')
    sd_Y_ = mt.vector('sd_Y')
    tt.config.compute_test_value = 'ignore'
    X_rv_ = mt.NormalRV(mu_X_, sd_X_, (1, ), rng, name='X_rv')
    X_rv_ = mt.addbroadcast(X_rv_, 0)
    Y_rv_ = mt.NormalRV(mu_Y_, sd_Y_, (1, ), rng, name='Y_rv')
    Y_rv_ = mt.addbroadcast(Y_rv_, 0)
    Z_rv_ = mt.NormalRV(mt.add(X_rv_, Y_rv_),
                        mt.add(sd_X_, sd_Y_), (1, ),
                        rng,
                        name='Z_rv')
    obs_ = mt(Z_rv.observations)
    Z_rv_obs_ = mt.observed(obs_, Z_rv_)
    Z_rv_meta = canonicalize(Z_rv_obs_.reify(), return_graph=False)

    assert mt(Z_rv_tt) == mt(Z_rv_meta)
def test_pymc_normal_model():
    """Conduct a more in-depth test of PyMC3/Theano conversions for a specific model."""
    tt.config.compute_test_value = 'ignore'

    mu_X = tt.dscalar('mu_X')
    sd_X = tt.dscalar('sd_X')
    mu_Y = tt.dscalar('mu_Y')
    mu_X.tag.test_value = np.array(0., dtype=tt.config.floatX)
    sd_X.tag.test_value = np.array(1., dtype=tt.config.floatX)
    mu_Y.tag.test_value = np.array(1., dtype=tt.config.floatX)

    # We need something that uses transforms...
    with pm.Model() as model:
        X_rv = pm.Normal('X_rv', mu_X, sd=sd_X)
        S_rv = pm.HalfCauchy('S_rv',
                             beta=np.array(0.5, dtype=tt.config.floatX))
        Y_rv = pm.Normal('Y_rv', X_rv * S_rv, sd=S_rv)
        Z_rv = pm.Normal('Z_rv', X_rv + Y_rv, sd=sd_X, observed=10.)

    fgraph = model_graph(model, output_vars=[Z_rv])

    Z_rv_tt = canonicalize(fgraph, return_graph=False)

    # This will break comparison if we don't reuse it
    rng = Z_rv_tt.owner.inputs[1].owner.inputs[-1]

    mu_X_ = mt.dscalar('mu_X')
    sd_X_ = mt.dscalar('sd_X')
    tt.config.compute_test_value = 'ignore'
    X_rv_ = mt.NormalRV(mu_X_, sd_X_, None, rng, name='X_rv')
    S_rv_ = mt.HalfCauchyRV(np.array(0., dtype=tt.config.floatX),
                            np.array(0.5, dtype=tt.config.floatX),
                            None,
                            rng,
                            name='S_rv')
    Y_rv_ = mt.NormalRV(mt.mul(X_rv_, S_rv_), S_rv_, None, rng, name='Y_rv')
    Z_rv_ = mt.NormalRV(mt.add(X_rv_, Y_rv_), sd_X, None, rng, name='Z_rv')
    obs_ = mt(Z_rv.observations)
    Z_rv_obs_ = mt.observed(obs_, Z_rv_)

    Z_rv_meta = mt(canonicalize(Z_rv_obs_.reify(), return_graph=False))

    assert mt(Z_rv_tt) == Z_rv_meta

    # Now, let's try that with multiple outputs.
    fgraph.disown()
    fgraph = model_graph(model, output_vars=[Y_rv, Z_rv])

    assert len(fgraph.variables) == 25

    Y_new_rv = walk(Y_rv, fgraph.memo)
    S_new_rv = walk(S_rv, fgraph.memo)
    X_new_rv = walk(X_rv, fgraph.memo)
    Z_new_rv = walk(Z_rv, fgraph.memo)

    # Make sure our new vars are actually in the graph and where
    # they should be.
    assert Y_new_rv == fgraph.outputs[0]
    assert Z_new_rv == fgraph.outputs[1]
    assert X_new_rv in fgraph.variables
    assert S_new_rv in fgraph.variables
    assert isinstance(Z_new_rv.owner.op, Observed)

    # Let's only look at the variables involved in the `Z_rv` subgraph.
    Z_vars = theano.gof.graph.variables(theano.gof.graph.inputs([Z_new_rv]),
                                        [Z_new_rv])

    # Let's filter for only the `RandomVariables` with names.
    Z_vars_count = Counter([
        n.name for n in Z_vars
        if n.name and n.owner and isinstance(n.owner.op, RandomVariable)
    ])

    # Each new RV should be present and only occur once.
    assert Y_new_rv.name in Z_vars_count.keys()
    assert X_new_rv.name in Z_vars_count.keys()
    assert Z_new_rv.owner.inputs[1].name in Z_vars_count.keys()
    assert all(v == 1 for v in Z_vars_count.values())