def test_delta_defaults():
    v = Variable('v', reals())
    log_density = Variable('log_density', reals())
    assert isinstance(dist.Delta(v, log_density), dist.Delta)
    value = Variable('value', reals())
    assert dist.Delta(v, log_density, 'value') is dist.Delta(
        v, log_density, value)
def test_delta_density(batch_shape, event_shape):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

    @funsor.torch.function(reals(*event_shape), reals(), reals(*event_shape),
                           reals())
    def delta(v, log_density, value):
        eq = (v == value)
        for _ in range(len(event_shape)):
            eq = eq.all(dim=-1)
        return eq.type(v.dtype).log() + log_density

    check_funsor(
        delta, {
            'v': reals(*event_shape),
            'log_density': reals(),
            'value': reals(*event_shape)
        }, reals())

    v = Tensor(torch.randn(batch_shape + event_shape), inputs)
    log_density = Tensor(torch.randn(batch_shape).exp(), inputs)
    for value in [v, Tensor(torch.randn(batch_shape + event_shape), inputs)]:
        expected = delta(v, log_density, value)
        check_funsor(expected, inputs, reals())

        actual = dist.Delta(v, log_density, value)
        check_funsor(actual, inputs, reals())
        assert_close(actual, expected)
Example #3
0
    def _forward_funsor(self, features, trip_counts):
        total_hours = len(features)
        observed_hours, num_origins, num_destins = trip_counts.shape
        assert observed_hours == total_hours
        assert num_origins == self.num_stations
        assert num_destins == self.num_stations
        n = self.num_stations
        gate_rate = funsor.Variable("gate_rate_t",
                                    reals(observed_hours, 2 * n * n))["time"]

        @funsor.torch.function(reals(2 * n * n), (reals(n, n, 2), reals(n, n)))
        def unpack_gate_rate(gate_rate):
            batch_shape = gate_rate.shape[:-1]
            gate, rate = gate_rate.reshape(batch_shape + (2, n, n)).unbind(-3)
            gate = gate.sigmoid().clamp(min=0.01, max=0.99)
            rate = bounded_exp(rate, bound=1e4)
            gate = torch.stack((1 - gate, gate), dim=-1)
            return gate, rate

        # Create a Gaussian latent dynamical system.
        init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist = \
            self._dynamics(features[:observed_hours])
        init = dist_to_funsor(init_dist)(value="state")
        trans = matrix_and_mvn_to_funsor(trans_matrix, trans_dist, ("time", ),
                                         "state", "state(time=1)")
        obs = matrix_and_mvn_to_funsor(obs_matrix, obs_dist, ("time", ),
                                       "state(time=1)", "gate_rate")

        # Compute dynamic prior over gate_rate.
        prior = trans + obs(gate_rate=gate_rate)
        prior = MarkovProduct(ops.logaddexp, ops.add, prior, "time",
                              {"state": "state(time=1)"})
        prior += init
        prior = prior.reduce(ops.logaddexp, {"state", "state(time=1)"})

        # Compute zero-inflated Poisson likelihood.
        gate, rate = unpack_gate_rate(gate_rate)
        likelihood = fdist.Categorical(gate["origin", "destin"], value="gated")
        trip_counts = tensor_to_funsor(trip_counts,
                                       ("time", "origin", "destin"))
        likelihood += funsor.Stack(
            "gated",
            (fdist.Poisson(rate["origin", "destin"], value=trip_counts),
             fdist.Delta(0, value=trip_counts)))
        likelihood = likelihood.reduce(ops.logaddexp, "gated")
        likelihood = likelihood.reduce(ops.add, {"time", "origin", "destin"})

        assert set(prior.inputs) == {"gate_rate_t"}, prior.inputs
        assert set(likelihood.inputs) == {"gate_rate_t"}, likelihood.inputs
        return prior, likelihood
