示例#1
0
    def __init__(self, rate, *, gate=None, gate_logits=None, validate_args=None):
        base_dist = Poisson(rate=rate, validate_args=False)
        base_dist._validate_args = validate_args

        super().__init__(
            base_dist, gate=gate, gate_logits=gate_logits, validate_args=validate_args
        )
示例#2
0
def test_zip_0_gate(rate):
    # if gate is 0 ZIP is Poisson
    zip_ = ZeroInflatedPoisson(torch.zeros(1), torch.tensor(rate))
    pois = Poisson(torch.tensor(rate))
    s = pois.sample((20, ))
    zip_prob = zip_.log_prob(s)
    pois_prob = pois.log_prob(s)
    assert_tensors_equal(zip_prob, pois_prob)
示例#3
0
 def step(self, state, branch, ρ=1.0):
     Δ = branch["t_beg"] - branch["t_end"]
     if branch['parent_id'] is None and Δ < 1e-5:
         return
     count_hs = sample(f"count_hs_{branch['id']}", Poisson(state["λ"] * Δ))
     f = vec_survives(branch["t_end"], branch["t_beg"], count_hs.numpy(), state["λ"].numpy(), state["μ"].numpy(), ρ)
     factor(f"factor_hs_{branch['id']}", f)
     sample(f"num_ex_{branch['id']}", Poisson(state["μ"] * Δ), obs=tensor(0))
     if branch["has_children"]:
         sample(f"spec_{branch['id']}", Exponential(state["λ"]), obs=tensor(1e-40))
     else:
         sample(f"obs_{branch['id']}", Bernoulli(ρ), obs=tensor(1.))
示例#4
0
def test_zip_0_gate(rate):
    # if gate is 0 ZIP is Poisson
    zip1 = ZeroInflatedPoisson(torch.tensor(rate), gate=torch.zeros(1))
    zip2 = ZeroInflatedPoisson(torch.tensor(rate),
                               gate_logits=torch.tensor(-99.9))
    pois = Poisson(torch.tensor(rate))
    s = pois.sample((20, ))
    zip1_prob = zip1.log_prob(s)
    zip2_prob = zip2.log_prob(s)
    pois_prob = pois.log_prob(s)
    assert_close(zip1_prob, pois_prob)
    assert_close(zip2_prob, pois_prob)
示例#5
0
    def model(self, x):
        x_size = x.size(0)

        # sample the global weights
        with pyro.iarange("w_top_iarange", self.top_width * self.mid_width):
            w_top = pyro.sample("w_top", Gamma(self.alpha_w, self.beta_w))
        with pyro.iarange("w_mid_iarange", self.mid_width * self.bottom_width):
            w_mid = pyro.sample("w_mid", Gamma(self.alpha_w, self.beta_w))
        with pyro.iarange("w_bottom_iarange",
                          self.bottom_width * self.image_size):
            w_bottom = pyro.sample("w_bottom", Gamma(self.alpha_w,
                                                     self.beta_w))

        # sample the local latent random variables
        # (the iarange encodes the fact that the z's for different datapoints are conditionally independent)
        with pyro.iarange("data", x_size):
            z_top = pyro.sample(
                "z_top",
                Gamma(self.alpha_z,
                      self.beta_z).expand([self.top_width]).independent(1))
            mean_mid = torch.mm(z_top,
                                w_top.reshape(self.top_width, self.mid_width))
            z_mid = pyro.sample(
                "z_mid",
                Gamma(self.alpha_z, self.beta_z / mean_mid).independent(1))
            mean_bottom = torch.mm(
                z_mid, w_mid.view(self.mid_width, self.bottom_width))
            z_bottom = pyro.sample(
                "z_bottom",
                Gamma(self.alpha_z, self.beta_z / mean_bottom).independent(1))
            mean_obs = torch.mm(
                z_bottom, w_bottom.view(self.bottom_width, self.image_size))

            # observe the data using a poisson likelihood
            pyro.sample('obs', Poisson(mean_obs).independent(1), obs=x)
