def pyro_model(x, y):
     priors= {
         'covar_module.base_kernel.raw_lengthscale': Normal(0, 2).expand([1, 1]),
         'covar_module.raw_outputscale': Normal(0, 2),
         'likelihood.noise_covar.raw_noise': Normal(0, 2).expand([1]),
         'mean_module.constant': Normal(0, 2),
     }
     fn = pyro.random_module("model", model, prior=priors)
     sampled_model = fn()
     
     output = sampled_model.likelihood(sampled_model(x))
     pyro.sample("obs", output, obs=y)
Beispiel #2
0
        def model():
            p = torch.tensor([0.5])
            loc = torch.zeros(1)
            scale = torch.ones(1)

            x = pyro.sample("x",
                            Normal(loc,
                                   scale))  # Before the discrete variable.
            y = pyro.sample("y", Bernoulli(p))
            z = pyro.sample("z", Normal(loc,
                                        scale))  # After the discrete variable.
            return dict(x=x, y=y, z=z)
Beispiel #3
0
def bayes_logistic(X, y):
    n, k = X.shape
    w_prior = Normal(torch.zeros(1, k), torch.ones(1, k)).to_event(1)
    b_prior = Normal(torch.tensor([[0.]]), torch.tensor([[10.]])).to_event(1)
    priors = {"linear.weight": w_prior, "linear.bias": b_prior}
    lifted_module = \
        pyro.random_module("bayes_logistic", frequentist_model, priors)
    lifted_model = lifted_module()
    with pyro.plate("customers", n):
        y_pred = lifted_model(X).squeeze(1)
        pyro.sample("obs", Bernoulli(y_pred, validate_args=True), obs=y)
        return y_pred
Beispiel #4
0
def model(is_cont_africa, ruggedness, log_gdp):
    # WL: edited. =====
    # a = pyro.sample("a", Normal(8., 1000.))
    a = pyro.sample("a", Normal(0., 10.))
    # =================
    b_a = pyro.sample("bA", Normal(0., 1.))
    b_r = pyro.sample("bR", Normal(0., 1.))
    b_ar = pyro.sample("bAR", Normal(0., 1.))
    sigma = pyro.sample("sigma", Uniform(0., 10.))
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
    with pyro.plate("data", 170):
        pyro.sample("obs", Normal(mean, sigma), obs=log_gdp)
Beispiel #5
0
def guide(data):

    w_mu = Variable(torch.randn(second_layer, first_layer).type_as(data.data),
                    requires_grad=True)
    w_log_sig = Variable(
        0.1 * torch.ones(second_layer, first_layer).type_as(data.data),
        requires_grad=True)
    b_mu = Variable(torch.randn(second_layer).type_as(data.data),
                    requires_grad=True)
    b_log_sig = Variable(0.1 * torch.ones(second_layer).type_as(data.data),
                         requires_grad=True)

    # register learnable params in the param store
    mw_param = pyro.param("guide_mean_weight", w_mu)
    sw_param = softplus(pyro.param("guide_log_sigma_weight", w_log_sig))
    mb_param = pyro.param("guide_mean_bias", b_mu)
    sb_param = softplus(pyro.param("guide_log_sigma_bias", b_log_sig))

    # gaussian guide distributions for w and b
    w_dist = Normal(mw_param, sw_param)
    b_dist = Normal(mb_param, sb_param)

    w_mu2 = Variable(torch.randn(1, second_layer).type_as(data.data),
                     requires_grad=True)
    w_log_sig2 = Variable(0.1 *
                          torch.randn(1, second_layer).type_as(data.data),
                          requires_grad=True)
    b_mu2 = Variable(torch.randn(1).type_as(data.data), requires_grad=True)
    b_log_sig2 = Variable(0.1 * torch.ones(1).type_as(data.data),
                          requires_grad=True)

    # register learnable params in the param store
    mw_param2 = pyro.param("guide_mean_weight2", w_mu2)
    sw_param2 = softplus(pyro.param("guide_log_sigma_weight2", w_log_sig2))
    mb_param2 = pyro.param("guide_mean_bias2", b_mu2)
    sb_param2 = softplus(pyro.param("guide_log_sigma_bias2", b_log_sig2))

    # gaussian guide distributions for w and b
    w_dist2 = Normal(mw_param2, sw_param2)
    b_dist2 = Normal(mb_param2, sb_param2)

    dists = {
        'hidden.weight': w_dist,
        'hidden.bias': b_dist,
        'predict.weight': w_dist2,
        'predict.bias': b_dist2
    }

    # overloading the parameters in the module with random samples from the guide distributions
    lifted_module = pyro.random_module("module", regression_model, dists)
    # sample a regressor
    return lifted_module()
Beispiel #6
0
    def generate(self, num_to_sample: int = 1):
        """Generate samples from prior."""
        cuda_device = self._get_prediction_device()
        prior_mean = nn_util.move_to_device(
            torch.zeros((num_to_sample, self._latent_dim)),
            cuda_device,
        )
        prior_stddev = torch.ones_like(prior_mean)
        prior = Normal(prior_mean, prior_stddev)
        latent = prior.sample()
        generated = self._decoder.generate(latent)

        return self.make_output_human_readable(generated)
