Ejemplo n.º 1
0
def model_3D(data):
    # Sampling mu and kappa
    psi_1 = pyro.sample("psi_1", dist.VonMises(0, 0.1))
    phi_2 = pyro.sample("phi_2", dist.VonMises(0, 0.1))    
    psi_2 = pyro.sample("psi_2", dist.VonMises(0, 0.1))    
    phi_3 = pyro.sample("phi_3", dist.VonMises(0, 0.1))
    das = torch.tensor([[[0., psi_1, np.pi]],[[phi_2, psi_2, np.pi]],[[phi_3, 0., 0.]]])
    coords_3d = pnerf(das)
    
    _3a_x = coords_3d[2,:,0]
    _4a_x = coords_3d[3,:,0]
    _5a_x = coords_3d[4,:,0]
    _6a_x = coords_3d[5,:,0]
    _7a_x = coords_3d[6,:,0]
    _8a_x = coords_3d[7,:,0]
    _9a_x = coords_3d[8,:,0]
    
    _3a_y = coords_3d[2,:,1]
    _4a_y = coords_3d[3,:,1]
    _5a_y = coords_3d[4,:,1]
    _6a_y = coords_3d[5,:,1]
    _7a_y = coords_3d[6,:,1]
    _8a_y = coords_3d[7,:,1]
    _9a_y = coords_3d[8,:,1]
    
    _3a_z = coords_3d[2,:,2]
    _4a_z = coords_3d[3,:,2]
    _5a_z = coords_3d[4,:,2]
    _6a_z = coords_3d[5,:,2]
    _7a_z = coords_3d[6,:,2]
    _8a_z = coords_3d[7,:,2]
    _9a_z = coords_3d[8,:,2]
    
    # Looping over the observed data in an independant manner
    with pyro.plate('3D_coords'):
        pyro.sample("obs_3D_3a_x", dist.Normal(_3a_x, 0.001), obs=data[2,:,0])
        pyro.sample("obs_3D_4a_x", dist.Normal(_4a_x, 0.001), obs=data[3,:,0])
        pyro.sample("obs_3D_5a_x", dist.Normal(_5a_x, 0.001), obs=data[4,:,0])
        pyro.sample("obs_3D_6a_x", dist.Normal(_6a_x, 0.001), obs=data[5,:,0])
        pyro.sample("obs_3D_7a_x", dist.Normal(_7a_x, 0.001), obs=data[6,:,0])
        pyro.sample("obs_3D_8a_x", dist.Normal(_8a_x, 0.001), obs=data[7,:,0])
        pyro.sample("obs_3D_9a_x", dist.Normal(_9a_x, 0.001), obs=data[8,:,0])
        
        pyro.sample("obs_3D_3a_y", dist.Normal(_3a_y, 0.001), obs=data[2,:,1])
        pyro.sample("obs_3D_4a_y", dist.Normal(_4a_y, 0.001), obs=data[3,:,1])
        pyro.sample("obs_3D_5a_y", dist.Normal(_5a_y, 0.001), obs=data[4,:,1])
        pyro.sample("obs_3D_6a_y", dist.Normal(_6a_y, 0.001), obs=data[5,:,1])
        pyro.sample("obs_3D_7a_y", dist.Normal(_7a_y, 0.001), obs=data[6,:,1])
        pyro.sample("obs_3D_8a_y", dist.Normal(_8a_y, 0.001), obs=data[7,:,1])
        pyro.sample("obs_3D_9a_y", dist.Normal(_9a_y, 0.001), obs=data[8,:,1])
        
        pyro.sample("obs_3D_3a_z", dist.Normal(_3a_z, 0.001), obs=data[2,:,2])
        pyro.sample("obs_3D_4a_z", dist.Normal(_4a_z, 0.001), obs=data[3,:,2])
        pyro.sample("obs_3D_5a_z", dist.Normal(_5a_z, 0.001), obs=data[4,:,2])
        pyro.sample("obs_3D_6a_z", dist.Normal(_6a_z, 0.001), obs=data[5,:,2])
        pyro.sample("obs_3D_7a_z", dist.Normal(_7a_z, 0.001), obs=data[6,:,2])
        pyro.sample("obs_3D_8a_z", dist.Normal(_8a_z, 0.001), obs=data[7,:,2])
        pyro.sample("obs_3D_9a_z", dist.Normal(_9a_z, 0.001), obs=data[8,:,2])