def test_delta_delta():
    v = Variable('v', reals(2))
    point = Tensor(torch.randn(2))
    log_density = Tensor(torch.tensor(0.5))
    d = dist.Delta(point, log_density, v)
    assert d is Delta('v', point, log_density)
Example #5
0
    def __call__(self):

        # calls pyro.param so that params are exposed and constraints applied
        # should not create any new torch.Tensors after __init__
        self.initialize_params()

        N_state = self.config["sizes"]["state"]

        # initialize gamma to uniform
        gamma = Tensor(
            torch.zeros((N_state, N_state)),
            OrderedDict([("y_prev", bint(N_state))]),
        )

        N_v = self.config["sizes"]["random"]
        N_c = self.config["sizes"]["group"]
        log_prob = []

        plate_g = Tensor(torch.zeros(N_c), OrderedDict([("g", bint(N_c))]))

        # group-level random effects
        if self.config["group"]["random"] == "discrete":
            # group-level discrete effect
            e_g = Variable("e_g", bint(N_v))
            e_g_dist = plate_g + dist.Categorical(**self.params["e_g"])(value=e_g)

            log_prob.append(e_g_dist)

            eps_g = (plate_g + self.params["eps_g"]["theta"])(e_g=e_g)

        elif self.config["group"]["random"] == "continuous":
            eps_g = Variable("eps_g", reals(N_state))
            eps_g_dist = plate_g + dist.Normal(**self.params["eps_g"])(value=eps_g)

            log_prob.append(eps_g_dist)
        else:
            eps_g = to_funsor(0.)

        N_s = self.config["sizes"]["individual"]

        plate_i = Tensor(torch.zeros(N_s), OrderedDict([("i", bint(N_s))]))
        # individual-level random effects
        if self.config["individual"]["random"] == "discrete":
            # individual-level discrete effect
            e_i = Variable("e_i", bint(N_v))
            e_i_dist = plate_g + plate_i + dist.Categorical(
                **self.params["e_i"]
            )(value=e_i) * self.raggedness_masks["individual"](t=0)

            log_prob.append(e_i_dist)

            eps_i = (plate_i + plate_g + self.params["eps_i"]["theta"](e_i=e_i))

        elif self.config["individual"]["random"] == "continuous":
            eps_i = Variable("eps_i", reals(N_state))
            eps_i_dist = plate_g + plate_i + dist.Normal(**self.params["eps_i"])(value=eps_i)

            log_prob.append(eps_i_dist)
        else:
            eps_i = to_funsor(0.)

        # add group-level and individual-level random effects to gamma
        gamma = gamma + eps_g + eps_i

        N_state = self.config["sizes"]["state"]

        # we've accounted for all effects, now actually compute gamma_y
        gamma_y = gamma(y_prev="y(t=1)")

        y = Variable("y", bint(N_state))
        y_dist = plate_g + plate_i + dist.Categorical(
            probs=gamma_y.exp() / gamma_y.exp().sum()
        )(value=y)

        # observation 1: step size
        step_dist = plate_g + plate_i + dist.Gamma(
            **{k: v(y_curr=y) for k, v in self.params["step"].items()}
        )(value=self.observations["step"])

        # step size zero-inflation
        if self.config["zeroinflation"]:
            step_zi = dist.Categorical(probs=self.params["zi_step"]["zi_param"](y_curr=y))(
                value="zi_step")
            step_zi_dist = plate_g + plate_i + dist.Delta(self.config["MISSING"], 0.)(
                value=self.observations["step"])
            step_dist = (step_zi + Stack("zi_step", (step_dist, step_zi_dist))).reduce(ops.logaddexp, "zi_step")

        # observation 2: step angle
        angle_dist = plate_g + plate_i + dist.VonMises(
            **{k: v(y_curr=y) for k, v in self.params["angle"].items()}
        )(value=self.observations["angle"])

        # observation 3: dive activity
        omega_dist = plate_g + plate_i + dist.Beta(
            **{k: v(y_curr=y) for k, v in self.params["omega"].items()}
        )(value=self.observations["omega"])

        # dive activity zero-inflation
        if self.config["zeroinflation"]:
            omega_zi = dist.Categorical(probs=self.params["zi_omega"]["zi_param"](y_curr=y))(
                value="zi_omega")
            omega_zi_dist = plate_g + plate_i + dist.Delta(self.config["MISSING"], 0.)(
                value=self.observations["omega"])
            omega_dist = (omega_zi + Stack("zi_omega", (omega_dist, omega_zi_dist))).reduce(ops.logaddexp, "zi_omega")

        # finally, construct the term for parallel scan reduction
        hmm_factor = step_dist + angle_dist + omega_dist
        hmm_factor = hmm_factor * self.raggedness_masks["individual"]
        hmm_factor = hmm_factor * self.raggedness_masks["timestep"]
        # copy masking behavior of pyro.infer.TraceEnum_ELBO._compute_model_factors
        hmm_factor = hmm_factor + y_dist
        log_prob.insert(0, hmm_factor)

        return log_prob