Beispiel #7
0
        def guide():
            mu1 = pyro.param("mu1", Variable(torch.randn(2),
                                             requires_grad=True))
            sigma1 = pyro.param("sigma1",
                                Variable(torch.ones(2), requires_grad=True))
            pyro.sample("latent1", Normal(mu1, sigma1))

            mu2 = pyro.param("mu2", Variable(torch.randn(2),
                                             requires_grad=True))
            sigma2 = pyro.param("sigma2",
                                Variable(torch.ones(2), requires_grad=True))
            latent2 = pyro.sample("latent2", Normal(mu2, sigma2))
            return latent2
    def __init__(self,
                 size_in,
                 prior_factor=1.0,
                 weight_prior_std=1.0,
                 bias_prior_std=3.0,
                 **kwargs):

        self._params = OrderedDict()
        self._param_dists = OrderedDict()

        self.prior_factor = prior_factor
        self.gp = VectorizedGP(size_in, **kwargs)

        for name, shape in self.gp.parameter_shapes().items():

            if name == 'constant_mean':
                mean_p_loc = torch.zeros(1).to(device)
                mean_p_scale = torch.ones(1).to(device)
                self._param_dist(name,
                                 Normal(mean_p_loc, mean_p_scale).to_event(1))

            if name == 'lengthscale_raw':
                lengthscale_p_loc = torch.zeros(shape[-1]).to(device)
                lengthscale_p_scale = torch.ones(shape[-1]).to(device)
                self._param_dist(
                    name,
                    Normal(lengthscale_p_loc, lengthscale_p_scale).to_event(1))

            if name == 'noise_raw':
                noise_p_loc = -1. * torch.ones(1).to(device)
                noise_p_scale = torch.ones(1).to(device)
                self._param_dist(
                    name,
                    Normal(noise_p_loc, noise_p_scale).to_event(1))

            if 'mean_nn' in name or 'kernel_nn' in name:
                mean = torch.zeros(shape).to(device)
                if "weight" in name:
                    std = weight_prior_std * torch.ones(shape).to(device)
                elif "bias" in name:
                    std = bias_prior_std * torch.ones(shape).to(device)
                else:
                    raise NotImplementedError
                self._param_dist(name, Normal(mean, std).to_event(1))

        # check that parameters in prior and gp modules are aligned
        for param_name_gp, param_name_prior in zip(
                self.gp.named_parameters().keys(), self._param_dists.keys()):
            assert param_name_gp == param_name_prior

        self.hyper_prior = CatDist(self._param_dists.values())
Beispiel #9
0
    def __init__(self, use_affine_ex=True, **kwargs):
        super.__init__(**kwargs)

        self.num_scales = 2

        self.register_buffer("glasses_base_loc",
                             torch.zeros([
                                 1,
                             ], requires_grad=False))
        self.register_buffer("glasses_base_scale",
                             torch.ones([
                                 1,
                             ], requires_grad=False))

        self.register_buffer("glasses_flow_lognorm_loc",
                             torch.zeros([], requires_grad=False))
        self.register_buffer("glasses_flow_lognorm_scale",
                             torch.ones([], requires_grad=False))

        self.glasses_flow_components = ComposeTransformModule([Spline(1)])
        self.glasses_flow_constraint_transforms = ComposeTransform(
            [self.glasses_flow_lognorm,
             SigmoidTransform()])
        self.glasses_flow_transforms = ComposeTransform([
            self.glasses_flow_components,
            self.glasses_flow_constraint_transforms
        ])

        glasses_base_dist = Normal(self.glasses_base_loc,
                                   self.glasses_base_scale).to_event(1)
        self.glasses_dist = TransformedDistribution(
            glasses_base_dist, self.glasses_flow_transforms)
        glasses_ = pyro.sample("glasses_", self.glasses_dist)
        glasses = pyro.sample("glasses", dist.Bernoulli(glasses_))
        glasses_context = self.glasses_flow_constraint_transforms.inv(glasses_)

        self.x_transforms = self._build_image_flow()
        self.register_buffer("x_base_loc",
                             torch.zeros([1, 64, 64], requires_grad=False))
        self.register_buffer("x_base_scale",
                             torch.ones([1, 64, 64], requires_grad=False))
        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
        cond_x_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                x_base_dist,
                self.x_transforms).condition(context).transforms).inv
        cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms)

        x = pyro.sample("x", cond_x_dist)

        return x, glasses
