Esempio n. 1
0
def test_ScanArgs():
    with pytest.raises(TypeError):
        ScanArgs.from_node(at.ones(2).owner)

    hmm_model_env = create_test_hmm()
    scan_args = hmm_model_env["scan_args"]
    scan_op = hmm_model_env["scan_op"]

    # Make sure we can get alternate variables
    test_v = scan_args.outer_out_sit_sot[0]
    alt_test_v = scan_args.get_alt_field(test_v, "inner_out")
    assert alt_test_v == scan_args.inner_out_sit_sot[0]

    alt_test_v = scan_args.get_alt_field(test_v, "outer_in")
    assert alt_test_v == scan_args.outer_in_sit_sot[0]

    # Check the `__repr__` and `__str__`
    scan_args_repr = repr(scan_args)
    # Just make sure it doesn't err-out
    assert scan_args_repr.startswith("ScanArgs")

    # Check the properties that allow us to use
    # `Scan.get_oinp_iinp_iout_oout_mappings` as-is to implement
    # `ScanArgs.var_mappings`
    assert scan_args.n_nit_sot == scan_op.info.n_nit_sot
    assert scan_args.n_mit_mot == scan_op.info.n_mit_mot
    # The `scan_args` base class always clones the inner-graph;
    # here we make sure it doesn't (and that all the inputs are the same)
    assert scan_args.inputs == scan_op.inner_inputs
    assert scan_args.info == scan_op.info

    # Check that `ScanArgs.find_among_fields` works
    test_v = scan_op.inner_seqs(scan_op.inner_inputs)[1]
    field_info = scan_args.find_among_fields(test_v)
    assert field_info.name == "inner_in_seqs"
    assert field_info.index == 1
    assert field_info.inner_index is None
    assert scan_args.inner_inputs[field_info.agg_index] == test_v

    test_l = scan_op.inner_non_seqs(scan_op.inner_inputs)
    # We didn't index this argument, so it's a `list` (i.e. bad input)
    field_info = scan_args.find_among_fields(test_l)
    assert field_info is None

    test_v = test_l[0]
    field_info = scan_args.find_among_fields(test_v)
    assert field_info.name == "inner_in_non_seqs"
    assert field_info.index == 0
    assert field_info.inner_index is None
    assert scan_args.inner_inputs[field_info.agg_index] == test_v

    scan_args_copy = copy(scan_args)
    assert scan_args_copy is not scan_args
    assert scan_args_copy == scan_args

    assert scan_args_copy != test_v
    scan_args_copy.outer_in_seqs.pop()
    assert scan_args_copy != scan_args
Esempio n. 2
0
def create_test_hmm():
    srng = at.random.RandomStream()

    N_tt = at.iscalar("N")
    N_tt.tag.test_value = 10
    M_tt = at.iscalar("M")
    M_tt.tag.test_value = 2

    mus_tt = at.matrix("mus")
    mus_tt.tag.test_value = np.stack(
        [np.arange(0.0, 10), np.arange(0.0, -10, -1)],
        axis=-1).astype(aesara.config.floatX)

    sigmas_tt = at.ones((N_tt, ))
    sigmas_tt.name = "sigmas"

    pi_0_rv = srng.dirichlet(at.ones((M_tt, )), name="pi_0")
    Gamma_rv = srng.dirichlet(at.ones((M_tt, M_tt)), name="Gamma")

    S_0_rv = srng.categorical(pi_0_rv, name="S_0")

    def scan_fn(mus_t, sigma_t, S_tm1, Gamma_t):
        S_t = srng.categorical(Gamma_t[S_tm1], name="S_t")
        Y_t = srng.normal(mus_t[S_t], sigma_t, name="Y_t")
        return S_t, Y_t

    (S_rv, Y_rv), scan_updates = aesara.scan(
        fn=scan_fn,
        sequences=[mus_tt, sigmas_tt],
        non_sequences=[Gamma_rv],
        outputs_info=[{
            "initial": S_0_rv,
            "taps": [-1]
        }, {}],
        strict=True,
        name="scan_rv",
    )
    Y_rv.name = "Y_rv"

    scan_op = Y_rv.owner.op
    scan_args = ScanArgs.from_node(Y_rv.owner)

    Gamma_in = scan_args.inner_in_non_seqs[0]
    Y_t = scan_args.inner_out_nit_sot[0]
    mus_t = scan_args.inner_in_seqs[0]
    sigmas_t = scan_args.inner_in_seqs[1]
    S_t = scan_args.inner_out_sit_sot[0]
    rng_in = scan_args.inner_out_shared[0]

    mus_in = Y_rv.owner.inputs[1]
    mus_in.name = "mus_in"
    sigmas_in = Y_rv.owner.inputs[2]
    sigmas_in.name = "sigmas_in"

    # The output `S_rv` is really `S_rv[1:]`, so we have to extract the actual
    # `Scan` output: `S_rv`.
    S_in = S_rv.owner.inputs[0]
    S_in.name = "S_in"

    return locals()