Ejemplo n.º 2
0
def guide_3D(data):    
    # Hyperparameters    
    mu_psi_1    = pyro.param('mu_psi_1', torch.tensor(0.)) 
    kappa_psi_1 = pyro.param('kappa_psi_1', torch.tensor(0.01), constraint=constraints.positive)
    
    mu_phi_2    = pyro.param('mu_phi_2', torch.tensor(0.)) 
    kappa_phi_2 = pyro.param('kappa_phi_2', torch.tensor(0.01), constraint=constraints.positive)
    
    mu_psi_2    = pyro.param('mu_psi_2', torch.tensor(0.)) 
    kappa_psi_2 = pyro.param('kappa_psi_2', torch.tensor(0.01), constraint=constraints.positive)
    
    mu_phi_3    = pyro.param('mu_phi_3', torch.tensor(0.)) 
    kappa_phi_3 = pyro.param('kappa_phi_3', torch.tensor(0.01), constraint=constraints.positive)
       
    # Sampling mu and kappa
    psi_1 = pyro.sample("psi_1", dist.VonMises(mu_psi_1, 10 + kappa_psi_1))
    phi_2 = pyro.sample("phi_2", dist.VonMises(mu_phi_2, 10 + kappa_phi_2))    
    psi_2 = pyro.sample("psi_2", dist.VonMises(mu_psi_2, 10 + kappa_psi_2))    
    phi_3 = pyro.sample("phi_3", dist.VonMises(mu_phi_3, 10 + kappa_phi_3))
Ejemplo n.º 3
0
def model(num_hidden=50):
    zs = pyro.sample('zs', VSGP(OUKernel(drift, scale), inducing_set,
                                dist.MultivariateNormal, input_data=input_data))
    decoder = torch.nn.GRU(30, num_hidden)
    pyro.module('decoder', decoder)
    mean_nn = torch.nn.Sequential(torch.nn.Linear(num_hidden, 1),
                                  torch.nn.Sigmoid())
    kappa_nn = torch.nn.Sequential(torch.nn.Linear(num_hidden, 1),
                                   torch.nn.Softplus())
    mean_nn2 = torch.nn.Sequential(torch.nn.Linear(num_hidden, 1),
                                   torch.nn.Sigmoid())
    kappa_nn2 = torch.nn.Sequential(torch.nn.Linear(num_hidden, 1),
                                    torch.nn.Softplus())
    pyro.module('param_nn', param_nn)
    decoded = decoder(zs)  # Should autoregress on zs
    means1 = mean_nn(decoded)
    kappas1 = kappa_nn(decoded)
    means2 = mean_nn2(decoded)
    kappas2 = kappa_nn2(decoded)
    pyro.sample('phis', dist.VonMises(means1, kappas1), obs=observed_phis)
    pyro.sample('psis', dist.VonMises(means2, kappas2), obs=observed_psis)
Ejemplo n.º 4
0
 def model(self, x):
     # register PyTorch module `decoder` with Pyro
     pyro.module("decoder", self.decoder)
     with pyro.plate("data", x.shape[0]):
         # setup hyperparameters for prior p(z)
         z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
         z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
         # sample from prior (value will be sampled by guide when computing the ELBO)
         z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
         # decode the latent code z
         loc, k = self.decoder.forward(z)
         # score against actual images
         pyro.sample("obs", dist.VonMises(loc, k).to_event(1), obs=x)
Ejemplo n.º 5
0
def model_3DA(data):
    # Sampling mu and kappa
    mu_psi_1        = pyro.sample("mu_psi_1", dist.Uniform(-np.pi, np.pi))
    inv_kappa_psi_1 = pyro.sample("inv_kappa_psi_1", dist.HalfNormal(1.))
    kappa_psi_1     = 100 + 1/inv_kappa_psi_1
    
    mu_phi_2        = pyro.sample("mu_phi_2", dist.Uniform(-np.pi, np.pi))
    inv_kappa_phi_2 = pyro.sample("inv_kappa_phi_2", dist.HalfNormal(1.))
    kappa_phi_2     = 100 + 1/inv_kappa_phi_2
    
    mu_psi_2        = pyro.sample("mu_psi_2", dist.Uniform(-np.pi, np.pi))
    inv_kappa_psi_2 = pyro.sample("inv_kappa_psi_2", dist.HalfNormal(1.))
    kappa_psi_2     = 100 + 1/inv_kappa_psi_2
    
    mu_phi_3        = pyro.sample("mu_phi_3", dist.Uniform(-np.pi, np.pi))
    inv_kappa_phi_3 = pyro.sample("inv_kappa_phi_3", dist.HalfNormal(1.))
    kappa_phi_3     = 100 + 1/inv_kappa_phi_3
    
    # Looping over the observed data in an conditionally independant manner
    with pyro.plate('dihedral_angles'):
        pyro.sample("obs_psi_1", dist.VonMises(mu_psi_1, kappa_psi_1), obs=data[0,:,1])           
        pyro.sample("obs_phi_2", dist.VonMises(mu_phi_2, kappa_phi_2), obs=data[1,:,0])
        pyro.sample("obs_psi_2", dist.VonMises(mu_psi_2, kappa_psi_2), obs=data[1,:,1])            
        pyro.sample("obs_phi_3", dist.VonMises(mu_phi_3, kappa_phi_3), obs=data[2,:,0])
