Ejemplo n.º 1
0
def test_jax_scan_tap_output():

    a_aet = scalar("a")

    def input_step_fn(y_tm1, y_tm3, a):
        y_tm1.name = "y_tm1"
        y_tm3.name = "y_tm3"
        res = (y_tm1 + y_tm3) * a
        res.name = "y_t"
        return res

    y_scan_aet, _ = scan(
        fn=input_step_fn,
        outputs_info=[
            {
                "initial":
                aet.as_tensor_variable(np.r_[-1.0, 1.3,
                                             0.0].astype(config.floatX)),
                "taps": [-1, -3],
            },
        ],
        non_sequences=[a_aet],
        n_steps=10,
        name="y_scan",
    )
    y_scan_aet.name = "y"
    y_scan_aet.owner.inputs[0].name = "y_all"

    out_fg = FunctionGraph([a_aet], [y_scan_aet])

    test_input_vals = [np.array(10.0).astype(config.floatX)]
    compare_jax_and_py(out_fg, test_input_vals)
Ejemplo n.º 2
0
    def outer_step(*args):
        # Separate the received arguments into their respective (seq, outputs
        # from previous iterations, nonseqs) categories
        i_sequences = list(args[:len(o_sequences)])
        i_prev_outputs = list(args[len(o_sequences):-len(o_nonsequences)])
        i_non_sequences = list(args[-len(o_nonsequences):])
        i_outputs_infos = i_prev_outputs + [
            None,
        ] * len(new_nitsots)

        # Call the user-provided function with the proper arguments
        results, updates = scan(
            fn=fn,
            sequences=i_sequences[:-1],
            outputs_info=i_outputs_infos,
            non_sequences=i_non_sequences,
            name=name + "_inner",
            n_steps=i_sequences[-1],
        )
        if not isinstance(results, list):
            results = [results]

        # Keep only the last timestep of every output but keep all the updates
        if not isinstance(results, list):
            return results[-1], updates
        else:
            return [r[-1] for r in results], updates
Ejemplo n.º 3
0
    def test_one_sequence_one_output_weights_gpu2(self):
        def f_rnn(u_t, x_tm1, W_in, W):
            return u_t * W_in + x_tm1 * W

        u = fvector("u")
        x0 = fscalar("x0")
        W_in = fscalar("win")
        W = fscalar("w")
        output, updates = scan(
            f_rnn,
            u,
            x0,
            [W_in, W],
            n_steps=None,
            truncate_gradient=-1,
            go_backwards=False,
            mode=mode_with_gpu,
        )

        f2 = aesara.function(
            [u, x0, W_in, W],
            output,
            updates=updates,
            allow_input_downcast=True,
            mode=mode_with_gpu,
        )

        # get random initial values
        rng = np.random.default_rng(utt.fetch_seed())
        v_u = rng.uniform(size=(4, ), low=-5.0, high=5.0)
        v_x0 = rng.uniform()
        W = rng.uniform()
        W_in = rng.uniform()

        # compute the output in numpy
        v_out = np.zeros((4, ))
        v_out[0] = v_u[0] * W_in + v_x0 * W
        for step in range(1, 4):
            v_out[step] = v_u[step] * W_in + v_out[step - 1] * W

        aesara_values = f2(v_u, v_x0, W_in, W)
        utt.assert_allclose(aesara_values, v_out)

        topo = f2.maker.fgraph.toposort()
        assert sum([isinstance(node.op, HostFromGpu) for node in topo]) == 1
        assert sum([isinstance(node.op, GpuFromHost) for node in topo]) == 4

        scan_node = [
            node for node in topo if isinstance(node.op, scan.op.Scan)
        ]
        assert len(scan_node) == 1
        scan_node = scan_node[0]
        scan_node_topo = scan_node.op.fn.maker.fgraph.toposort()

        # check that there is no gpu transfer in the inner loop.
        assert any(isinstance(node.op, GpuElemwise) for node in scan_node_topo)
        assert not any(
            isinstance(node.op, HostFromGpu) for node in scan_node_topo)
        assert not any(
            isinstance(node.op, GpuFromHost) for node in scan_node_topo)