Esempio n. 3
0
def test_ScanArgs_basics_mit_sot():

    srng = at.random.RandomStream()

    N_tt = at.iscalar("N")
    N_tt.tag.test_value = 10
    M_tt = at.iscalar("M")
    M_tt.tag.test_value = 2

    mus_tt = at.matrix("mus")
    mus_tt.tag.test_value = np.stack(
        [np.arange(0.0, 10), np.arange(0.0, -10, -1)],
        axis=-1).astype(aesara.config.floatX)

    sigmas_tt = at.ones((N_tt, ))
    sigmas_tt.name = "sigmas"

    pi_0_rv = srng.dirichlet(at.ones((M_tt, )), name="pi_0")
    Gamma_rv = srng.dirichlet(at.ones((M_tt, M_tt)), name="Gamma")

    S_0_rv = srng.categorical(pi_0_rv, name="S_0")

    def scan_fn(mus_t, sigma_t, S_tm2, S_tm1, Gamma_t):
        S_t = srng.categorical(Gamma_t[S_tm2], name="S_t")
        Y_t = srng.normal(mus_t[S_tm1], sigma_t, name="Y_t")
        return S_t, Y_t

    (S_rv, Y_rv), scan_updates = aesara.scan(
        fn=scan_fn,
        sequences=[mus_tt, sigmas_tt],
        non_sequences=[Gamma_rv],
        outputs_info=[{
            "initial": at.stack([S_0_rv, S_0_rv]),
            "taps": [-2, -1]
        }, {}],
        strict=True,
        name="scan_rv",
    )
    # Adding names should make output easier to read
    Y_rv.name = "Y_rv"
    # This `S_rv` outer-output is actually a `Subtensor` of the "real" output
    S_rv = S_rv.owner.inputs[0]
    S_rv.name = "S_rv"
    mus_in = Y_rv.owner.inputs[1]
    mus_in.name = "mus_in"
    sigmas_in = Y_rv.owner.inputs[2]
    sigmas_in.name = "sigmas_in"

    scan_args = ScanArgs.from_node(Y_rv.owner)

    test_v = scan_args.inner_in_mit_sot[0][1]
    field_info = scan_args.find_among_fields(test_v)

    assert field_info.name == "inner_in_mit_sot"
    assert field_info.index == 0
    assert field_info.inner_index == 1
    assert field_info.agg_index == 3

    rm_info = scan_args._remove_from_fields(at.ones(2))
    assert rm_info is None

    rm_info = scan_args._remove_from_fields(test_v)

    assert rm_info.name == "inner_in_mit_sot"
    assert rm_info.index == 0
    assert rm_info.inner_index == 1
    assert rm_info.agg_index == 3
