Esempio n. 1
0
def model(detections, args):
    noise_scale = pyro.param('noise_scale')
    objects = pyro.param('objects_loc').squeeze(-1)
    num_detections, = detections.shape
    max_num_objects, = objects.shape

    # Existence part.
    p_exists = args.expected_num_objects / max_num_objects
    with pyro.plate('objects_plate', max_num_objects):
        exists = pyro.sample('exists', dist.Bernoulli(p_exists))
        with poutine.mask(mask=exists.bool()):
            pyro.sample('objects', dist.Normal(0., 1.), obs=objects)

    # Assignment part.
    p_fake = args.num_fake_detections / num_detections
    with pyro.plate('detections_plate', num_detections):
        assign_probs = torch.empty(max_num_objects + 1)
        assign_probs[:-1] = (1 - p_fake) / max_num_objects
        assign_probs[-1] = p_fake
        assign = pyro.sample('assign', dist.Categorical(logits=assign_probs))
        is_fake = (assign == assign.shape[-1] - 1)
        objects_plus_bogus = torch.zeros(max_num_objects + 1)
        objects_plus_bogus[:max_num_objects] = objects
        real_dist = dist.Normal(objects_plus_bogus[assign], noise_scale)
        fake_dist = dist.Normal(0., 1.)
        pyro.sample('detections',
                    dist.MaskedMixture(is_fake, real_dist, fake_dist),
                    obs=detections)
Esempio n. 2
0
def test_masked_mixture_multivariate(sample_shape, batch_shape):
    event_shape = torch.Size((8,))
    component0 = dist.MultivariateNormal(
        torch.zeros(event_shape), torch.eye(event_shape[0])
    )
    component1 = dist.Uniform(
        torch.zeros(event_shape), torch.ones(event_shape)
    ).to_event(1)
    if batch_shape:
        component0 = component0.expand_by(batch_shape)
        component1 = component1.expand_by(batch_shape)
    mask = torch.empty(batch_shape).bernoulli_(0.5).bool()
    d = dist.MaskedMixture(mask, component0, component1)
    assert d.batch_shape == batch_shape
    assert d.event_shape == event_shape

    assert d.sample().shape == batch_shape + event_shape
    assert d.mean.shape == batch_shape + event_shape
    assert d.variance.shape == batch_shape + event_shape
    x = d.sample(sample_shape)
    assert x.shape == sample_shape + batch_shape + event_shape

    log_prob = d.log_prob(x)
    assert log_prob.shape == sample_shape + batch_shape
    assert not torch_isnan(log_prob)
    log_prob_0 = component0.log_prob(x)
    log_prob_1 = component1.log_prob(x)
    mask = mask.expand(sample_shape + batch_shape)
    assert_equal(log_prob[mask], log_prob_1[mask])
    assert_equal(log_prob[~mask], log_prob_0[~mask])
Esempio n. 3
0
def test_expand(sample_shape, batch_shape, event_shape):
    ones_shape = torch.Size((1,) * len(batch_shape))
    mask = torch.empty(ones_shape).bernoulli_(0.5).bool()
    zero = torch.zeros(ones_shape + event_shape)
    d0 = dist.Uniform(zero - 2, zero + 1).to_event(len(event_shape))
    d1 = dist.Uniform(zero - 1, zero + 2).to_event(len(event_shape))
    d = dist.MaskedMixture(mask, d0, d1)

    assert d.sample().shape == ones_shape + event_shape
    assert d.mean.shape == ones_shape + event_shape
    assert d.variance.shape == ones_shape + event_shape
    assert d.sample(sample_shape).shape == sample_shape + ones_shape + event_shape

    assert (
        d.expand(sample_shape + batch_shape).batch_shape == sample_shape + batch_shape
    )
    assert (
        d.expand(sample_shape + batch_shape).sample().shape
        == sample_shape + batch_shape + event_shape
    )
    assert (
        d.expand(sample_shape + batch_shape).mean.shape
        == sample_shape + batch_shape + event_shape
    )
    assert (
        d.expand(sample_shape + batch_shape).variance.shape
        == sample_shape + batch_shape + event_shape
    )