Example #6
0
def test_bart(analytic_kl):
    global call_count
    call_count = 0

    with interpretation(reflect):
        q = Independent(
            Independent(
                Contraction(
                    ops.nullop,
                    ops.add,
                    frozenset(),
                    (
                        Tensor(
                            torch.tensor(
                                [[
                                    -0.6077086925506592, -1.1546266078948975,
                                    -0.7021151781082153, -0.5303535461425781,
                                    -0.6365622282028198, -1.2423288822174072,
                                    -0.9941254258155823, -0.6287292242050171
                                ],
                                 [
                                     -0.6987162828445435, -1.0875964164733887,
                                     -0.7337473630905151, -0.4713417589664459,
                                     -0.6674002408981323, -1.2478348016738892,
                                     -0.8939017057418823, -0.5238542556762695
                                 ]],
                                dtype=torch.float32),  # noqa
                            (
                                (
                                    'time_b4',
                                    bint(2),
                                ),
                                (
                                    '_event_1_b2',
                                    bint(8),
                                ),
                            ),
                            'real'),
                        Gaussian(
                            torch.tensor([
                                [[-0.3536059558391571], [-0.21779225766658783],
                                 [0.2840439975261688], [0.4531521499156952],
                                 [-0.1220812276005745], [-0.05519985035061836],
                                 [0.10932210087776184], [0.6656699776649475]],
                                [[-0.39107921719551086], [
                                    -0.20241987705230713
                                ], [0.2170514464378357], [0.4500560462474823],
                                 [0.27945515513420105], [-0.0490039587020874],
                                 [-0.06399798393249512], [0.846565842628479]]
                            ],
                                         dtype=torch.float32),  # noqa
                            torch.tensor([
                                [[[1.984686255455017]], [[0.6699360013008118]],
                                 [[1.6215802431106567]], [[2.372016668319702]],
                                 [[1.77385413646698]], [[0.526767373085022]],
                                 [[0.8722561597824097]], [[2.1879124641418457]]
                                 ],
                                [[[1.6996612548828125]], [[
                                    0.7535632252693176
                                ]], [[1.4946647882461548]],
                                 [[2.642792224884033]], [[1.7301604747772217]],
                                 [[0.5203893780708313]], [[1.055436372756958]],
                                 [[2.8370864391326904]]]
                            ],
                                         dtype=torch.float32),  # noqa
                            (
                                (
                                    'time_b4',
                                    bint(2),
                                ),
                                (
                                    '_event_1_b2',
                                    bint(8),
                                ),
                                (
                                    'value_b1',
                                    reals(),
                                ),
                            )),
                    )),
                'gate_rate_b3',
                '_event_1_b2',
                'value_b1'),
            'gate_rate_t',
            'time_b4',
            'gate_rate_b3')
        p_prior = Contraction(
            ops.logaddexp,
            ops.add,
            frozenset({'state(time=1)_b11', 'state_b10'}),
            (
                MarkovProduct(
                    ops.logaddexp,
                    ops.add,
                    Contraction(
                        ops.nullop,
                        ops.add,
                        frozenset(),
                        (
                            Tensor(
                                torch.tensor(2.7672932147979736,
                                             dtype=torch.float32), (), 'real'),
                            Gaussian(
                                torch.tensor([-0.0, -0.0, 0.0, 0.0],
                                             dtype=torch.float32),
                                torch.tensor([[
                                    98.01002502441406, 0.0, -99.0000228881836,
                                    -0.0
                                ],
                                              [
                                                  0.0, 98.01002502441406, -0.0,
                                                  -99.0000228881836
                                              ],
                                              [
                                                  -99.0000228881836, -0.0,
                                                  100.0000228881836, 0.0
                                              ],
                                              [
                                                  -0.0, -99.0000228881836, 0.0,
                                                  100.0000228881836
                                              ]],
                                             dtype=torch.float32),  # noqa
                                (
                                    (
                                        'state_b7',
                                        reals(2, ),
                                    ),
                                    (
                                        'state(time=1)_b8',
                                        reals(2, ),
                                    ),
                                )),
                            Subs(
                                AffineNormal(
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                0.03488487750291824,
                                                0.07356668263673782,
                                                0.19946961104869843,
                                                0.5386509299278259,
                                                -0.708323061466217,
                                                0.24411526322364807,
                                                -0.20855577290058136,
                                                -0.2421337217092514
                                            ],
                                             [
                                                 0.41762110590934753,
                                                 0.5272183418273926,
                                                 -0.49835553765296936,
                                                 -0.0363837406039238,
                                                 -0.0005282597267068923,
                                                 0.2704298794269562,
                                                 -0.155222088098526,
                                                 -0.44802337884902954
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        (),
                                        'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                -0.003566693514585495,
                                                -0.2848514914512634,
                                                0.037103548645973206,
                                                0.12648648023605347,
                                                -0.18501518666744232,
                                                -0.20899859070777893,
                                                0.04121830314397812,
                                                0.0054807960987091064
                                            ],
                                             [
                                                 0.0021788496524095535,
                                                 -0.18700894713401794,
                                                 0.08187370002269745,
                                                 0.13554862141609192,
                                                 -0.10477752983570099,
                                                 -0.20848378539085388,
                                                 -0.01393645629286766,
                                                 0.011670656502246857
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        ((
                                            'time_b9',
                                            bint(2),
                                        ), ),
                                        'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                0.5974780917167664,
                                                0.864071786403656,
                                                1.0236268043518066,
                                                0.7147538065910339,
                                                0.7423890233039856,
                                                0.9462157487869263,
                                                1.2132389545440674,
                                                1.0596832036972046
                                            ],
                                             [
                                                 0.5787821412086487,
                                                 0.9178534150123596,
                                                 0.9074794054031372,
                                                 0.6600189208984375,
                                                 0.8473222255706787,
                                                 0.8426999449729919,
                                                 1.194266438484192,
                                                 1.0471148490905762
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        ((
                                            'time_b9',
                                            bint(2),
                                        ), ),
                                        'real'),
                                    Variable('state(time=1)_b8', reals(2, )),
                                    Variable('gate_rate_b6', reals(8, ))),
                                ((
                                    'gate_rate_b6',
                                    Binary(
                                        ops.GetitemOp(0),
                                        Variable('gate_rate_t', reals(2, 8)),
                                        Variable('time_b9', bint(2))),
                                ), )),
                        )),
                    Variable('time_b9', bint(2)),
                    frozenset({('state_b7', 'state(time=1)_b8')}),
                    frozenset({('state(time=1)_b8', 'state(time=1)_b11'),
                               ('state_b7', 'state_b10')})),  # noqa
                Subs(
                    dist.MultivariateNormal(
                        Tensor(torch.tensor([0.0, 0.0], dtype=torch.float32),
                               (), 'real'),
                        Tensor(
                            torch.tensor([[10.0, 0.0], [0.0, 10.0]],
                                         dtype=torch.float32),
                            (), 'real'), Variable('value_b5', reals(2, ))), ((
                                'value_b5',
                                Variable('state_b10', reals(2, )),
                            ), )),
            ))
        p_likelihood = Contraction(
            ops.add,
            ops.nullop,
            frozenset({'time_b17', 'destin_b16', 'origin_b15'}),
            (
                Contraction(
                    ops.logaddexp,
                    ops.add,
                    frozenset({'gated_b14'}),
                    (
                        dist.Categorical(
                            Binary(
                                ops.GetitemOp(0),
                                Binary(
                                    ops.GetitemOp(0),
                                    Subs(
                                        Function(
                                            unpack_gate_rate_0, reals(2, 2, 2),
                                            (Variable('gate_rate_b12',
                                                      reals(8, )), )),
                                        ((
                                            'gate_rate_b12',
                                            Binary(
                                                ops.GetitemOp(0),
                                                Variable(
                                                    'gate_rate_t', reals(2,
                                                                         8)),
                                                Variable('time_b17', bint(2))),
                                        ), )), Variable('origin_b15',
                                                        bint(2))),
                                Variable('destin_b16', bint(2))),
                            Variable('gated_b14', bint(2))),
                        Stack(
                            'gated_b14',
                            (
                                dist.Poisson(
                                    Binary(
                                        ops.GetitemOp(0),
                                        Binary(
                                            ops.GetitemOp(0),
                                            Subs(
                                                Function(
                                                    unpack_gate_rate_1,
                                                    reals(2, 2), (Variable(
                                                        'gate_rate_b13',
                                                        reals(8, )), )),
                                                ((
                                                    'gate_rate_b13',
                                                    Binary(
                                                        ops.GetitemOp(0),
                                                        Variable(
                                                            'gate_rate_t',
                                                            reals(2, 8)),
                                                        Variable(
                                                            'time_b17',
                                                            bint(2))),
                                                ), )),
                                            Variable('origin_b15', bint(2))),
                                        Variable('destin_b16', bint(2))),
                                    Tensor(
                                        torch.tensor(
                                            [[[1.0, 1.0], [5.0, 0.0]],
                                             [[0.0, 6.0], [19.0, 3.0]]],
                                            dtype=torch.float32),  # noqa
                                        (
                                            (
                                                'time_b17',
                                                bint(2),
                                            ),
                                            (
                                                'origin_b15',
                                                bint(2),
                                            ),
                                            (
                                                'destin_b16',
                                                bint(2),
                                            ),
                                        ),
                                        'real')),
                                dist.Delta(
                                    Tensor(
                                        torch.tensor(0.0, dtype=torch.float32),
                                        (), 'real'),
                                    Tensor(
                                        torch.tensor(0.0, dtype=torch.float32),
                                        (), 'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[[1.0, 1.0], [5.0, 0.0]],
                                             [[0.0, 6.0], [19.0, 3.0]]],
                                            dtype=torch.float32),  # noqa
                                        (
                                            (
                                                'time_b17',
                                                bint(2),
                                            ),
                                            (
                                                'origin_b15',
                                                bint(2),
                                            ),
                                            (
                                                'destin_b16',
                                                bint(2),
                                            ),
                                        ),
                                        'real')),
                            )),
                    )), ))

    if analytic_kl:
        exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t")
        with interpretation(monte_carlo):
            approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t")
        elbo = exact_part + approx_part
    else:
        p = p_prior + p_likelihood
        with interpretation(monte_carlo):
            elbo = Integrate(q, p - q, "gate_rate_t")

    assert isinstance(elbo, Tensor), elbo.pretty()
    assert call_count == 1