def model_2(x_data, y_data):

  conv1w_prior = Normal(loc=torch.zeros_like(net.conv1.weight), scale=torch.ones_like(net.conv1.weight))
  conv1b_prior = Normal(loc=torch.zeros_like(net.conv1.bias), scale=torch.ones_like(net.conv1.bias))

  conv2w_prior = Normal(loc=torch.zeros_like(net.conv2.weight), scale=torch.ones_like(net.conv2.weight))
  conv2b_prior = Normal(loc=torch.zeros_like(net.conv2.bias), scale=torch.ones_like(net.conv2.bias))

  fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
  fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))

  fc2w_prior = Normal(loc=torch.zeros_like(net.fc2.weight), scale=torch.ones_like(net.fc2.weight))
  fc2b_prior = Normal(loc=torch.zeros_like(net.fc2.bias), scale=torch.ones_like(net.fc2.bias))

  priors = {'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior,
            'conv2.weight': conv2w_prior, 'conv2.bias': conv2b_prior,
            'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,
            'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior}

  lifted_module = pyro.random_module("module", net, priors)

  lifted_reg_model = lifted_module()

  lhat = log_softmax(lifted_reg_model(x_data))

  pyro.sample("obs", Categorical(logits=lhat), obs=y_data)
Beispiel #11
0
def partially_pooled(at_bats):
    """
    Number of hits has a Binomial distribution with a logit link function.
    The logits $\alpha$ for each player is normally distributed with the
    mean and scale parameters sharing a common prior.

    :param (torch.Tensor) at_bats: Number of at bats for each player.
    :return: Number of hits predicted by the model.
    """
    num_players = at_bats.shape[0]
    loc = pyro.sample("loc", Normal(at_bats.new_tensor(-1), at_bats.new_tensor(1)))
    scale = pyro.sample("scale", HalfCauchy(at_bats.new_tensor(0), at_bats.new_tensor(1)))
    alpha = pyro.sample("alpha", Normal(loc, scale).expand_by([num_players]).independent(1))
    return pyro.sample("obs", Binomial(at_bats, logits=alpha))
def guide_2(x_data, y_data):
    
    conv1w_mu = torch.randn_like(net.conv1.weight)
    conv1w_sigma = torch.randn_like(net.conv1.weight)
    conv1w_mu_param = pyro.param("conv1w_mu", conv1w_mu)
    conv1w_sigma_param = softplus(pyro.param("conv1w_sigma", conv1w_sigma))
    conv1w_prior = Normal(loc=conv1w_mu_param, scale=conv1w_sigma_param)
    
    conv1b_mu = torch.randn_like(net.conv1.bias)
    conv1b_sigma = torch.randn_like(net.conv1.bias)
    conv1b_mu_param = pyro.param("conv1b_mu", conv1b_mu)
    conv1b_sigma_param = softplus(pyro.param("conv1b_sigma", conv1b_sigma))
    conv1b_prior = Normal(loc=conv1b_mu_param, scale=conv1b_sigma_param)  

    conv2w_mu = torch.randn_like(net.conv2.weight)
    conv2w_sigma = torch.randn_like(net.conv2.weight)
    conv2w_mu_param = pyro.param("conv2w_mu", conv2w_mu)
    conv2w_sigma_param = softplus(pyro.param("conv2w_sigma", conv2w_sigma))
    conv2w_prior = Normal(loc=conv2w_mu_param, scale=conv2w_sigma_param)
    
    conv2b_mu = torch.randn_like(net.conv2.bias)
    conv2b_sigma = torch.randn_like(net.conv2.bias)
    conv2b_mu_param = pyro.param("conv2b_mu", conv2b_mu)
    conv2b_sigma_param = softplus(pyro.param("conv2b_sigma", conv2b_sigma))
    conv2b_prior = Normal(loc=conv2b_mu_param, scale=conv2b_sigma_param)

    # First layer weight distribution priors
    fc1w_mu = torch.randn_like(net.fc1.weight)
    fc1w_sigma = torch.randn_like(net.fc1.weight)
    fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
    fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
    fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
    
    fc1b_mu = torch.randn_like(net.fc1.bias)
    fc1b_sigma = torch.randn_like(net.fc1.bias)
    fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
    fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
    fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)

    fc2w_mu = torch.randn_like(net.fc2.weight)
    fc2w_sigma = torch.randn_like(net.fc2.weight)
    fc2w_mu_param = pyro.param("fc2w_mu", fc2w_mu)
    fc2w_sigma_param = softplus(pyro.param("fc2w_sigma", fc2w_sigma))
    fc2w_prior = Normal(loc=fc2w_mu_param, scale=fc2w_sigma_param)
    
    fc2b_mu = torch.randn_like(net.fc2.bias)
    fc2b_sigma = torch.randn_like(net.fc2.bias)
    fc2b_mu_param = pyro.param("fc2b_mu", fc2b_mu)
    fc2b_sigma_param = softplus(pyro.param("fc2b_sigma", fc2b_sigma))
    fc2b_prior = Normal(loc=fc2b_mu_param, scale=fc2b_sigma_param)
    
    


    priors = {'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior, 'conv2.weight': conv2w_prior, 'conv2.bias': conv2b_prior,
              'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior}
    
    lifted_module = pyro.random_module("module", net, priors)
    
    return lifted_module()
Beispiel #13
0
    def pgm_model(self):
        sex_dist = Bernoulli(logits=self.sex_logits).to_event(1)

        _ = self.sex_logits

        sex = pyro.sample('sex', sex_dist)

        age_base_dist = Normal(self.age_base_loc,
                               self.age_base_scale).to_event(1)
        age_dist = TransformedDistribution(age_base_dist,
                                           self.age_flow_transforms)

        age = pyro.sample('age', age_dist)
        age_ = self.age_flow_constraint_transforms.inv(age)
        # pseudo call to thickness_flow_transforms to register with pyro
        _ = self.age_flow_components

        brain_context = torch.cat([sex, age_], 1)

        brain_volume_base_dist = Normal(
            self.brain_volume_base_loc,
            self.brain_volume_base_scale).to_event(1)
        brain_volume_dist = ConditionalTransformedDistribution(
            brain_volume_base_dist,
            self.brain_volume_flow_transforms).condition(brain_context)

        brain_volume = pyro.sample('brain_volume', brain_volume_dist)
        # pseudo call to intensity_flow_transforms to register with pyro
        _ = self.brain_volume_flow_components

        brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
            brain_volume)

        ventricle_context = torch.cat([age_, brain_volume_], 1)

        ventricle_volume_base_dist = Normal(
            self.ventricle_volume_base_loc,
            self.ventricle_volume_base_scale).to_event(1)
        ventricle_volume_dist = ConditionalTransformedDistribution(
            ventricle_volume_base_dist,
            self.ventricle_volume_flow_transforms).condition(
                ventricle_context)  # noqa: E501

        ventricle_volume = pyro.sample('ventricle_volume',
                                       ventricle_volume_dist)
        # pseudo call to intensity_flow_transforms to register with pyro
        _ = self.ventricle_volume_flow_components

        return age, sex, ventricle_volume, brain_volume