Ejemplo n.º 6
0
 def __init__(self, von_loc, von_conc, skewness):
     base_dist = dist.VonMises(von_loc, von_conc).to_event(von_loc.ndim)
     super().__init__(base_dist, skewness)
Ejemplo n.º 7
0
def random_pose_transform(transforms: transforms.TransformSequence,
                          device=torch.device("cpu")):
    """
    Take a sequence of transformations and create a r.v corresponding to the co-ordinates of each.
    Return the transformation corresponding to this that can be applied to an input image.
    TODO: atm these classes have a cnn attached; this doesn't feel like a super clean abstraciton, but we'll see...
    """
    params = []

    for transform in transforms:
        # sample the U and V corresponding to that transformation
        sample_site_prefix = repr(type(transform)).split(".")[-1].strip("'>")

        u_sample_name = sample_site_prefix + "_u"
        v_sample_name = sample_site_prefix + "_v"

        # todo: make this more sensible
        ps = []
        if transform.has_u:
            if transform.periodic_u:
                u = pyro.sample(
                    u_sample_name,
                    D.VonMises(
                        torch.tensor(0.0, device=device),
                        torch.tensor(1e-2, device=device),
                    ),
                )
            else:
                u = pyro.sample(
                    u_sample_name,
                    D.Normal(
                        torch.tensor(0.0, device=device),
                        torch.tensor(1.0, device=device),
                    ),
                )
            ps.append(u)

        if transform.has_v:
            if transform.periodic_v:
                v = pyro.sample(
                    v_sample_name,
                    D.VonMises(
                        torch.tensor(0.0, device=device),
                        torch.tensor(1e-2, device=device),
                    ),
                )
            else:
                v = pyro.sample(
                    v_sample_name,
                    D.Normal(
                        torch.tensor(0.0, device=device),
                        torch.tensor(1.0, device=device),
                    ),
                )
            ps.append(v)

        params.append(ps)

    # we take the inverse transformation here because we want to use the order specified by
    # the spatial transformer network, which uses the opposite convection to what make sense
    # from a generative perspective
    if len(params) > 0:
        transform_grid = transforms.inverse_transform_from_params(params)
    else:
        # handle the edge case where we have an empty list of transformations.
        transform_grid = ProjectiveGridTransform(torch.eye(3, device=device))
    return transform_grid