def model(x):
    x = torch.reshape(x, [320, 4096])

    with pyro.plate("w_top_plate", 4000):
        w_top = pyro.sample("w_top", Gamma(alpha_w, beta_w))
    with pyro.plate("w_mid_plate", 600):
        w_mid = pyro.sample("w_mid", Gamma(alpha_w, beta_w))
    with pyro.plate("w_bottom_plate", 61440):
        w_bottom = pyro.sample("w_bottom", Gamma(alpha_w, beta_w))

    with pyro.plate("data", 320):
        z_top = pyro.sample(
            "z_top",
            Gamma(alpha_z, beta_z).expand_by([100]).to_event(1))

        w_top = torch.reshape(w_top, [100, 40])
        mean_mid = torch.matmul(z_top, w_top)
        z_mid = pyro.sample("z_mid",
                            Gamma(alpha_z, beta_z / mean_mid).to_event(1))

        w_mid = torch.reshape(w_mid, [40, 15])
        mean_bottom = torch.matmul(z_mid, w_mid)
        z_bottom = pyro.sample(
            "z_bottom",
            Gamma(alpha_z, beta_z / mean_bottom).to_event(1))

        w_bottom = torch.reshape(w_bottom, [15, 4096])
        mean_obs = torch.matmul(z_bottom, w_bottom)

        pyro.sample('obs', Poisson(mean_obs).to_event(1), obs=x)
    def model(self, x):
        x_size = x.size(0)

        # sample the global weights
        with pyro.plate("w_top_plate", self.top_width * self.mid_width):
            w_top = pyro.sample("w_top", Gamma(self.alpha_w, self.beta_w))
        with pyro.plate("w_mid_plate", self.mid_width * self.bottom_width):
            w_mid = pyro.sample("w_mid", Gamma(self.alpha_w, self.beta_w))
        with pyro.plate("w_bottom_plate", self.bottom_width * self.image_size):
            w_bottom = pyro.sample("w_bottom", Gamma(self.alpha_w, self.beta_w))

        # sample the local latent random variables
        # (the plate encodes the fact that the z's for different datapoints are conditionally independent)
        with pyro.plate("data", x_size):
            z_top = pyro.sample("z_top", Gamma(self.alpha_z, self.beta_z).expand([self.top_width]).to_event(1))
            # note that we need to use matmul (batch matrix multiplication) as well as appropriate reshaping
            # to make sure our code is fully vectorized
            w_top = w_top.reshape(self.top_width, self.mid_width) if w_top.dim() == 1 else \
                w_top.reshape(-1, self.top_width, self.mid_width)
            mean_mid = torch.matmul(z_top, w_top)
            z_mid = pyro.sample("z_mid", Gamma(self.alpha_z, self.beta_z / mean_mid).to_event(1))

            w_mid = w_mid.reshape(self.mid_width, self.bottom_width) if w_mid.dim() == 1 else \
                w_mid.reshape(-1, self.mid_width, self.bottom_width)
            mean_bottom = torch.matmul(z_mid, w_mid)
            z_bottom = pyro.sample("z_bottom", Gamma(self.alpha_z, self.beta_z / mean_bottom).to_event(1))

            w_bottom = w_bottom.reshape(self.bottom_width, self.image_size) if w_bottom.dim() == 1 else \
                w_bottom.reshape(-1, self.bottom_width, self.image_size)
            mean_obs = torch.matmul(z_bottom, w_bottom)

            # observe the data using a poisson likelihood
            pyro.sample('obs', Poisson(mean_obs).to_event(1), obs=x)
示例#8
0
 def model():
     lambda_latent = pyro.sample("lambda_latent",
                                 Gamma(self.alpha0, self.beta0))
     x_dist = Poisson(lambda_latent)
     # x0 = pyro.observe("obs0", x_dist, self.data[0])
     pyro.map_data(self.data,
                   lambda i, x: pyro.observe("obs", x_dist, x),
                   batch_size=3)
     return lambda_latent
示例#9
0
 def step(self, state, branch, ρ=1.0):
     Δ = branch["t_beg"] - branch["t_end"]
     if branch['parent_id'] is None and Δ == 0:
         return
     count_hs = sample(f"count_hs_{branch['id']}", Poisson(state["λ"] * Δ))
     f = zeros(state._num_particles)
     for n in range(state._num_particles):
         for i in range(int(count_hs[n])):
             t = Uniform(branch["t_end"], branch["t_beg"]).sample()
             if self.survives(t, state["λ"][n], state["μ"][n], ρ):
                 f[n] = -float('inf')
                 break
             f[n] += log(tensor(2))
     factor(f"factor_hs_{branch['id']}", f)
     sample(f"num_ex_{branch['id']}", Poisson(state["μ"] * Δ), obs=tensor(0))
     if branch["has_children"]:
         sample(f"spec_{branch['id']}", Exponential(state["λ"]), obs=tensor(1e-40))
     else:
         sample(f"obs_{branch['id']}", Bernoulli(ρ), obs=tensor(1.))
    def sample(self, sample_shape=torch.Size()):
        gamma_d = self._gamma()
        p_means = gamma_d.sample(sample_shape)

        # Clamping as distributions objects can have buggy behaviors when
        # their parameters are too high
        l_train = torch.clamp(p_means, max=1e8)
        counts = Poisson(
            l_train).sample()  # Shape : (n_samples, n_cells_batch, n_genes)
        return counts