Ejemplo n.º 4
0
    def test_gpu3_mixture_dtype_outputs(self):
        def f_rnn(u_t, x_tm1, W_in, W):
            return (u_t * W_in + x_tm1 * W, aet.cast(u_t + x_tm1, "int64"))

        u = fvector("u")
        x0 = fscalar("x0")
        W_in = fscalar("win")
        W = fscalar("w")
        output, updates = scan(
            f_rnn,
            u,
            [x0, None],
            [W_in, W],
            n_steps=None,
            truncate_gradient=-1,
            go_backwards=False,
            mode=self.mode_with_gpu,
        )

        f2 = aesara.function(
            [u, x0, W_in, W],
            output,
            updates=updates,
            allow_input_downcast=True,
            mode=self.mode_with_gpu,
        )

        # get random initial values
        rng = np.random.RandomState(utt.fetch_seed())
        v_u = rng.uniform(size=(4, ), low=-5.0, high=5.0)
        v_x0 = rng.uniform()
        W = rng.uniform()
        W_in = rng.uniform()

        # compute the output in numpy
        v_out1 = np.zeros((4, ))
        v_out2 = np.zeros((4, ), dtype="int64")
        v_out1[0] = v_u[0] * W_in + v_x0 * W
        v_out2[0] = v_u[0] + v_x0
        for step in range(1, 4):
            v_out1[step] = v_u[step] * W_in + v_out1[step - 1] * W
            v_out2[step] = np.int64(v_u[step] + v_out1[step - 1])

        aesara_out1, aesara_out2 = f2(v_u, v_x0, W_in, W)
        utt.assert_allclose(aesara_out1, v_out1)
        utt.assert_allclose(aesara_out2, v_out2)

        topo = f2.maker.fgraph.toposort()
        scan_node = [node for node in topo if isinstance(node.op, Scan)]
        assert len(scan_node) == 1
        scan_node = scan_node[0]
        assert self.is_scan_on_gpu(scan_node)
Ejemplo n.º 5
0
def test_gradient_scan():
    # Test for a crash when using MRG inside scan and taking the gradient
    # See https://groups.google.com/d/msg/theano-dev/UbcYyU5m-M8/UO9UgXqnQP0J
    aesara_rng = MRG_RandomStream(10)
    w = shared(np.ones(1, dtype="float32"))

    def one_step(x):
        return x + aesara_rng.uniform((1, ), dtype="float32") * w

    x = vector(dtype="float32")
    values, updates = scan(one_step, outputs_info=x, n_steps=10)
    gw = grad(aet_sum(values[-1]), w)
    f = function([x], gw)
    f(np.arange(1, dtype="float32"))
Ejemplo n.º 6
0
def test_gradient_scan(mode):
    aesara_rng = MRG_RandomStream(10)
    w = shared(np.ones(1, dtype="float32"))

    def one_step(x):
        return x + aesara_rng.uniform((1, ), dtype="float32") * w

    x = vector(dtype="float32")
    values, updates = scan(one_step, outputs_info=x, n_steps=10)
    gw = grad(at_sum(values[-1]), w)
    f = function([x], gw, mode=mode)
    assert np.allclose(
        f(np.arange(1, dtype=np.float32)),
        np.array([0.13928187], dtype=np.float32),
        rtol=1e6,
    )
Ejemplo n.º 7
0
def test_simple_shared_mrg_random():
    aesara_rng = MRG_RandomStream(10)

    values, updates = scan(
        lambda: aesara_rng.uniform((2, ), -1, 1),
        [],
        [],
        [],
        n_steps=5,
        truncate_gradient=-1,
        go_backwards=False,
    )
    my_f = function([], values, updates=updates, allow_input_downcast=True)

    # Just check for run-time errors
    my_f()
    my_f()
Ejemplo n.º 8
0
 def setup_method(self):
     self.k = iscalar("k")
     self.A = vector("A")
     result, _ = scan(
         fn=lambda prior_result, A: prior_result * A,
         outputs_info=aet.ones_like(self.A),
         non_sequences=self.A,
         n_steps=self.k,
     )
     result_check, _ = scan_checkpoints(
         fn=lambda prior_result, A: prior_result * A,
         outputs_info=aet.ones_like(self.A),
         non_sequences=self.A,
         n_steps=self.k,
         save_every_N=100,
     )
     self.result = result[-1]
     self.result_check = result_check[-1]
     self.grad_A = aesara.grad(self.result.sum(), self.A)
     self.grad_A_check = aesara.grad(self.result_check.sum(), self.A)