def guide(x_data, y_data):

    # First layer weight distribution priors
    fc1w_mu = torch.randn_like(net.fc1.weight)
    fc1w_sigma = torch.randn_like(net.fc1.weight)
    fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
    fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
    fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
    # First layer bias distribution priors
    fc1b_mu = torch.randn_like(net.fc1.bias)
    fc1b_sigma = torch.randn_like(net.fc1.bias)
    fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
    fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
    fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)

    # Second layer weight distribution priors
    fc2w_mu = torch.randn_like(net.fc2.weight)
    fc2w_sigma = torch.randn_like(net.fc2.weight)
    fc2w_mu_param = pyro.param("fc2w_mu", fc2w_mu)
    fc2w_sigma_param = softplus(pyro.param("fc2w_sigma", fc2w_sigma))
    fc2w_prior = Normal(loc=fc2w_mu_param, scale=fc2w_sigma_param)
    # Second layer bias distribution priors
    fc2b_mu = torch.randn_like(net.fc2.bias)
    fc2b_sigma = torch.randn_like(net.fc2.bias)
    fc2b_mu_param = pyro.param("fc2b_mu", fc2b_mu)
    fc2b_sigma_param = softplus(pyro.param("fc2b_sigma", fc2b_sigma))
    fc2b_prior = Normal(loc=fc2b_mu_param, scale=fc2b_sigma_param)

    #3rd layer weight distribution priors
    fc3w_mu = torch.randn_like(net.fc3.weight)
    fc3w_sigma = torch.randn_like(net.fc3.weight)
    fc3w_mu_param = pyro.param("fc3w_mu", fc3w_mu)
    fc3w_sigma_param = softplus(pyro.param("fc3w_sigma", fc3w_sigma))
    fc3w_prior = Normal(loc=fc3w_mu_param, scale=fc3w_sigma_param)
    # Third layer bias distribution priors
    fc3b_mu = torch.randn_like(net.fc3.bias)
    fc3b_sigma = torch.randn_like(net.fc3.bias)
    fc3b_mu_param = pyro.param("fc3b_mu", fc3b_mu)
    fc3b_sigma_param = softplus(pyro.param("fc3b_sigma", fc3b_sigma))
    fc3b_prior = Normal(loc=fc3b_mu_param, scale=fc3b_sigma_param)

    # Output layer weight distribution priors
    outw_mu = torch.randn_like(net.out.weight)
    outw_sigma = torch.randn_like(net.out.weight)
    outw_mu_param = pyro.param("outw_mu", outw_mu)
    outw_sigma_param = softplus(pyro.param("outw_sigma", outw_sigma))
    outw_prior = Normal(loc=outw_mu_param,
                        scale=outw_sigma_param).independent(1)
    # Output layer bias distribution priors
    outb_mu = torch.randn_like(net.out.bias)
    outb_sigma = torch.randn_like(net.out.bias)
    outb_mu_param = pyro.param("outb_mu", outb_mu)
    outb_sigma_param = softplus(pyro.param("outb_sigma", outb_sigma))
    outb_prior = Normal(loc=outb_mu_param, scale=outb_sigma_param)
    priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior,\
              'fc3.weight': fc3w_prior, 'fc3.bias': fc3b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior}

    lifted_module = pyro.random_module("module", net, priors)
    return lifted_module()
