def test_dirichlet_multinomial(hybridize: bool) -> None:
    num_samples = 8000
    dim = 3
    n_trials = 500

    alpha = np.array([1.0, 2.0, 3.0])

    distr = DirichletMultinomial(dim=3,
                                 n_trials=n_trials,
                                 alpha=mx.nd.array(alpha))
    cov = distr.variance.asnumpy()

    samples = distr.sample(num_samples)

    alpha_hat = maximum_likelihood_estimate_sgd(
        DirichletMultinomialOutput(dim=dim, n_trials=n_trials),
        samples,
        init_biases=None,
        hybridize=hybridize,
        learning_rate=PositiveFloat(0.05),
        num_epochs=PositiveInt(10),
    )

    distr = DirichletMultinomial(dim=3,
                                 n_trials=n_trials,
                                 alpha=mx.nd.array(alpha_hat))

    cov_hat = distr.variance.asnumpy()

    assert np.allclose(
        alpha_hat, alpha, atol=0.1, rtol=0.1
    ), f"alpha did not match: alpha = {alpha}, alpha_hat = {alpha_hat}"
    assert np.allclose(
        cov_hat, cov, atol=0.1, rtol=0.1
    ), f"Covariance did not match: cov = {cov}, cov_hat = {cov_hat}"
 (
     LowrankMultivariateGaussianOutput(dim=5, rank=4),
     mx.nd.random.normal(shape=(3, 4, 10)),
     [None, mx.nd.ones(shape=(3, 4, 5))],
     (3, 4),
     (5, ),
 ),
 (
     DirichletOutput(dim=5),
     mx.nd.random.gamma(shape=(3, 4, 5)),
     [None],
     (3, 4),
     (5, ),
 ),
 (
     DirichletMultinomialOutput(dim=5, n_trials=10),
     mx.nd.random.gamma(shape=(3, 4, 5)),
     [None],
     (3, 4),
     (5, ),
 ),
 (
     LaplaceOutput(),
     mx.nd.random.normal(shape=(3, 4, 5, 6)),
     [None, mx.nd.ones(shape=(3, 4, 5))],
     (3, 4, 5),
     (),
 ),
 (
     NegativeBinomialOutput(),
     mx.nd.random.normal(shape=(3, 4, 5, 6)),