Ejemplo n.º 9
0
    def test_gibbs_chain(self):
        rng = np.random.RandomState(utt.fetch_seed())
        v_vsample = np.array(
            rng.binomial(
                1,
                0.5,
                size=(3, 20),
            ),
            dtype="float32",
        )
        vsample = aesara.shared(v_vsample)
        trng = aesara.sandbox.rng_mrg.MRG_RandomStream(utt.fetch_seed())

        def f(vsample_tm1):
            return (
                trng.binomial(vsample_tm1.shape, n=1, p=0.3, dtype="float32") *
                vsample_tm1)

        aesara_vsamples, updates = scan(
            f,
            [],
            vsample,
            [],
            n_steps=10,
            truncate_gradient=-1,
            go_backwards=False,
            mode=self.mode_with_gpu,
        )
        my_f = aesara.function(
            [],
            aesara_vsamples[-1],
            updates=updates,
            allow_input_downcast=True,
            mode=self.mode_with_gpu,
        )

        # I leave this to tested by debugmode, this test was anyway more of
        # doest the graph compile kind of test
        my_f()
Ejemplo n.º 10
0
    def test_memory_reuse_gpudimshuffle(self):
        # Test the memory pre-allocation feature in scan when one output is
        # the result of a GpuDimshuffle (because an optimization in
        # GpuDimshuffle can cause issues with the memory pre-allocation
        # where it falsely thinks that a pre-allocated memory region has
        # been used when it hasn't).
        def inner_fn(seq1, recurrent_out):
            temp = seq1 + recurrent_out.sum()
            output1 = temp.dimshuffle(1, 0)
            output2 = temp.sum() + recurrent_out
            return output1, output2

        input1 = ftensor3()
        init = ftensor3()
        outputs_info = [None, init]

        out, _ = scan(
            inner_fn,
            sequences=[input1],
            outputs_info=outputs_info,
            mode=self.mode_with_gpu,
        )

        out1 = out[0].flatten()
        out2 = out[1].flatten()

        fct = aesara.function([input1, init], [out1, out2],
                              mode=self.mode_with_gpu)

        output = fct(np.ones((2, 1, 1), dtype="float32"),
                     np.ones((1, 1, 1), dtype="float32"))

        expected_output = (
            np.array([2, 4], dtype="float32"),
            np.array([3, 7], dtype="float32"),
        )
        utt.assert_allclose(output, expected_output)
Ejemplo n.º 11
0
    def test_gpu_memory_usage(self):
        # This test validates that the memory usage of the defined aesara
        # function is reasonnable when executed on the GPU. It checks for
        # a bug in which one of scan's optimization was not applied which
        # made the scan node compute large and unnecessary outputs which
        # brought memory usage on the GPU to ~12G.

        # Dimensionality of input and output data (not one-hot coded)
        n_in = 100
        n_out = 100
        # Number of neurons in hidden layer
        n_hid = 4000

        # Number of minibatches
        mb_size = 2
        # Time steps in minibatch
        mb_length = 200

        # Define input variables
        xin = ftensor3(name="xin")
        yout = ftensor3(name="yout")

        # Initialize the network parameters
        U = aesara.shared(np.zeros((n_in, n_hid), dtype="float32"),
                          name="W_xin_to_l1")
        V = aesara.shared(np.zeros((n_hid, n_hid), dtype="float32"),
                          name="W_l1_to_l1")
        W = aesara.shared(np.zeros((n_hid, n_out), dtype="float32"),
                          name="W_l1_to_l2")
        nparams = [U, V, W]

        # Build the forward pass
        l1_base = dot(xin, U)

        def scan_l(baseline, last_step):
            return baseline + dot(last_step, V)

        zero_output = aet.alloc(np.asarray(0.0, dtype="float32"), mb_size,
                                n_hid)

        l1_out, _ = scan(
            scan_l,
            sequences=[l1_base],
            outputs_info=[zero_output],
            mode=self.mode_with_gpu_nodebug,
        )

        l2_out = dot(l1_out, W)

        # Compute the cost and take the gradient wrt params
        cost = tt_sum((l2_out - yout)**2)
        grads = aesara.grad(cost, nparams)
        updates = list(zip(nparams, (n - g for n, g in zip(nparams, grads))))

        # Compile the aesara function
        feval_backprop = aesara.function([xin, yout],
                                         cost,
                                         updates=updates,
                                         mode=self.mode_with_gpu_nodebug)

        # Validate that the PushOutScanOutput optimization has been applied
        # by checking the number of outputs of the grad Scan node in the
        # compiled function.
        nodes = feval_backprop.maker.fgraph.toposort()
        scan_nodes = [n for n in nodes if isinstance(n.op, Scan)]

        # The grad scan is always the 2nd one according to toposort. If the
        # optimization has been applied, it has 2 outputs, otherwise 3.
        grad_scan_node = scan_nodes[1]
        assert len(grad_scan_node.outputs) == 2, len(grad_scan_node.outputs)

        # Call the aesara function to ensure the absence of a memory error
        feval_backprop(
            np.zeros((mb_length, mb_size, n_in), dtype="float32"),
            np.zeros((mb_length, mb_size, n_out), dtype="float32"),
        )