Beispiel #15
0
def guide(fc_network: BNN, x_data, y_data):
    """
    Approximation of the posterior P(w|x_data), --> likelihood p(y_data|w, x_data)
    :param fc_network:
    :param x_data:
    :param y_data:
    :return:
    """
    # create weight distribution parameters priors
    priors = {}
    for i, layer in enumerate(fc_network.fc):
        if not hasattr(layer, 'weight'):
            continue
        # print("guide: ",i,layer)
        # print('guide_shapes',layer.weight.shape, layer.bias.shape)

        fciw_mu = Variable(torch.randn_like(layer.weight).type_as(x_data),
                           requires_grad=True)
        fcib_mu = Variable(torch.randn_like(layer.bias).type_as(x_data),
                           requires_grad=True)
        fciw_sigma = Variable(0.1 *
                              torch.randn_like(layer.weight).type_as(x_data),
                              requires_grad=True)
        fcib_sigma = Variable(0.1 *
                              torch.randn_like(layer.bias).type_as(x_data),
                              requires_grad=True)

        fciw_mu_param = pyro.param("guide.{}.w_mu".format(str(i)), fciw_mu)
        fcib_mu_param = pyro.param("guide.{}.b_mu".format(str(i)), fcib_mu)
        fciw_sigma_param = softplus(
            pyro.param("guide.{}.w_sigma".format(str(i)), fciw_sigma))
        fcib_sigma_param = softplus(
            pyro.param("guide.{}.b_sigma".format(str(i)), fcib_sigma))

        fciw_prior = Normal(fciw_mu_param, fciw_sigma_param)
        fcib_prior = Normal(fcib_mu_param, fcib_sigma_param)
        # TODO prior should have the same weight as in for name, _ in fc_network.named_parameters(),
        #  according to https://forum.pyro.ai/t/how-does-pyro-random-module-match-priors-with-regressionmodel-parameters/528/7
        priors['model.{}.weight'.format(str(i))] = fciw_prior
        priors['model.{}.bias'.format(str(i))] = fcib_prior
    #    lifted_module = pyro.module("module", fc_network, priors)
    # print('guide: ',priors)
    # for name, _ in fc_network.named_parameters():
    #    print(name)
    # exit(0)
    lifted_module = pyro.random_module("module", fc_network, priors)
    random_model = lifted_module()
    # print('lifted_module', random_model)
    return random_model
Beispiel #16
0
def guide(is_cont_africa, ruggedness, log_gdp):
    a_loc = pyro.param('a_loc', torch.tensor(0.))
    a_scale = pyro.param('a_scale', torch.tensor(1.),
                         constraint=constraints.positive)
    sigma_loc = pyro.param('sigma_loc', torch.tensor(1.),
                           constraint=constraints.positive)
    weights_loc = pyro.param('weights_loc', torch.rand(3))
    weights_scale = pyro.param('weights_scale', torch.ones(3),
                               constraint=constraints.positive)
    a = pyro.sample("a", Normal(a_loc, a_scale))
    b_a = pyro.sample("bA", Normal(weights_loc[0], weights_scale[0]))
    b_r = pyro.sample("bR", Normal(weights_loc[1], weights_scale[1]))
    b_ar = pyro.sample("bAR", Normal(weights_loc[2], weights_scale[2]))
    sigma = pyro.sample("sigma", Normal(sigma_loc, torch.tensor(0.05)))
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
Beispiel #17
0
def linear():
    x_data, y_data = [1, 2, 3, 4, 5,
                      6], torch.tensor([2.2, 4.2, 5.5, 8.3, 9.9, 12.1])
    k = sample('k', pyro.distributions.Normal(0, 1))
    if k < 0:
        slope = sample('slope', Normal(0, 5))
    else:
        slope = sample('slope', pyro.distributions.Bernoulli(0.5))

    bias = sample('bias', Normal(0, 5))

    for i in range(len(x_data)):
        x = x_data[i]
        mu = x * slope + bias
        y = sample(f"y_{i}", Normal(mu, 1), obs=y_data[i])