Ejemplo n.º 8
0
def torus_dbn(phis=None,
              psis=None,
              lengths=None,
              num_sequences=None,
              num_states=55,
              prior_conc=0.1,
              prior_loc=0.0,
              prior_length_shape=100.,
              prior_length_rate=100.,
              prior_kappa_min=10.,
              prior_kappa_max=1000.):
    # From https://pyro.ai/examples/hmm.html
    with ignore_jit_warnings():
        if lengths is not None:
            assert num_sequences is None
            num_sequences = int(lengths.shape[0])
        else:
            assert num_sequences is not None
    transition_probs = pyro.sample(
        'transition_probs',
        dist.Dirichlet(
            torch.ones(num_states, num_states, dtype=torch.float) *
            num_states).to_event(1))
    length_shape = pyro.sample('length_shape',
                               dist.HalfCauchy(prior_length_shape))
    length_rate = pyro.sample('length_rate',
                              dist.HalfCauchy(prior_length_rate))
    phi_locs = pyro.sample(
        'phi_locs',
        dist.VonMises(
            torch.ones(num_states, dtype=torch.float) * prior_loc,
            torch.ones(num_states, dtype=torch.float) *
            prior_conc).to_event(1))
    phi_kappas = pyro.sample(
        'phi_kappas',
        dist.Uniform(
            torch.ones(num_states, dtype=torch.float) * prior_kappa_min,
            torch.ones(num_states, dtype=torch.float) *
            prior_kappa_max).to_event(1))
    psi_locs = pyro.sample(
        'psi_locs',
        dist.VonMises(
            torch.ones(num_states, dtype=torch.float) * prior_loc,
            torch.ones(num_states, dtype=torch.float) *
            prior_conc).to_event(1))
    psi_kappas = pyro.sample(
        'psi_kappas',
        dist.Uniform(
            torch.ones(num_states, dtype=torch.float) * prior_kappa_min,
            torch.ones(num_states, dtype=torch.float) *
            prior_kappa_max).to_event(1))
    element_plate = pyro.plate('elements', 1, dim=-1)
    with pyro.plate('sequences', num_sequences, dim=-2) as batch:
        if lengths is not None:
            lengths = lengths[batch]
            obs_length = lengths.float().unsqueeze(-1)
        else:
            obs_length = None
        state = 0
        sam_lengths = pyro.sample('length',
                                  dist.TransformedDistribution(
                                      dist.GammaPoisson(
                                          length_shape, length_rate),
                                      AffineTransform(0., 1.)),
                                  obs=obs_length)
        if lengths is None:
            lengths = sam_lengths.squeeze(-1).long()
        for t in pyro.markov(range(lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                state = pyro.sample(f'state_{t}',
                                    dist.Categorical(transition_probs[state]),
                                    infer={'enumerate': 'parallel'})
                if phis is not None:
                    obs_phi = Vindex(phis)[batch, t].unsqueeze(-1)
                else:
                    obs_phi = None
                if psis is not None:
                    obs_psi = Vindex(psis)[batch, t].unsqueeze(-1)
                else:
                    obs_psi = None
                with element_plate:
                    pyro.sample(f'phi_{t}',
                                dist.VonMises(phi_locs[state],
                                              phi_kappas[state]),
                                obs=obs_phi)
                    pyro.sample(f'psi_{t}',
                                dist.VonMises(psi_locs[state],
                                              psi_kappas[state]),
                                obs=obs_psi)
Ejemplo n.º 9
0
def model_generic(config):
    """Hierarchical mixed-effects hidden markov model"""

    MISSING = config["MISSING"]
    N_v = config["sizes"]["random"]
    N_state = config["sizes"]["state"]

    # initialize group-level random effect parameterss
    if config["group"]["random"] == "discrete":
        probs_e_g = pyro.param("probs_e_group",
                               lambda: torch.randn((N_v, )).abs(),
                               constraint=constraints.simplex)
        theta_g = pyro.param("theta_group", lambda: torch.randn(
            (N_v, N_state**2)))
    elif config["group"]["random"] == "continuous":
        loc_g = torch.zeros((N_state**2, ))
        scale_g = torch.ones((N_state**2, ))

    # initialize individual-level random effect parameters
    N_c = config["sizes"]["group"]
    if config["individual"]["random"] == "discrete":
        probs_e_i = pyro.param("probs_e_individual",
                               lambda: torch.randn((
                                   N_c,
                                   N_v,
                               )).abs(),
                               constraint=constraints.simplex)
        theta_i = pyro.param("theta_individual", lambda: torch.randn(
            (N_c, N_v, N_state**2)))
    elif config["individual"]["random"] == "continuous":
        loc_i = torch.zeros((
            N_c,
            N_state**2,
        ))
        scale_i = torch.ones((
            N_c,
            N_state**2,
        ))

    # initialize likelihood parameters
    # observation 1: step size (step ~ Gamma)
    step_zi_param = pyro.param("step_zi_param", lambda: torch.ones(
        (N_state, 2)))
    step_concentration = pyro.param("step_param_concentration",
                                    lambda: torch.randn((N_state, )).abs(),
                                    constraint=constraints.positive)
    step_rate = pyro.param("step_param_rate",
                           lambda: torch.randn((N_state, )).abs(),
                           constraint=constraints.positive)

    # observation 2: step angle (angle ~ VonMises)
    angle_concentration = pyro.param("angle_param_concentration",
                                     lambda: torch.randn((N_state, )).abs(),
                                     constraint=constraints.positive)
    angle_loc = pyro.param("angle_param_loc", lambda: torch.randn(
        (N_state, )).abs())

    # observation 3: dive activity (omega ~ Beta)
    omega_zi_param = pyro.param("omega_zi_param", lambda: torch.ones(
        (N_state, 2)))
    omega_concentration0 = pyro.param("omega_param_concentration0",
                                      lambda: torch.randn((N_state, )).abs(),
                                      constraint=constraints.positive)
    omega_concentration1 = pyro.param("omega_param_concentration1",
                                      lambda: torch.randn((N_state, )).abs(),
                                      constraint=constraints.positive)

    # initialize gamma to uniform
    gamma = torch.zeros((N_state**2, ))

    N_c = config["sizes"]["group"]
    with pyro.plate("group", N_c, dim=-1):

        # group-level random effects
        if config["group"]["random"] == "discrete":
            # group-level discrete effect
            e_g = pyro.sample("e_g", dist.Categorical(probs_e_g))
            eps_g = Vindex(theta_g)[..., e_g, :]
        elif config["group"]["random"] == "continuous":
            eps_g = pyro.sample(
                "eps_g",
                dist.Normal(loc_g, scale_g).to_event(1),
            )  # infer={"num_samples": 10})
        else:
            eps_g = 0.

        # add group-level random effect to gamma
        gamma = gamma + eps_g

        N_s = config["sizes"]["individual"]
        with pyro.plate(
                "individual", N_s,
                dim=-2), poutine.mask(mask=config["individual"]["mask"]):

            # individual-level random effects
            if config["individual"]["random"] == "discrete":
                # individual-level discrete effect
                e_i = pyro.sample("e_i", dist.Categorical(probs_e_i))
                eps_i = Vindex(theta_i)[..., e_i, :]
                # assert eps_i.shape[-3:] == (1, N_c, N_state ** 2) and eps_i.shape[0] == N_v
            elif config["individual"]["random"] == "continuous":
                eps_i = pyro.sample(
                    "eps_i",
                    dist.Normal(loc_i, scale_i).to_event(1),
                )  # infer={"num_samples": 10})
            else:
                eps_i = 0.

            # add individual-level random effect to gamma
            gamma = gamma + eps_i

            y = torch.tensor(0).long()

            N_t = config["sizes"]["timesteps"]
            for t in pyro.markov(range(N_t)):
                with poutine.mask(mask=config["timestep"]["mask"][..., t]):
                    gamma_t = gamma  # per-timestep variable

                    # finally, reshape gamma as batch of transition matrices
                    gamma_t = gamma_t.reshape(
                        tuple(gamma_t.shape[:-1]) + (N_state, N_state))

                    # we've accounted for all effects, now actually compute gamma_y
                    gamma_y = Vindex(gamma_t)[..., y, :]
                    y = pyro.sample("y_{}".format(t),
                                    dist.Categorical(logits=gamma_y))

                    # observation 1: step size
                    step_dist = dist.Gamma(
                        concentration=Vindex(step_concentration)[..., y],
                        rate=Vindex(step_rate)[..., y])

                    # zero-inflation with MaskedMixture
                    step_zi = Vindex(step_zi_param)[..., y, :]
                    step_zi_mask = pyro.sample(
                        "step_zi_{}".format(t),
                        dist.Categorical(logits=step_zi),
                        obs=(config["observations"]["step"][...,
                                                            t] == MISSING))
                    step_zi_zero_dist = dist.Delta(v=torch.tensor(MISSING))
                    step_zi_dist = dist.MaskedMixture(step_zi_mask, step_dist,
                                                      step_zi_zero_dist)

                    pyro.sample("step_{}".format(t),
                                step_zi_dist,
                                obs=config["observations"]["step"][..., t])

                    # observation 2: step angle
                    angle_dist = dist.VonMises(
                        concentration=Vindex(angle_concentration)[..., y],
                        loc=Vindex(angle_loc)[..., y])
                    pyro.sample("angle_{}".format(t),
                                angle_dist,
                                obs=config["observations"]["angle"][..., t])

                    # observation 3: dive activity
                    omega_dist = dist.Beta(
                        concentration0=Vindex(omega_concentration0)[..., y],
                        concentration1=Vindex(omega_concentration1)[..., y])

                    # zero-inflation with MaskedMixture
                    omega_zi = Vindex(omega_zi_param)[..., y, :]
                    omega_zi_mask = pyro.sample(
                        "omega_zi_{}".format(t),
                        dist.Categorical(logits=omega_zi),
                        obs=(config["observations"]["omega"][...,
                                                             t] == MISSING))

                    omega_zi_zero_dist = dist.Delta(v=torch.tensor(MISSING))
                    omega_zi_dist = dist.MaskedMixture(omega_zi_mask,
                                                       omega_dist,
                                                       omega_zi_zero_dist)

                    pyro.sample("omega_{}".format(t),
                                omega_zi_dist,
                                obs=config["observations"]["omega"][..., t])