Ejemplo n.º 12
0
    def test_one_sequence_one_output_weights_gpu1(self):
        def f_rnn(u_t, x_tm1, W_in, W):
            return u_t * W_in + x_tm1 * W

        u = fvector("u")
        x0 = fscalar("x0")
        W_in = fscalar("win")
        W = fscalar("w")

        # The following line is needed to have the first case being used
        # Otherwise, it is the second that is tested.
        mode = self.mode_with_gpu.excluding("InputToGpuOptimizer")
        output, updates = scan(
            f_rnn,
            u,
            x0,
            [W_in, W],
            n_steps=None,
            truncate_gradient=-1,
            go_backwards=False,
            mode=mode,
        )

        output = self.gpu_backend.gpu_from_host(output)
        f2 = aesara.function(
            [u, x0, W_in, W],
            output,
            updates=updates,
            allow_input_downcast=True,
            mode=self.mode_with_gpu,
        )

        # get random initial values
        rng = np.random.RandomState(utt.fetch_seed())
        v_u = rng.uniform(size=(4, ), low=-5.0, high=5.0)
        v_x0 = rng.uniform()
        W = rng.uniform()
        W_in = rng.uniform()

        v_u = np.asarray(v_u, dtype="float32")
        v_x0 = np.asarray(v_x0, dtype="float32")
        W = np.asarray(W, dtype="float32")
        W_in = np.asarray(W_in, dtype="float32")

        # compute the output in numpy
        v_out = np.zeros((4, ))
        v_out[0] = v_u[0] * W_in + v_x0 * W
        for step in range(1, 4):
            v_out[step] = v_u[step] * W_in + v_out[step - 1] * W
        aesara_values = f2(v_u, v_x0, W_in, W)
        utt.assert_allclose(aesara_values, v_out)

        # TO DEL
        topo = f2.maker.fgraph.toposort()
        scan_node = [node for node in topo if isinstance(node.op, Scan)]
        assert len(scan_node) == 1
        scan_node = scan_node[0]

        topo = f2.maker.fgraph.toposort()
        assert (sum([
            isinstance(node.op, self.gpu_backend.HostFromGpu) for node in topo
        ]) == 0)
        assert (sum([
            isinstance(node.op, self.gpu_backend.GpuFromHost) for node in topo
        ]) == 4)

        scan_node = [node for node in topo if isinstance(node.op, Scan)]
        assert len(scan_node) == 1
        scan_node = scan_node[0]
        scan_node_topo = scan_node.op.fn.maker.fgraph.toposort()

        # check that there is no gpu transfer in the inner loop.
        assert any([
            isinstance(node.op, self.gpu_backend.GpuElemwise)
            for node in scan_node_topo
        ])
        assert not any([
            isinstance(node.op, self.gpu_backend.HostFromGpu)
            for node in scan_node_topo
        ])
        assert not any([
            isinstance(node.op, self.gpu_backend.GpuFromHost)
            for node in scan_node_topo
        ])
