Ejemplo n.º 1
0
def test_ScanArgs():
    with pytest.raises(TypeError):
        ScanArgs.from_node(at.ones(2).owner)

    hmm_model_env = create_test_hmm()
    scan_args = hmm_model_env["scan_args"]
    scan_op = hmm_model_env["scan_op"]

    # Make sure we can get alternate variables
    test_v = scan_args.outer_out_sit_sot[0]
    alt_test_v = scan_args.get_alt_field(test_v, "inner_out")
    assert alt_test_v == scan_args.inner_out_sit_sot[0]

    alt_test_v = scan_args.get_alt_field(test_v, "outer_in")
    assert alt_test_v == scan_args.outer_in_sit_sot[0]

    # Check the `__repr__` and `__str__`
    scan_args_repr = repr(scan_args)
    # Just make sure it doesn't err-out
    assert scan_args_repr.startswith("ScanArgs")

    # Check the properties that allow us to use
    # `Scan.get_oinp_iinp_iout_oout_mappings` as-is to implement
    # `ScanArgs.var_mappings`
    assert scan_args.n_nit_sot == scan_op.info.n_nit_sot
    assert scan_args.n_mit_mot == scan_op.info.n_mit_mot
    # The `scan_args` base class always clones the inner-graph;
    # here we make sure it doesn't (and that all the inputs are the same)
    assert scan_args.inputs == scan_op.inner_inputs
    assert scan_args.info == scan_op.info

    # Check that `ScanArgs.find_among_fields` works
    test_v = scan_op.inner_seqs(scan_op.inner_inputs)[1]
    field_info = scan_args.find_among_fields(test_v)
    assert field_info.name == "inner_in_seqs"
    assert field_info.index == 1
    assert field_info.inner_index is None
    assert scan_args.inner_inputs[field_info.agg_index] == test_v

    test_l = scan_op.inner_non_seqs(scan_op.inner_inputs)
    # We didn't index this argument, so it's a `list` (i.e. bad input)
    field_info = scan_args.find_among_fields(test_l)
    assert field_info is None

    test_v = test_l[0]
    field_info = scan_args.find_among_fields(test_v)
    assert field_info.name == "inner_in_non_seqs"
    assert field_info.index == 0
    assert field_info.inner_index is None
    assert scan_args.inner_inputs[field_info.agg_index] == test_v

    scan_args_copy = copy(scan_args)
    assert scan_args_copy is not scan_args
    assert scan_args_copy == scan_args

    assert scan_args_copy != test_v
    scan_args_copy.outer_in_seqs.pop()
    assert scan_args_copy != scan_args
Ejemplo n.º 2
0
def create_test_hmm():
    srng = at.random.RandomStream()

    N_tt = at.iscalar("N")
    N_tt.tag.test_value = 10
    M_tt = at.iscalar("M")
    M_tt.tag.test_value = 2

    mus_tt = at.matrix("mus")
    mus_tt.tag.test_value = np.stack(
        [np.arange(0.0, 10), np.arange(0.0, -10, -1)],
        axis=-1).astype(aesara.config.floatX)

    sigmas_tt = at.ones((N_tt, ))
    sigmas_tt.name = "sigmas"

    pi_0_rv = srng.dirichlet(at.ones((M_tt, )), name="pi_0")
    Gamma_rv = srng.dirichlet(at.ones((M_tt, M_tt)), name="Gamma")

    S_0_rv = srng.categorical(pi_0_rv, name="S_0")

    def scan_fn(mus_t, sigma_t, S_tm1, Gamma_t):
        S_t = srng.categorical(Gamma_t[S_tm1], name="S_t")
        Y_t = srng.normal(mus_t[S_t], sigma_t, name="Y_t")
        return S_t, Y_t

    (S_rv, Y_rv), scan_updates = aesara.scan(
        fn=scan_fn,
        sequences=[mus_tt, sigmas_tt],
        non_sequences=[Gamma_rv],
        outputs_info=[{
            "initial": S_0_rv,
            "taps": [-1]
        }, {}],
        strict=True,
        name="scan_rv",
    )
    Y_rv.name = "Y_rv"

    scan_op = Y_rv.owner.op
    scan_args = ScanArgs.from_node(Y_rv.owner)

    Gamma_in = scan_args.inner_in_non_seqs[0]
    Y_t = scan_args.inner_out_nit_sot[0]
    mus_t = scan_args.inner_in_seqs[0]
    sigmas_t = scan_args.inner_in_seqs[1]
    S_t = scan_args.inner_out_sit_sot[0]
    rng_in = scan_args.inner_out_shared[0]

    mus_in = Y_rv.owner.inputs[1]
    mus_in.name = "mus_in"
    sigmas_in = Y_rv.owner.inputs[2]
    sigmas_in.name = "sigmas_in"

    # The output `S_rv` is really `S_rv[1:]`, so we have to extract the actual
    # `Scan` output: `S_rv`.
    S_in = S_rv.owner.inputs[0]
    S_in.name = "S_in"

    return locals()
Ejemplo n.º 3
0
def test_ScanArgs_basics_mit_sot():

    srng = at.random.RandomStream()

    N_tt = at.iscalar("N")
    N_tt.tag.test_value = 10
    M_tt = at.iscalar("M")
    M_tt.tag.test_value = 2

    mus_tt = at.matrix("mus")
    mus_tt.tag.test_value = np.stack(
        [np.arange(0.0, 10), np.arange(0.0, -10, -1)],
        axis=-1).astype(aesara.config.floatX)

    sigmas_tt = at.ones((N_tt, ))
    sigmas_tt.name = "sigmas"

    pi_0_rv = srng.dirichlet(at.ones((M_tt, )), name="pi_0")
    Gamma_rv = srng.dirichlet(at.ones((M_tt, M_tt)), name="Gamma")

    S_0_rv = srng.categorical(pi_0_rv, name="S_0")

    def scan_fn(mus_t, sigma_t, S_tm2, S_tm1, Gamma_t):
        S_t = srng.categorical(Gamma_t[S_tm2], name="S_t")
        Y_t = srng.normal(mus_t[S_tm1], sigma_t, name="Y_t")
        return S_t, Y_t

    (S_rv, Y_rv), scan_updates = aesara.scan(
        fn=scan_fn,
        sequences=[mus_tt, sigmas_tt],
        non_sequences=[Gamma_rv],
        outputs_info=[{
            "initial": at.stack([S_0_rv, S_0_rv]),
            "taps": [-2, -1]
        }, {}],
        strict=True,
        name="scan_rv",
    )
    # Adding names should make output easier to read
    Y_rv.name = "Y_rv"
    # This `S_rv` outer-output is actually a `Subtensor` of the "real" output
    S_rv = S_rv.owner.inputs[0]
    S_rv.name = "S_rv"
    mus_in = Y_rv.owner.inputs[1]
    mus_in.name = "mus_in"
    sigmas_in = Y_rv.owner.inputs[2]
    sigmas_in.name = "sigmas_in"

    scan_args = ScanArgs.from_node(Y_rv.owner)

    test_v = scan_args.inner_in_mit_sot[0][1]
    field_info = scan_args.find_among_fields(test_v)

    assert field_info.name == "inner_in_mit_sot"
    assert field_info.index == 0
    assert field_info.inner_index == 1
    assert field_info.agg_index == 3

    rm_info = scan_args._remove_from_fields(at.ones(2))
    assert rm_info is None

    rm_info = scan_args._remove_from_fields(test_v)

    assert rm_info.name == "inner_in_mit_sot"
    assert rm_info.index == 0
    assert rm_info.inner_index == 1
    assert rm_info.agg_index == 3