Beispiel #18
0
    def update_noise_svi(self, observed_steady_state, initial_noise):
        def guide(noise):
            noise_terms = list(noise.keys())
            mu_constraints = constraints.interval(-3., 3.)
            sigma_constraints = constraints.interval(.0001, 3)
            mu = {
                k: pyro.param('{}_mu'.format(k),
                              tensor(0.),
                              constraint=mu_constraints)
                for k in noise_terms
            }
            sigma = {
                k: pyro.param('{}_sigma'.format(k),
                              tensor(1.),
                              constraint=sigma_constraints)
                for k in noise_terms
            }
            for noise in noise_terms:
                sample(noise, Normal(mu[noise], sigma[noise]))

        observation_model = condition(self.noisy_model, observed_steady_state)
        pyro.clear_param_store()
        svi = SVI(model=observation_model,
                  guide=guide,
                  optim=SGD({
                      "lr": 0.001,
                      "momentum": 0.1
                  }),
                  loss=Trace_ELBO())

        losses = []
        num_steps = 1000
        samples = defaultdict(list)
        for t in range(num_steps):
            losses.append(svi.step(initial_noise))
            for noise in initial_noise.keys():
                mu = '{}_mu'.format(noise)
                sigma = '{}_sigma'.format(noise)
                samples[mu].append(pyro.param(mu).item())
                samples[sigma].append(pyro.param(sigma).item())
        means = {k: statistics.mean(v) for k, v in samples.items()}
        updated_noise = {
            'N_Raf': Normal(means['N_Raf_mu'], means['N_Raf_sigma']),
            'N_Mek': Normal(means['N_Mek_mu'], means['N_Mek_sigma']),
            'N_Erk': Normal(means['N_Erk_mu'], means['N_Erk_sigma'])
        }

        return updated_noise, losses
def guide_step(t, n, prev, inputs):
    rnn_input = torch.cat(
        (inputs['embed'], prev.z_where, prev.z_what, prev.z_pres), 1)
    h, c = rnn(rnn_input, (prev.h, prev.c))

    #===== predict
    out = predict_l2(F.relu(predict_l1(h)))
    z_pres_p = torch.sigmoid(out[:, 0:1])
    z_where_loc = out[:, 1:4]
    z_where_scale = F.softplus(out[:, 4:])
    #===== predict

    infer_dict, bl_h, bl_c = baseline_step(prev, inputs)

    z_pres =\
        pyro.sample('z_pres_{}'.format(t),
                    Bernoulli(z_pres_p * prev.z_pres).to_event(1),
                    infer=infer_dict)

    sample_mask = z_pres if use_masking else torch.tensor(1.0)
    z_where =\
        pyro.sample('z_where_{}'.format(t),
                    Normal(z_where_loc + z_where_loc_prior,
                           z_where_scale * z_where_scale_prior)
                    .mask(sample_mask)
                    .to_event(1))

    x_att = image_to_window(z_where, window_size, x_size, inputs['raw'])

    #===== encode
    a = encode_l2(F.relu(encode_l1(x_att)))
    z_what_loc = a[:, 0:50]
    z_what_scale = F.softplus(a[:, 50:])
    #===== encode

    z_what =\
        pyro.sample('z_what_{}'.format(t),
                    Normal(z_what_loc, z_what_scale)
                    .mask(sample_mask)
                    .to_event(1))

    return GuideState(h=h,
                      c=c,
                      bl_h=bl_h,
                      bl_c=bl_c,
                      z_pres=z_pres,
                      z_where=z_where,
                      z_what=z_what)
Beispiel #20
0
    def guide(
        self,
        x: torch.Tensor,
        x_packed_reversed: nn.utils.rnn.PackedSequence,
        seq_mask: torch.Tensor,
        seq_lengths: torch.Tensor,
        annealing=1.0,
    ) -> Tensor:

        pyro.module("dmm", self)
        batch_dim, time_steps, _ = x.shape
        h0 = self.h0.expand(self.h0.size(0), batch_dim,
                            self.h0.size(-1)).contiguous()
        h_packed_reversed = self.encode(x_packed_reversed, h0)[0]
        h_reversed, _ = pad_packed_sequence(h_packed_reversed,
                                            batch_first=True)
        h = self.reverse_sequences(h_reversed, seq_lengths)
        z = self.qz0.expand(batch_dim, self.qz0.size(-1))
        with pyro.plate("data", batch_dim):
            for t in range(time_steps):
                z_params = self.combine(h[:, t, :], z)
                with poutine.scale(None, annealing):
                    z = pyro.sample(
                        f"z_{t+1}",
                        Normal(*z_params).mask(seq_mask[:,
                                                        t:t + 1]).to_event(1),
                    )
        return z
Beispiel #21
0
def pyromodel(x, y):
    priors = {}
    for name, par in model.named_parameters():
        priors[name] = dist.Normal(torch.zeros(*par.shape),
                                   50 * torch.ones(*par.shape)).independent(
                                       par.dim())

        #print("batch shape:", priors[name].batch_shape)
        #print("event shape:", priors[name].event_shape)
        #print("event dim:", priors[name].event_dim)

    bayesian_model = pyro.random_module('bayesian_model', model, priors)
    sampled_model = bayesian_model()
    sigma = pyro.sample('sigma', Uniform(0, 50))
    with pyro.iarange("map", len(x)):
        prediction_mean = sampled_model(x)
        logging.debug(f"prediction_mean: {prediction_mean.shape}")

        if y is not None:
            logging.debug(f"y_data: {y.shape}")

        d_dist = Normal(prediction_mean, sigma).to_event(1)

        if y is not None:
            logging.debug(f"y_data: {y.shape}")

        logging.debug(f"batch shape: {d_dist.batch_shape}")
        logging.debug(f"event shape: {d_dist.event_shape}")
        logging.debug(f"event dim: {d_dist.event_dim}")

        pyro.sample("obs", d_dist, obs=y)

        return prediction_mean