Esempio n. 4
0
def test_broadcast(mask_shape, component0_shape, component1_shape, value_shape):
    mask = torch.empty(torch.Size(mask_shape)).bernoulli_(0.5).bool()
    component0 = dist.Normal(torch.zeros(component0_shape), 1.0)
    component1 = dist.Exponential(torch.ones(component1_shape))
    value = torch.ones(value_shape)

    d = dist.MaskedMixture(mask, component0, component1)
    d_shape = broadcast_shape(mask_shape, component0_shape, component1_shape)
    assert d.batch_shape == d_shape

    log_prob_shape = broadcast_shape(d_shape, value_shape)
    assert d.log_prob(value).shape == log_prob_shape
Esempio n. 5
0
def test_masked_mixture_univariate(component0, component1, sample_shape, batch_shape):
    if batch_shape:
        component0 = component0.expand_by(batch_shape)
        component1 = component1.expand_by(batch_shape)
    mask = torch.empty(batch_shape).bernoulli_(0.5).bool()
    d = dist.MaskedMixture(mask, component0, component1)
    assert d.batch_shape == batch_shape
    assert d.event_shape == ()

    assert d.sample().shape == batch_shape
    assert d.mean.shape == batch_shape
    assert d.variance.shape == batch_shape
    x = d.sample(sample_shape)
    assert x.shape == sample_shape + batch_shape

    log_prob = d.log_prob(x)
    assert log_prob.shape == sample_shape + batch_shape
    assert not torch_isnan(log_prob)
    log_prob_0 = component0.log_prob(x)
    log_prob_1 = component1.log_prob(x)
    mask = mask.expand(sample_shape + batch_shape)
    assert_equal(log_prob[mask], log_prob_1[mask])
    assert_equal(log_prob[~mask], log_prob_0[~mask])