Esempio n. 4
0
    def scan(*outer_inputs):
        scan_args = ScanArgs(
            list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info
        )

        # `outer_inputs` is a list with the following composite form:
        # [n_steps]
        # + outer_in_seqs
        # + outer_in_mit_mot
        # + outer_in_mit_sot
        # + outer_in_sit_sot
        # + outer_in_shared
        # + outer_in_nit_sot
        # + outer_in_non_seqs
        n_steps = scan_args.n_steps
        seqs = scan_args.outer_in_seqs

        # TODO: mit_mots
        mit_mot_in_slices = []

        mit_sot_in_slices = []
        for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
            neg_taps = [abs(t) for t in tap if t < 0]
            pos_taps = [abs(t) for t in tap if t > 0]
            max_neg = max(neg_taps) if neg_taps else 0
            max_pos = max(pos_taps) if pos_taps else 0
            init_slice = seq[: max_neg + max_pos]
            mit_sot_in_slices.append(init_slice)

        sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]

        init_carry = (
            mit_mot_in_slices,
            mit_sot_in_slices,
            sit_sot_in_slices,
            scan_args.outer_in_shared,
            scan_args.outer_in_non_seqs,
        )

        def jax_args_to_inner_scan(op, carry, x):
            # `carry` contains all inner-output taps, non_seqs, and shared
            # terms
            (
                inner_in_mit_mot,
                inner_in_mit_sot,
                inner_in_sit_sot,
                inner_in_shared,
                inner_in_non_seqs,
            ) = carry

            # `x` contains the in_seqs
            inner_in_seqs = x

            # `inner_scan_inputs` is a list with the following composite form:
            # inner_in_seqs
            # + sum(inner_in_mit_mot, [])
            # + sum(inner_in_mit_sot, [])
            # + inner_in_sit_sot
            # + inner_in_shared
            # + inner_in_non_seqs
            inner_in_mit_sot_flatten = []
            for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
                inner_in_mit_sot_flatten.extend(array[jnp.array(index)])

            inner_scan_inputs = sum(
                [
                    inner_in_seqs,
                    inner_in_mit_mot,
                    inner_in_mit_sot_flatten,
                    inner_in_sit_sot,
                    inner_in_shared,
                    inner_in_non_seqs,
                ],
                [],
            )

            return inner_scan_inputs

        def inner_scan_outs_to_jax_outs(
            op,
            old_carry,
            inner_scan_outs,
        ):
            (
                inner_in_mit_mot,
                inner_in_mit_sot,
                inner_in_sit_sot,
                inner_in_shared,
                inner_in_non_seqs,
            ) = old_carry

            def update_mit_sot(mit_sot, new_val):
                return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0)

            inner_out_mit_sot = [
                update_mit_sot(mit_sot, new_val)
                for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
            ]

            # This should contain all inner-output taps, non_seqs, and shared
            # terms
            if not inner_in_sit_sot:
                inner_out_sit_sot = []
            else:
                inner_out_sit_sot = inner_scan_outs
            new_carry = (
                inner_in_mit_mot,
                inner_out_mit_sot,
                inner_out_sit_sot,
                inner_in_shared,
                inner_in_non_seqs,
            )

            return new_carry

        def jax_inner_func(carry, x):
            inner_args = jax_args_to_inner_scan(op, carry, x)
            inner_scan_outs = [fn(*inner_args) for fn in jax_aet_inner_func]
            new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
            return new_carry, inner_scan_outs

        _, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)

        # We need to prepend the initial values so that the JAX output will
        # match the raw `Scan` `Op` output and, thus, work with a downstream
        # `Subtensor` `Op` introduced by the `scan` helper function.
        def append_scan_out(scan_in_part, scan_out_part):
            return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)

        if scan_args.outer_in_mit_sot:
            scan_out_final = [
                append_scan_out(init, out)
                for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
            ]
        elif scan_args.outer_in_sit_sot:
            scan_out_final = [
                append_scan_out(init, out)
                for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
            ]

        if len(scan_out_final) == 1:
            scan_out_final = scan_out_final[0]
        return scan_out_final