Ejemplo n.º 13
0
def scan_checkpoints(
    fn,
    sequences=None,
    outputs_info=None,
    non_sequences=None,
    name="checkpointscan_fn",
    n_steps=None,
    save_every_N=10,
    padding=True,
):
    """Scan function that uses less memory, but is more restrictive.

    In :func:`~aesara.scan`, if you compute the gradient of the output
    with respect to the input, you will have to store the intermediate
    results at each time step, which can be prohibitively huge. This
    function allows to do ``save_every_N`` steps of forward computations
    without storing the intermediate results, and to recompute them during
    the gradient computation.

    Notes
    -----
    Current assumptions:

    * Every sequence has the same length.
    * If ``n_steps`` is specified, it has the same value as the length of
      any sequence.
    * The value of ``save_every_N`` divides the number of steps the scan
      will run without remainder.
    * Only singly-recurrent and non-recurrent outputs are used.
      No multiple recurrences.
    * Only the last timestep of any output will ever be used.

    Parameters
    ----------
    fn
        ``fn`` is a function that describes the operations involved in one
        step of ``scan``. See the documentation of :func:`~aesara.scan`
        for more information.

    sequences
        ``sequences`` is the list of Aesara variables or dictionaries
        describing the sequences ``scan`` has to iterate over. All
        sequences must be the same length in this version of ``scan``.

    outputs_info
        ``outputs_info`` is the list of Aesara variables or dictionaries
        describing the initial state of the outputs computed
        recurrently.

    non_sequences
        ``non_sequences`` is the list of arguments that are passed to
        ``fn`` at each steps. One can opt to exclude variable
        used in ``fn`` from this list as long as they are part of the
        computational graph, though for clarity we encourage not to do so.

    n_steps
        ``n_steps`` is the number of steps to iterate given as an int
        or Aesara scalar (> 0). If any of the input sequences do not have
        enough elements, scan will raise an error. If n_steps is not provided,
        ``scan`` will figure out the amount of steps it should run given its
        input sequences.

    save_every_N
        ``save_every_N`` is the number of steps to go without storing
        the computations of ``scan`` (ie they will have to be recomputed
        during the gradient computation).

    padding
        If the length of the sequences is not a multiple of ``save_every_N``,
        the sequences will be zero padded to make this version of ``scan``
        work properly, but will also result in a memory copy. It can be
        avoided by setting ``padding`` to False, but you need to make
        sure the length of the sequences is a multiple of ``save_every_N``.

    Returns
    -------
    tuple
        Tuple of the form ``(outputs, updates)`` as in :func:`~aesara.scan`, but
        with a small change: It only contain the output at each
        ``save_every_N`` step. The time steps that are not returned by
        this function will be recomputed during the gradient computation
        (if any).

    See Also
    --------
    :func:`~aesara.scan`: Looping in Aesara.

    """
    # Standardize the format of input arguments
    if sequences is None:
        sequences = []
    elif not isinstance(sequences, list):
        sequences = [sequences]

    if not isinstance(outputs_info, list):
        outputs_info = [outputs_info]

    if non_sequences is None:
        non_sequences = []
    elif not isinstance(non_sequences, list):
        non_sequences = [non_sequences]

    # Check that outputs_info has no taps:
    for element in outputs_info:
        if isinstance(element, dict) and "taps" in element:
            raise RuntimeError("scan_checkpoints doesn't work with taps.")

    # Determine how many steps the original scan would run
    if n_steps is None:
        n_steps = sequences[0].shape[0]

    # Compute the number of steps of the outer scan
    o_n_steps = at.cast(ceil(n_steps / save_every_N), "int64")

    # Compute the number of steps of the inner scan
    i_n_steps = save_every_N * at.ones((o_n_steps, ), "int64")
    mod = n_steps % save_every_N
    last_n_steps = at.switch(eq(mod, 0), save_every_N, mod)
    i_n_steps = set_subtensor(i_n_steps[-1], last_n_steps)

    # Pad the sequences if needed
    if padding:
        # Since padding could be an empty tensor, Join returns a view of s.
        join = Join(view=0)
        for i, s in enumerate(sequences):
            n = s.shape[0] % save_every_N
            z = at.zeros((n, s.shape[1:]), dtype=s.dtype)
            sequences[i] = join(0, [s, z])

    # Establish the input variables of the outer scan
    o_sequences = [
        s.reshape(
            [s.shape[0] / save_every_N, save_every_N] +
            [s.shape[i] for i in range(1, s.ndim)],
            s.ndim + 1,
        ) for s in sequences
    ]
    o_sequences.append(i_n_steps)
    new_nitsots = [i for i in outputs_info if i is None]
    o_nonsequences = non_sequences

    def outer_step(*args):
        # Separate the received arguments into their respective (seq, outputs
        # from previous iterations, nonseqs) categories
        i_sequences = list(args[:len(o_sequences)])
        i_prev_outputs = list(args[len(o_sequences):-len(o_nonsequences)])
        i_non_sequences = list(args[-len(o_nonsequences):])
        i_outputs_infos = i_prev_outputs + [
            None,
        ] * len(new_nitsots)

        # Call the user-provided function with the proper arguments
        results, updates = scan(
            fn=fn,
            sequences=i_sequences[:-1],
            outputs_info=i_outputs_infos,
            non_sequences=i_non_sequences,
            name=name + "_inner",
            n_steps=i_sequences[-1],
        )
        if not isinstance(results, list):
            results = [results]

        # Keep only the last timestep of every output but keep all the updates
        if not isinstance(results, list):
            return results[-1], updates
        else:
            return [r[-1] for r in results], updates

    results, updates = scan(
        fn=outer_step,
        sequences=o_sequences,
        outputs_info=outputs_info,
        non_sequences=o_nonsequences,
        name=name + "_outer",
        n_steps=o_n_steps,
        allow_gc=True,
    )

    return results, updates