Beispiel #22
0
 def reparam_dist(self, mu, sigma):
     if self.post_approx == 'diag':
         dist = Independent(Normal(mu, sigma), 1)
     elif self.post_approx == 'low_rank':
         if sigma.dim() == 2:
             W = sigma[...,
                       self.dim_stochastic:].view(sigma.shape[0],
                                                  self.dim_stochastic,
                                                  self.rank)
         elif sigma.dim() == 3:
             W = sigma[...,
                       self.dim_stochastic:].view(sigma.shape[0],
                                                  sigma.shape[1],
                                                  self.dim_stochastic,
                                                  self.rank)
         else:
             raise NotImplemented()
         D = sigma[..., :self.dim_stochastic]
         dist = LowRankMultivariateNormal(mu, W, D)
     else:
         raise ValueError('should not be here')
     sample = torch.squeeze(dist.rsample((1, )))
     if len(sample.shape) == 1:
         sample = sample[None, ...]
     return sample, dist
Beispiel #23
0
def guide(observations={'x1': 0, 'x2': 0}):
    pyro.module("first", first)
    pyro.module("second", second)
    pyro.module("third", third)
    pyro.module("fourth", fourth)
    pyro.module("fifth", fifth)

    obs = torch.tensor([float(observations['x1']), float(observations['x2'])])
    # x1 = observations['x1']
    # x2 = observations['x2']
    x1 = obs[0]
    x2 = obs[1]
    # v = torch.cat((x1.view(1, 1), x2.view(1, 1)), 1)
    v = torch.cat(
        (torch.Tensor.view(x1, [1, 1]), torch.Tensor.view(x2, [1, 1])), 1)

    h1 = relu(first(v))
    h2 = relu(second(h1))
    h3 = relu(third(h2))
    h4 = relu(fourth(h3))
    out = fifth(h4)

    mean = out[0, 0]
    # std = out[0, 1].exp()
    std = torch.exp(out[0, 1])
    pyro.sample("z", Normal(mean, std))
Beispiel #24
0
 def _sample(name, y):
     z_mu = get_module(
         f"{name}-mu",
         lambda: torch.nn.Sequential(
             torch.nn.Conv2d(y.shape[1], y.shape[1], 1),
             torch.nn.BatchNorm2d(y.shape[1], momentum=0.05),
             torch.nn.LeakyReLU(0.2, inplace=True),
             torch.nn.Conv2d(y.shape[1], y.shape[1], 1),
         ),
         checkpoint=True,
     )
     z_sd = get_module(
         f"{name}-sd",
         lambda: torch.nn.Sequential(
             torch.nn.Conv2d(y.shape[1], y.shape[1], 1),
             torch.nn.BatchNorm2d(y.shape[1], momentum=0.05),
             torch.nn.LeakyReLU(0.2, inplace=True),
             torch.nn.Conv2d(y.shape[1], y.shape[1], 1),
             torch.nn.Softplus(),
         ),
         checkpoint=True,
     )
     return p.sample(
         name, Normal(z_mu(y), 1e-8 + z_sd(y)).to_event(3)
     )
Beispiel #25
0
 def vnormal(name, target):
     softplus = nn.Softplus()
     return Normal(loc=pyro.param(name + '_m',
                                  torch.randn_like(target)),
                   scale=softplus(
                       pyro.param(name + '_s',
                                  torch.randn_like(target))))
    def encode(self,
               src,
               src_mask,
               src_lengths,
               pad_pack=True,
               calc_z=True,
               deterministic=True):
        #TODO need to add option to use surrogate...not ...that important atm
        X = self.src_embed(src)
        mu_x, sig_x, latent_input = self.inference_network(X,
                                                           src_mask,
                                                           src_lengths,
                                                           pad_pack_x=True)

        if self.use_latent:
            if deterministic:
                self.z = mu_x
            else:
                self.z = (Normal(mu_x, sig_x).to_event(1)).sample()
            #z is otherwise only used as additional input where as project is for initializing hidden states so has to align with rnn hidden state
            self.z = self.applyFlows(self.z, cond_inp=latent_input)
            self.z_hid = self.project(self.z)
        else:
            self.z = torch.zeros_like(mu_x)
            self.z_hid = self.project(self.z)

        z_hid = self.resize_z(self.z_hid, 2 * self.num_layers)
        hidden_states, encoder_final = super(GenerativeEncoderDecoder,
                                             self).encode(src,
                                                          src_mask,
                                                          src_lengths,
                                                          pad_pack=pad_pack,
                                                          hidden=z_hid)
        return hidden_states, encoder_final