Esempio n. 6
0
def model_generic(config):
    """Hierarchical mixed-effects hidden markov model"""

    MISSING = config["MISSING"]
    N_v = config["sizes"]["random"]
    N_state = config["sizes"]["state"]

    # initialize group-level random effect parameterss
    if config["group"]["random"] == "discrete":
        probs_e_g = pyro.param("probs_e_group",
                               lambda: torch.randn((N_v, )).abs(),
                               constraint=constraints.simplex)
        theta_g = pyro.param("theta_group", lambda: torch.randn(
            (N_v, N_state**2)))
    elif config["group"]["random"] == "continuous":
        loc_g = torch.zeros((N_state**2, ))
        scale_g = torch.ones((N_state**2, ))

    # initialize individual-level random effect parameters
    N_c = config["sizes"]["group"]
    if config["individual"]["random"] == "discrete":
        probs_e_i = pyro.param("probs_e_individual",
                               lambda: torch.randn((
                                   N_c,
                                   N_v,
                               )).abs(),
                               constraint=constraints.simplex)
        theta_i = pyro.param("theta_individual", lambda: torch.randn(
            (N_c, N_v, N_state**2)))
    elif config["individual"]["random"] == "continuous":
        loc_i = torch.zeros((
            N_c,
            N_state**2,
        ))
        scale_i = torch.ones((
            N_c,
            N_state**2,
        ))

    # initialize likelihood parameters
    # observation 1: step size (step ~ Gamma)
    step_zi_param = pyro.param("step_zi_param", lambda: torch.ones(
        (N_state, 2)))
    step_concentration = pyro.param("step_param_concentration",
                                    lambda: torch.randn((N_state, )).abs(),
                                    constraint=constraints.positive)
    step_rate = pyro.param("step_param_rate",
                           lambda: torch.randn((N_state, )).abs(),
                           constraint=constraints.positive)

    # observation 2: step angle (angle ~ VonMises)
    angle_concentration = pyro.param("angle_param_concentration",
                                     lambda: torch.randn((N_state, )).abs(),
                                     constraint=constraints.positive)
    angle_loc = pyro.param("angle_param_loc", lambda: torch.randn(
        (N_state, )).abs())

    # observation 3: dive activity (omega ~ Beta)
    omega_zi_param = pyro.param("omega_zi_param", lambda: torch.ones(
        (N_state, 2)))
    omega_concentration0 = pyro.param("omega_param_concentration0",
                                      lambda: torch.randn((N_state, )).abs(),
                                      constraint=constraints.positive)
    omega_concentration1 = pyro.param("omega_param_concentration1",
                                      lambda: torch.randn((N_state, )).abs(),
                                      constraint=constraints.positive)

    # initialize gamma to uniform
    gamma = torch.zeros((N_state**2, ))

    N_c = config["sizes"]["group"]
    with pyro.plate("group", N_c, dim=-1):

        # group-level random effects
        if config["group"]["random"] == "discrete":
            # group-level discrete effect
            e_g = pyro.sample("e_g", dist.Categorical(probs_e_g))
            eps_g = Vindex(theta_g)[..., e_g, :]
        elif config["group"]["random"] == "continuous":
            eps_g = pyro.sample(
                "eps_g",
                dist.Normal(loc_g, scale_g).to_event(1),
            )  # infer={"num_samples": 10})
        else:
            eps_g = 0.

        # add group-level random effect to gamma
        gamma = gamma + eps_g

        N_s = config["sizes"]["individual"]
        with pyro.plate(
                "individual", N_s,
                dim=-2), poutine.mask(mask=config["individual"]["mask"]):

            # individual-level random effects
            if config["individual"]["random"] == "discrete":
                # individual-level discrete effect
                e_i = pyro.sample("e_i", dist.Categorical(probs_e_i))
                eps_i = Vindex(theta_i)[..., e_i, :]
                # assert eps_i.shape[-3:] == (1, N_c, N_state ** 2) and eps_i.shape[0] == N_v
            elif config["individual"]["random"] == "continuous":
                eps_i = pyro.sample(
                    "eps_i",
                    dist.Normal(loc_i, scale_i).to_event(1),
                )  # infer={"num_samples": 10})
            else:
                eps_i = 0.

            # add individual-level random effect to gamma
            gamma = gamma + eps_i

            y = torch.tensor(0).long()

            N_t = config["sizes"]["timesteps"]
            for t in pyro.markov(range(N_t)):
                with poutine.mask(mask=config["timestep"]["mask"][..., t]):
                    gamma_t = gamma  # per-timestep variable

                    # finally, reshape gamma as batch of transition matrices
                    gamma_t = gamma_t.reshape(
                        tuple(gamma_t.shape[:-1]) + (N_state, N_state))

                    # we've accounted for all effects, now actually compute gamma_y
                    gamma_y = Vindex(gamma_t)[..., y, :]
                    y = pyro.sample("y_{}".format(t),
                                    dist.Categorical(logits=gamma_y))

                    # observation 1: step size
                    step_dist = dist.Gamma(
                        concentration=Vindex(step_concentration)[..., y],
                        rate=Vindex(step_rate)[..., y])

                    # zero-inflation with MaskedMixture
                    step_zi = Vindex(step_zi_param)[..., y, :]
                    step_zi_mask = pyro.sample(
                        "step_zi_{}".format(t),
                        dist.Categorical(logits=step_zi),
                        obs=(config["observations"]["step"][...,
                                                            t] == MISSING))
                    step_zi_zero_dist = dist.Delta(v=torch.tensor(MISSING))
                    step_zi_dist = dist.MaskedMixture(step_zi_mask, step_dist,
                                                      step_zi_zero_dist)

                    pyro.sample("step_{}".format(t),
                                step_zi_dist,
                                obs=config["observations"]["step"][..., t])

                    # observation 2: step angle
                    angle_dist = dist.VonMises(
                        concentration=Vindex(angle_concentration)[..., y],
                        loc=Vindex(angle_loc)[..., y])
                    pyro.sample("angle_{}".format(t),
                                angle_dist,
                                obs=config["observations"]["angle"][..., t])

                    # observation 3: dive activity
                    omega_dist = dist.Beta(
                        concentration0=Vindex(omega_concentration0)[..., y],
                        concentration1=Vindex(omega_concentration1)[..., y])

                    # zero-inflation with MaskedMixture
                    omega_zi = Vindex(omega_zi_param)[..., y, :]
                    omega_zi_mask = pyro.sample(
                        "omega_zi_{}".format(t),
                        dist.Categorical(logits=omega_zi),
                        obs=(config["observations"]["omega"][...,
                                                             t] == MISSING))

                    omega_zi_zero_dist = dist.Delta(v=torch.tensor(MISSING))
                    omega_zi_dist = dist.MaskedMixture(omega_zi_mask,
                                                       omega_dist,
                                                       omega_zi_zero_dist)

                    pyro.sample("omega_{}".format(t),
                                omega_zi_dist,
                                obs=config["observations"]["omega"][..., t])