示例#11
0
def train_model(data, n_steps, pmf):

    def model(data):
        mu = 2.8
        num_sigmas = 4
        sigma = 0.3
        low = mu - num_sigmas * sigma
        high = mu + num_sigmas * sigma

        f = pyro.sample("latent", dist.Uniform(low, high))
        print(f)
        # sample f from the prior
        # Probabilities are generated by the pmf
        for i in range(len(data)):
            pyro.sample("obs_{}".format(i), Poisson(f), obs=data[i])


    def guide(data):
        lam = pyro.param("lam", torch.tensor(2.0), constraint=constraints.positive)
        # alpha_q = pyro.param("alpha_q", torch.tensor(2.0))
        # beta_q = pyro.param("beta_q", torch.tensor(1.0))
        pyro.sample("latent", dist.Poisson(torch.tensor(2.0)))

    adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
    optimizer = ClippedAdam(adam_params)

    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

    for step in range(n_steps):
        loss = svi.step(data)
        if step % 100 == 0:
            logging.info(".")
            logging.info("Elbo loss: {}".format(loss))

    # grab the learned variational parameters
    lam = pyro.param("lam").item()
    print(lam)
    # a_q = pyro.param("alpha_q").item()
    # b_q = pyro.param("beta_q").item()
    # print(a_q, b_q)
    posterior = Poisson(lam)
    logging.info("Sampling:{}".format(posterior.sample()))
示例#12
0
 def survives(self, t, λ, μ, ρ):
     t_end = t - Exponential(μ).sample()
     if t_end <= 0:
         if Bernoulli(ρ).sample():
             return True
         t_end = 0
     for i in range(int(Poisson(λ * (t - t_end)).sample())):
         τ = Uniform(t_end, t).sample()
         if self.survives(τ, λ, μ, ρ):
             return True
     return False
示例#13
0
 def model():
     alpha_p_log = pyro.param(
         "alpha_p_log", Variable(self.alpha_p_log_0,
                                 requires_grad=True))
     beta_p_log = pyro.param(
         "beta_p_log", Variable(self.beta_p_log_0, requires_grad=True))
     alpha_p, beta_p = torch.exp(alpha_p_log), torch.exp(beta_p_log)
     lambda_latent = pyro.sample("lambda_latent",
                                 Gamma(alpha_p, beta_p))
     x_dist = Poisson(lambda_latent)
     pyro.observe("obs", x_dist, self.data)
     return lambda_latent
示例#14
0
    def model(data):
        mu = 2.8
        num_sigmas = 4
        sigma = 0.3
        low = mu - num_sigmas * sigma
        high = mu + num_sigmas * sigma

        f = pyro.sample("latent", dist.Uniform(low, high))
        print(f)
        # sample f from the prior
        # Probabilities are generated by the pmf
        for i in range(len(data)):
            pyro.sample("obs_{}".format(i), Poisson(f), obs=data[i])
def model(n_ice, n_obs, floe_size, cover_subp):
    a_floe = pyro.sample("a_floe", dist.Normal(1., 1.))
    b_floe = pyro.sample("b_floe", dist.Normal(0., 1.))
    a_cover = pyro.sample("a_cover", dist.Normal(1., 1.))
    b_cover = pyro.sample("b_cover", dist.Normal(0., 1.))
    a_cover_b = pyro.sample("a_cover_b", dist.Normal(1., 1.))
    b_cover_b = pyro.sample("b_cover_b", dist.Normal(0., 1.))
    lambda_ice = sigmoid((a_floe * floe_size + b_floe))
    alpha_det = sigmoid((a_cover * cover_subp + b_cover))
    beta_det = sigmoid((a_cover_b * cover_subp + b_cover_b))
    with pyro.plate('subp', size=len(floe_size)):
        N_ice = pyro.sample('N_ice', Poisson(lambda_ice), obs=n_ice)
        phi_det = pyro.sample('phi_det', Beta(alpha_det,
                                              beta_det)) * (N_ice > 0).float()
        N_obs = pyro.sample('N_obs', Binomial(N_ice, phi_det), obs=n_obs)
示例#16
0
    def __init__(self, gate, rate, validate_args=None):
        base_dist = Poisson(rate=rate, validate_args=validate_args)

        super(ZeroInflatedPoisson, self).__init__(gate,
                                                  base_dist,
                                                  validate_args=validate_args)
示例#17
0
    def __init__(self, gate, rate, validate_args=None):
        base_dist = Poisson(rate=rate, validate_args=False)
        base_dist._validate_args = validate_args

        super().__init__(gate, base_dist, validate_args=validate_args)