Beispiel #27
0
Datei: st.py Projekt: ludvb/xfuse
 def _sample_metagene(name, metagene):
     if metagene.profile is None:
         metagene = MetageneDefault(
             metagene.scale, torch.zeros(len(self._allocated_genes))
         )
     mu = get_param(
         f"{_encode_metagene_name(name)}_mu",
         # pylint: disable=unnecessary-lambda
         lambda: metagene.profile.float(),
         lr_multiplier=2.0,
     )
     sd = get_param(
         f"{_encode_metagene_name(name)}_sd",
         lambda: 1e-2 * torch.ones_like(mu),
         constraint=constraints.positive,
         lr_multiplier=2.0,
     )
     if len(self.__metagenes) < 2:
         mu = mu.detach()
         sd = sd.detach()
     pyro.sample(
         _encode_metagene_name(name),
         Normal(mu, 1e-8 + sd),
         infer={"is_global": True},
     )
Beispiel #28
0
    def p_Zt_Ztm1(self, Zg, Zt_1T, A, B, Xt):
        mu0 = self.pre_t_mu(Zg)[:, None, :]
        sig0 = torch.nn.functional.softplus(self.pre_t_sigma(Zg))[:, None, :]
        Tmax = Zt_1T.shape[1]
        Z_rep = Zg[:, None, :].repeat(1, Tmax - 1, 1)
        if self.augmented:
            Zinp = torch.cat([Zt_1T, Xt], -1)
        else:
            Zinp = Zt_1T
        inp = torch.cat([Zinp[:, :-1, :], A[:, 1:Tmax, :], Z_rep], -1)

        if self.include_baseline:
            Aval = A[:, 1:Tmax, :]
            # include baseline in both control and input signals
            Acat = torch.cat([
                Aval[..., [0]], B[:, None, :].repeat(1, Aval.shape[1], 1),
                Aval[..., 1:]
            ], -1)
            inp = torch.cat([B[:, None, :].repeat(1, Aval.shape[1], 1), inp],
                            -1)
            mu1T, sig1T = self.transition_fxn(inp, Acat)
        else:
            mu1T, sig1T = self.transition_fxn(inp, A[:, 1:Tmax, :])

        mu, sig = torch.cat([mu0, mu1T], 1), torch.cat([sig0, sig1T], 1)
        return Independent(Normal(mu, sig), 1)
Beispiel #29
0
def partially_pooled_with_logit(at_bats, hits):
    r"""
    Number of hits has a Binomial distribution with a logit link function.
    The logits $\alpha$ for each player is normally distributed with the
    mean and scale parameters sharing a common prior.

    :param (torch.Tensor) at_bats: Number of at bats for each player.
    :param (torch.Tensor) hits: Number of hits for the given at bats.
    :return: Number of hits predicted by the model.
    """
    num_players = at_bats.shape[0]
    loc = pyro.sample("loc", Normal(scalar_like(at_bats, -1), scalar_like(at_bats, 1)))
    scale = pyro.sample("scale", HalfCauchy(scale=scalar_like(at_bats, 1)))
    with pyro.plate("num_players", num_players):
        alpha = pyro.sample("alpha", Normal(loc, scale))
        return pyro.sample("obs", Binomial(at_bats, logits=alpha), obs=hits)
Beispiel #30
0
    def p_Zt_Ztm1(self, Zt, A, B, X, A0, Am, eps=0.):
        X0 = X[:, 0, :]
        Xt = X[:, 1:, :]
        inp_cat = torch.cat([B, X0, A0], -1)
        mu1 = self.prior_W(inp_cat)[:, None, :]
        sig1 = torch.nn.functional.softplus(self.prior_sigma(inp_cat))[:,
                                                                       None, :]

        Tmax = Zt.shape[1]
        if self.hparams['augmented']:
            Zinp = torch.cat([Zt[:, :-1, :], Xt[:, :-1, :]], -1)
        else:
            Zinp = Zt[:, :-1, :]
        Aval = A[:, 1:Tmax, :]
        sub_mask = np.triu(np.ones(
            (Aval.shape[0], Aval.shape[1], Aval.shape[1])),
                           k=1).astype('uint8')
        Zm = (torch.from_numpy(sub_mask) == 0).to(Am.device)
        res = self.attn(self.attn_lin(torch.cat([Xt[:, :-1, :], Aval], -1)),
                        Zinp,
                        Zinp,
                        mask=Zm,
                        use_matmul=True)
        if self.hparams['include_baseline']:
            Acat = torch.cat([
                Aval[..., [0]], B[:, None, :].repeat(1, Aval.shape[1], 1),
                Aval[..., 1:]
            ], -1)
            mu2T, sig2T = self.transition_fxn(res, Acat, eps=eps)
        else:
            mu2T, sig2T = self.transition_fxn(res, A[:, 1:Tmax, :], eps=eps)
        mu, sig = torch.cat([mu1, mu2T], 1), torch.cat([sig1, sig2T], 1)
        return Independent(Normal(mu, sig), 1)