def test_beta_density(batch_shape, eager):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

    @funsor.torch.function(reals(), reals(), reals(), reals())
    def beta(concentration1, concentration0, value):
        return torch.distributions.Beta(concentration1,
                                        concentration0).log_prob(value)

    check_funsor(beta, {
        'concentration1': reals(),
        'concentration0': reals(),
        'value': reals()
    }, reals())

    concentration1 = Tensor(torch.randn(batch_shape).exp(), inputs)
    concentration0 = Tensor(torch.randn(batch_shape).exp(), inputs)
    value = Tensor(torch.rand(batch_shape), inputs)
    expected = beta(concentration1, concentration0, value)
    check_funsor(expected, inputs, reals())

    d = Variable('value', reals())
    actual = dist.Beta(concentration1, concentration0, value) if eager else \
        dist.Beta(concentration1, concentration0, d)(value=value)
    check_funsor(actual, inputs, reals())
    assert_close(actual, expected)
Esempio n. 2
0
    def __call__(self):

        # calls pyro.param so that params are exposed and constraints applied
        # should not create any new torch.Tensors after __init__
        self.initialize_params()

        N_state = self.config["sizes"]["state"]

        # initialize gamma to uniform
        gamma = Tensor(
            torch.zeros((N_state, N_state)),
            OrderedDict([("y_prev", bint(N_state))]),
        )

        N_v = self.config["sizes"]["random"]
        N_c = self.config["sizes"]["group"]
        log_prob = []

        plate_g = Tensor(torch.zeros(N_c), OrderedDict([("g", bint(N_c))]))

        # group-level random effects
        if self.config["group"]["random"] == "discrete":
            # group-level discrete effect
            e_g = Variable("e_g", bint(N_v))
            e_g_dist = plate_g + dist.Categorical(**self.params["e_g"])(value=e_g)

            log_prob.append(e_g_dist)

            eps_g = (plate_g + self.params["eps_g"]["theta"])(e_g=e_g)

        elif self.config["group"]["random"] == "continuous":
            eps_g = Variable("eps_g", reals(N_state))
            eps_g_dist = plate_g + dist.Normal(**self.params["eps_g"])(value=eps_g)

            log_prob.append(eps_g_dist)
        else:
            eps_g = to_funsor(0.)

        N_s = self.config["sizes"]["individual"]

        plate_i = Tensor(torch.zeros(N_s), OrderedDict([("i", bint(N_s))]))
        # individual-level random effects
        if self.config["individual"]["random"] == "discrete":
            # individual-level discrete effect
            e_i = Variable("e_i", bint(N_v))
            e_i_dist = plate_g + plate_i + dist.Categorical(
                **self.params["e_i"]
            )(value=e_i) * self.raggedness_masks["individual"](t=0)

            log_prob.append(e_i_dist)

            eps_i = (plate_i + plate_g + self.params["eps_i"]["theta"](e_i=e_i))

        elif self.config["individual"]["random"] == "continuous":
            eps_i = Variable("eps_i", reals(N_state))
            eps_i_dist = plate_g + plate_i + dist.Normal(**self.params["eps_i"])(value=eps_i)

            log_prob.append(eps_i_dist)
        else:
            eps_i = to_funsor(0.)

        # add group-level and individual-level random effects to gamma
        gamma = gamma + eps_g + eps_i

        N_state = self.config["sizes"]["state"]

        # we've accounted for all effects, now actually compute gamma_y
        gamma_y = gamma(y_prev="y(t=1)")

        y = Variable("y", bint(N_state))
        y_dist = plate_g + plate_i + dist.Categorical(
            probs=gamma_y.exp() / gamma_y.exp().sum()
        )(value=y)

        # observation 1: step size
        step_dist = plate_g + plate_i + dist.Gamma(
            **{k: v(y_curr=y) for k, v in self.params["step"].items()}
        )(value=self.observations["step"])

        # step size zero-inflation
        if self.config["zeroinflation"]:
            step_zi = dist.Categorical(probs=self.params["zi_step"]["zi_param"](y_curr=y))(
                value="zi_step")
            step_zi_dist = plate_g + plate_i + dist.Delta(self.config["MISSING"], 0.)(
                value=self.observations["step"])
            step_dist = (step_zi + Stack("zi_step", (step_dist, step_zi_dist))).reduce(ops.logaddexp, "zi_step")

        # observation 2: step angle
        angle_dist = plate_g + plate_i + dist.VonMises(
            **{k: v(y_curr=y) for k, v in self.params["angle"].items()}
        )(value=self.observations["angle"])

        # observation 3: dive activity
        omega_dist = plate_g + plate_i + dist.Beta(
            **{k: v(y_curr=y) for k, v in self.params["omega"].items()}
        )(value=self.observations["omega"])

        # dive activity zero-inflation
        if self.config["zeroinflation"]:
            omega_zi = dist.Categorical(probs=self.params["zi_omega"]["zi_param"](y_curr=y))(
                value="zi_omega")
            omega_zi_dist = plate_g + plate_i + dist.Delta(self.config["MISSING"], 0.)(
                value=self.observations["omega"])
            omega_dist = (omega_zi + Stack("zi_omega", (omega_dist, omega_zi_dist))).reduce(ops.logaddexp, "zi_omega")

        # finally, construct the term for parallel scan reduction
        hmm_factor = step_dist + angle_dist + omega_dist
        hmm_factor = hmm_factor * self.raggedness_masks["individual"]
        hmm_factor = hmm_factor * self.raggedness_masks["timestep"]
        # copy masking behavior of pyro.infer.TraceEnum_ELBO._compute_model_factors
        hmm_factor = hmm_factor + y_dist
        log_prob.insert(0, hmm_factor)

        return log_prob