Ejemplo n.º 14
0
def test_jax_scan_multiple_output():
    """Test a scan implementation of a SEIR model.

    SEIR model definition:
    S[t+1] = S[t] - B[t]
    E[t+1] = E[t] +B[t] - C[t]
    I[t+1] = I[t+1] + C[t] - D[t]

    B[t] ~ Binom(S[t], beta)
    C[t] ~ Binom(E[t], gamma)
    D[t] ~ Binom(I[t], delta)
    """
    def binomln(n, k):
        return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1)

    def binom_log_prob(n, p, value):
        return binomln(n, value) + value * log(p) + (n - value) * log(1 - p)

    # sequences
    aet_C = ivector("C_t")
    aet_D = ivector("D_t")
    # outputs_info (initial conditions)
    st0 = lscalar("s_t0")
    et0 = lscalar("e_t0")
    it0 = lscalar("i_t0")
    logp_c = scalar("logp_c")
    logp_d = scalar("logp_d")
    # non_sequences
    beta = scalar("beta")
    gamma = scalar("gamma")
    delta = scalar("delta")

    # TODO: Use random streams when their JAX conversions are implemented.
    # trng = aesara.tensor.random.RandomStream(1234)

    def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma,
                      delta):
        # bt0 = trng.binomial(n=st0, p=beta)
        bt0 = st0 * beta
        bt0 = bt0.astype(st0.dtype)

        logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype)
        logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype)

        st1 = st0 - bt0
        et1 = et0 + bt0 - ct0
        it1 = it0 + ct0 - dt0
        return st1, et1, it1, logp_c1, logp_d1

    (st, et, it, logp_c_all, logp_d_all), _ = scan(
        fn=seir_one_step,
        sequences=[aet_C, aet_D],
        outputs_info=[st0, et0, it0, logp_c, logp_d],
        non_sequences=[beta, gamma, delta],
    )
    st.name = "S_t"
    et.name = "E_t"
    it.name = "I_t"
    logp_c_all.name = "C_t_logp"
    logp_d_all.name = "D_t_logp"

    out_fg = FunctionGraph(
        [aet_C, aet_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
        [st, et, it, logp_c_all, logp_d_all],
    )

    s0, e0, i0 = 100, 50, 25
    logp_c0 = np.array(0.0, dtype=config.floatX)
    logp_d0 = np.array(0.0, dtype=config.floatX)
    beta_val, gamma_val, delta_val = [
        np.array(val, dtype=config.floatX)
        for val in [0.277792, 0.135330, 0.108753]
    ]
    C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32)
    D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32)

    test_input_vals = [
        C,
        D,
        s0,
        e0,
        i0,
        logp_c0,
        logp_d0,
        beta_val,
        gamma_val,
        delta_val,
    ]
    compare_jax_and_py(out_fg, test_input_vals)