def test_ScanArgs(): with pytest.raises(TypeError): ScanArgs.from_node(at.ones(2).owner) hmm_model_env = create_test_hmm() scan_args = hmm_model_env["scan_args"] scan_op = hmm_model_env["scan_op"] # Make sure we can get alternate variables test_v = scan_args.outer_out_sit_sot[0] alt_test_v = scan_args.get_alt_field(test_v, "inner_out") assert alt_test_v == scan_args.inner_out_sit_sot[0] alt_test_v = scan_args.get_alt_field(test_v, "outer_in") assert alt_test_v == scan_args.outer_in_sit_sot[0] # Check the `__repr__` and `__str__` scan_args_repr = repr(scan_args) # Just make sure it doesn't err-out assert scan_args_repr.startswith("ScanArgs") # Check the properties that allow us to use # `Scan.get_oinp_iinp_iout_oout_mappings` as-is to implement # `ScanArgs.var_mappings` assert scan_args.n_nit_sot == scan_op.info.n_nit_sot assert scan_args.n_mit_mot == scan_op.info.n_mit_mot # The `scan_args` base class always clones the inner-graph; # here we make sure it doesn't (and that all the inputs are the same) assert scan_args.inputs == scan_op.inner_inputs assert scan_args.info == scan_op.info # Check that `ScanArgs.find_among_fields` works test_v = scan_op.inner_seqs(scan_op.inner_inputs)[1] field_info = scan_args.find_among_fields(test_v) assert field_info.name == "inner_in_seqs" assert field_info.index == 1 assert field_info.inner_index is None assert scan_args.inner_inputs[field_info.agg_index] == test_v test_l = scan_op.inner_non_seqs(scan_op.inner_inputs) # We didn't index this argument, so it's a `list` (i.e. bad input) field_info = scan_args.find_among_fields(test_l) assert field_info is None test_v = test_l[0] field_info = scan_args.find_among_fields(test_v) assert field_info.name == "inner_in_non_seqs" assert field_info.index == 0 assert field_info.inner_index is None assert scan_args.inner_inputs[field_info.agg_index] == test_v scan_args_copy = copy(scan_args) assert scan_args_copy is not scan_args assert scan_args_copy == scan_args assert scan_args_copy != test_v scan_args_copy.outer_in_seqs.pop() assert scan_args_copy != scan_args
def create_test_hmm(): srng = at.random.RandomStream() N_tt = at.iscalar("N") N_tt.tag.test_value = 10 M_tt = at.iscalar("M") M_tt.tag.test_value = 2 mus_tt = at.matrix("mus") mus_tt.tag.test_value = np.stack( [np.arange(0.0, 10), np.arange(0.0, -10, -1)], axis=-1).astype(aesara.config.floatX) sigmas_tt = at.ones((N_tt, )) sigmas_tt.name = "sigmas" pi_0_rv = srng.dirichlet(at.ones((M_tt, )), name="pi_0") Gamma_rv = srng.dirichlet(at.ones((M_tt, M_tt)), name="Gamma") S_0_rv = srng.categorical(pi_0_rv, name="S_0") def scan_fn(mus_t, sigma_t, S_tm1, Gamma_t): S_t = srng.categorical(Gamma_t[S_tm1], name="S_t") Y_t = srng.normal(mus_t[S_t], sigma_t, name="Y_t") return S_t, Y_t (S_rv, Y_rv), scan_updates = aesara.scan( fn=scan_fn, sequences=[mus_tt, sigmas_tt], non_sequences=[Gamma_rv], outputs_info=[{ "initial": S_0_rv, "taps": [-1] }, {}], strict=True, name="scan_rv", ) Y_rv.name = "Y_rv" scan_op = Y_rv.owner.op scan_args = ScanArgs.from_node(Y_rv.owner) Gamma_in = scan_args.inner_in_non_seqs[0] Y_t = scan_args.inner_out_nit_sot[0] mus_t = scan_args.inner_in_seqs[0] sigmas_t = scan_args.inner_in_seqs[1] S_t = scan_args.inner_out_sit_sot[0] rng_in = scan_args.inner_out_shared[0] mus_in = Y_rv.owner.inputs[1] mus_in.name = "mus_in" sigmas_in = Y_rv.owner.inputs[2] sigmas_in.name = "sigmas_in" # The output `S_rv` is really `S_rv[1:]`, so we have to extract the actual # `Scan` output: `S_rv`. S_in = S_rv.owner.inputs[0] S_in.name = "S_in" return locals()
def test_ScanArgs_basics_mit_sot(): srng = at.random.RandomStream() N_tt = at.iscalar("N") N_tt.tag.test_value = 10 M_tt = at.iscalar("M") M_tt.tag.test_value = 2 mus_tt = at.matrix("mus") mus_tt.tag.test_value = np.stack( [np.arange(0.0, 10), np.arange(0.0, -10, -1)], axis=-1).astype(aesara.config.floatX) sigmas_tt = at.ones((N_tt, )) sigmas_tt.name = "sigmas" pi_0_rv = srng.dirichlet(at.ones((M_tt, )), name="pi_0") Gamma_rv = srng.dirichlet(at.ones((M_tt, M_tt)), name="Gamma") S_0_rv = srng.categorical(pi_0_rv, name="S_0") def scan_fn(mus_t, sigma_t, S_tm2, S_tm1, Gamma_t): S_t = srng.categorical(Gamma_t[S_tm2], name="S_t") Y_t = srng.normal(mus_t[S_tm1], sigma_t, name="Y_t") return S_t, Y_t (S_rv, Y_rv), scan_updates = aesara.scan( fn=scan_fn, sequences=[mus_tt, sigmas_tt], non_sequences=[Gamma_rv], outputs_info=[{ "initial": at.stack([S_0_rv, S_0_rv]), "taps": [-2, -1] }, {}], strict=True, name="scan_rv", ) # Adding names should make output easier to read Y_rv.name = "Y_rv" # This `S_rv` outer-output is actually a `Subtensor` of the "real" output S_rv = S_rv.owner.inputs[0] S_rv.name = "S_rv" mus_in = Y_rv.owner.inputs[1] mus_in.name = "mus_in" sigmas_in = Y_rv.owner.inputs[2] sigmas_in.name = "sigmas_in" scan_args = ScanArgs.from_node(Y_rv.owner) test_v = scan_args.inner_in_mit_sot[0][1] field_info = scan_args.find_among_fields(test_v) assert field_info.name == "inner_in_mit_sot" assert field_info.index == 0 assert field_info.inner_index == 1 assert field_info.agg_index == 3 rm_info = scan_args._remove_from_fields(at.ones(2)) assert rm_info is None rm_info = scan_args._remove_from_fields(test_v) assert rm_info.name == "inner_in_mit_sot" assert rm_info.index == 0 assert rm_info.inner_index == 1 assert rm_info.agg_index == 3
def scan(*outer_inputs): scan_args = ScanArgs( list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info ) # `outer_inputs` is a list with the following composite form: # [n_steps] # + outer_in_seqs # + outer_in_mit_mot # + outer_in_mit_sot # + outer_in_sit_sot # + outer_in_shared # + outer_in_nit_sot # + outer_in_non_seqs n_steps = scan_args.n_steps seqs = scan_args.outer_in_seqs # TODO: mit_mots mit_mot_in_slices = [] mit_sot_in_slices = [] for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot): neg_taps = [abs(t) for t in tap if t < 0] pos_taps = [abs(t) for t in tap if t > 0] max_neg = max(neg_taps) if neg_taps else 0 max_pos = max(pos_taps) if pos_taps else 0 init_slice = seq[: max_neg + max_pos] mit_sot_in_slices.append(init_slice) sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot] init_carry = ( mit_mot_in_slices, mit_sot_in_slices, sit_sot_in_slices, scan_args.outer_in_shared, scan_args.outer_in_non_seqs, ) def jax_args_to_inner_scan(op, carry, x): # `carry` contains all inner-output taps, non_seqs, and shared # terms ( inner_in_mit_mot, inner_in_mit_sot, inner_in_sit_sot, inner_in_shared, inner_in_non_seqs, ) = carry # `x` contains the in_seqs inner_in_seqs = x # `inner_scan_inputs` is a list with the following composite form: # inner_in_seqs # + sum(inner_in_mit_mot, []) # + sum(inner_in_mit_sot, []) # + inner_in_sit_sot # + inner_in_shared # + inner_in_non_seqs inner_in_mit_sot_flatten = [] for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices): inner_in_mit_sot_flatten.extend(array[jnp.array(index)]) inner_scan_inputs = sum( [ inner_in_seqs, inner_in_mit_mot, inner_in_mit_sot_flatten, inner_in_sit_sot, inner_in_shared, inner_in_non_seqs, ], [], ) return inner_scan_inputs def inner_scan_outs_to_jax_outs( op, old_carry, inner_scan_outs, ): ( inner_in_mit_mot, inner_in_mit_sot, inner_in_sit_sot, inner_in_shared, inner_in_non_seqs, ) = old_carry def update_mit_sot(mit_sot, new_val): return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0) inner_out_mit_sot = [ update_mit_sot(mit_sot, new_val) for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs) ] # This should contain all inner-output taps, non_seqs, and shared # terms if not inner_in_sit_sot: inner_out_sit_sot = [] else: inner_out_sit_sot = inner_scan_outs new_carry = ( inner_in_mit_mot, inner_out_mit_sot, inner_out_sit_sot, inner_in_shared, inner_in_non_seqs, ) return new_carry def jax_inner_func(carry, x): inner_args = jax_args_to_inner_scan(op, carry, x) inner_scan_outs = [fn(*inner_args) for fn in jax_aet_inner_func] new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs) return new_carry, inner_scan_outs _, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps) # We need to prepend the initial values so that the JAX output will # match the raw `Scan` `Op` output and, thus, work with a downstream # `Subtensor` `Op` introduced by the `scan` helper function. def append_scan_out(scan_in_part, scan_out_part): return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0) if scan_args.outer_in_mit_sot: scan_out_final = [ append_scan_out(init, out) for init, out in zip(scan_args.outer_in_mit_sot, scan_out) ] elif scan_args.outer_in_sit_sot: scan_out_final = [ append_scan_out(init, out) for init, out in zip(scan_args.outer_in_sit_sot, scan_out) ] if len(scan_out_final) == 1: scan_out_final = scan_out_final[0] return scan_out_final