def pyro_model(x, y): priors= { 'covar_module.base_kernel.raw_lengthscale': Normal(0, 2).expand([1, 1]), 'covar_module.raw_outputscale': Normal(0, 2), 'likelihood.noise_covar.raw_noise': Normal(0, 2).expand([1]), 'mean_module.constant': Normal(0, 2), } fn = pyro.random_module("model", model, prior=priors) sampled_model = fn() output = sampled_model.likelihood(sampled_model(x)) pyro.sample("obs", output, obs=y)
def model(): p = torch.tensor([0.5]) loc = torch.zeros(1) scale = torch.ones(1) x = pyro.sample("x", Normal(loc, scale)) # Before the discrete variable. y = pyro.sample("y", Bernoulli(p)) z = pyro.sample("z", Normal(loc, scale)) # After the discrete variable. return dict(x=x, y=y, z=z)
def bayes_logistic(X, y): n, k = X.shape w_prior = Normal(torch.zeros(1, k), torch.ones(1, k)).to_event(1) b_prior = Normal(torch.tensor([[0.]]), torch.tensor([[10.]])).to_event(1) priors = {"linear.weight": w_prior, "linear.bias": b_prior} lifted_module = \ pyro.random_module("bayes_logistic", frequentist_model, priors) lifted_model = lifted_module() with pyro.plate("customers", n): y_pred = lifted_model(X).squeeze(1) pyro.sample("obs", Bernoulli(y_pred, validate_args=True), obs=y) return y_pred
def model(is_cont_africa, ruggedness, log_gdp): # WL: edited. ===== # a = pyro.sample("a", Normal(8., 1000.)) a = pyro.sample("a", Normal(0., 10.)) # ================= b_a = pyro.sample("bA", Normal(0., 1.)) b_r = pyro.sample("bR", Normal(0., 1.)) b_ar = pyro.sample("bAR", Normal(0., 1.)) sigma = pyro.sample("sigma", Uniform(0., 10.)) mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness with pyro.plate("data", 170): pyro.sample("obs", Normal(mean, sigma), obs=log_gdp)
def guide(data): w_mu = Variable(torch.randn(second_layer, first_layer).type_as(data.data), requires_grad=True) w_log_sig = Variable( 0.1 * torch.ones(second_layer, first_layer).type_as(data.data), requires_grad=True) b_mu = Variable(torch.randn(second_layer).type_as(data.data), requires_grad=True) b_log_sig = Variable(0.1 * torch.ones(second_layer).type_as(data.data), requires_grad=True) # register learnable params in the param store mw_param = pyro.param("guide_mean_weight", w_mu) sw_param = softplus(pyro.param("guide_log_sigma_weight", w_log_sig)) mb_param = pyro.param("guide_mean_bias", b_mu) sb_param = softplus(pyro.param("guide_log_sigma_bias", b_log_sig)) # gaussian guide distributions for w and b w_dist = Normal(mw_param, sw_param) b_dist = Normal(mb_param, sb_param) w_mu2 = Variable(torch.randn(1, second_layer).type_as(data.data), requires_grad=True) w_log_sig2 = Variable(0.1 * torch.randn(1, second_layer).type_as(data.data), requires_grad=True) b_mu2 = Variable(torch.randn(1).type_as(data.data), requires_grad=True) b_log_sig2 = Variable(0.1 * torch.ones(1).type_as(data.data), requires_grad=True) # register learnable params in the param store mw_param2 = pyro.param("guide_mean_weight2", w_mu2) sw_param2 = softplus(pyro.param("guide_log_sigma_weight2", w_log_sig2)) mb_param2 = pyro.param("guide_mean_bias2", b_mu2) sb_param2 = softplus(pyro.param("guide_log_sigma_bias2", b_log_sig2)) # gaussian guide distributions for w and b w_dist2 = Normal(mw_param2, sw_param2) b_dist2 = Normal(mb_param2, sb_param2) dists = { 'hidden.weight': w_dist, 'hidden.bias': b_dist, 'predict.weight': w_dist2, 'predict.bias': b_dist2 } # overloading the parameters in the module with random samples from the guide distributions lifted_module = pyro.random_module("module", regression_model, dists) # sample a regressor return lifted_module()
def generate(self, num_to_sample: int = 1): """Generate samples from prior.""" cuda_device = self._get_prediction_device() prior_mean = nn_util.move_to_device( torch.zeros((num_to_sample, self._latent_dim)), cuda_device, ) prior_stddev = torch.ones_like(prior_mean) prior = Normal(prior_mean, prior_stddev) latent = prior.sample() generated = self._decoder.generate(latent) return self.make_output_human_readable(generated)
def guide(): mu1 = pyro.param("mu1", Variable(torch.randn(2), requires_grad=True)) sigma1 = pyro.param("sigma1", Variable(torch.ones(2), requires_grad=True)) pyro.sample("latent1", Normal(mu1, sigma1)) mu2 = pyro.param("mu2", Variable(torch.randn(2), requires_grad=True)) sigma2 = pyro.param("sigma2", Variable(torch.ones(2), requires_grad=True)) latent2 = pyro.sample("latent2", Normal(mu2, sigma2)) return latent2
def __init__(self, size_in, prior_factor=1.0, weight_prior_std=1.0, bias_prior_std=3.0, **kwargs): self._params = OrderedDict() self._param_dists = OrderedDict() self.prior_factor = prior_factor self.gp = VectorizedGP(size_in, **kwargs) for name, shape in self.gp.parameter_shapes().items(): if name == 'constant_mean': mean_p_loc = torch.zeros(1).to(device) mean_p_scale = torch.ones(1).to(device) self._param_dist(name, Normal(mean_p_loc, mean_p_scale).to_event(1)) if name == 'lengthscale_raw': lengthscale_p_loc = torch.zeros(shape[-1]).to(device) lengthscale_p_scale = torch.ones(shape[-1]).to(device) self._param_dist( name, Normal(lengthscale_p_loc, lengthscale_p_scale).to_event(1)) if name == 'noise_raw': noise_p_loc = -1. * torch.ones(1).to(device) noise_p_scale = torch.ones(1).to(device) self._param_dist( name, Normal(noise_p_loc, noise_p_scale).to_event(1)) if 'mean_nn' in name or 'kernel_nn' in name: mean = torch.zeros(shape).to(device) if "weight" in name: std = weight_prior_std * torch.ones(shape).to(device) elif "bias" in name: std = bias_prior_std * torch.ones(shape).to(device) else: raise NotImplementedError self._param_dist(name, Normal(mean, std).to_event(1)) # check that parameters in prior and gp modules are aligned for param_name_gp, param_name_prior in zip( self.gp.named_parameters().keys(), self._param_dists.keys()): assert param_name_gp == param_name_prior self.hyper_prior = CatDist(self._param_dists.values())
def __init__(self, use_affine_ex=True, **kwargs): super.__init__(**kwargs) self.num_scales = 2 self.register_buffer("glasses_base_loc", torch.zeros([ 1, ], requires_grad=False)) self.register_buffer("glasses_base_scale", torch.ones([ 1, ], requires_grad=False)) self.register_buffer("glasses_flow_lognorm_loc", torch.zeros([], requires_grad=False)) self.register_buffer("glasses_flow_lognorm_scale", torch.ones([], requires_grad=False)) self.glasses_flow_components = ComposeTransformModule([Spline(1)]) self.glasses_flow_constraint_transforms = ComposeTransform( [self.glasses_flow_lognorm, SigmoidTransform()]) self.glasses_flow_transforms = ComposeTransform([ self.glasses_flow_components, self.glasses_flow_constraint_transforms ]) glasses_base_dist = Normal(self.glasses_base_loc, self.glasses_base_scale).to_event(1) self.glasses_dist = TransformedDistribution( glasses_base_dist, self.glasses_flow_transforms) glasses_ = pyro.sample("glasses_", self.glasses_dist) glasses = pyro.sample("glasses", dist.Bernoulli(glasses_)) glasses_context = self.glasses_flow_constraint_transforms.inv(glasses_) self.x_transforms = self._build_image_flow() self.register_buffer("x_base_loc", torch.zeros([1, 64, 64], requires_grad=False)) self.register_buffer("x_base_scale", torch.ones([1, 64, 64], requires_grad=False)) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) cond_x_transforms = ComposeTransform( ConditionalTransformedDistribution( x_base_dist, self.x_transforms).condition(context).transforms).inv cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms) x = pyro.sample("x", cond_x_dist) return x, glasses
def model_2(x_data, y_data): conv1w_prior = Normal(loc=torch.zeros_like(net.conv1.weight), scale=torch.ones_like(net.conv1.weight)) conv1b_prior = Normal(loc=torch.zeros_like(net.conv1.bias), scale=torch.ones_like(net.conv1.bias)) conv2w_prior = Normal(loc=torch.zeros_like(net.conv2.weight), scale=torch.ones_like(net.conv2.weight)) conv2b_prior = Normal(loc=torch.zeros_like(net.conv2.bias), scale=torch.ones_like(net.conv2.bias)) fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight)) fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias)) fc2w_prior = Normal(loc=torch.zeros_like(net.fc2.weight), scale=torch.ones_like(net.fc2.weight)) fc2b_prior = Normal(loc=torch.zeros_like(net.fc2.bias), scale=torch.ones_like(net.fc2.bias)) priors = {'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior, 'conv2.weight': conv2w_prior, 'conv2.bias': conv2b_prior, 'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior} lifted_module = pyro.random_module("module", net, priors) lifted_reg_model = lifted_module() lhat = log_softmax(lifted_reg_model(x_data)) pyro.sample("obs", Categorical(logits=lhat), obs=y_data)
def partially_pooled(at_bats): """ Number of hits has a Binomial distribution with a logit link function. The logits $\alpha$ for each player is normally distributed with the mean and scale parameters sharing a common prior. :param (torch.Tensor) at_bats: Number of at bats for each player. :return: Number of hits predicted by the model. """ num_players = at_bats.shape[0] loc = pyro.sample("loc", Normal(at_bats.new_tensor(-1), at_bats.new_tensor(1))) scale = pyro.sample("scale", HalfCauchy(at_bats.new_tensor(0), at_bats.new_tensor(1))) alpha = pyro.sample("alpha", Normal(loc, scale).expand_by([num_players]).independent(1)) return pyro.sample("obs", Binomial(at_bats, logits=alpha))
def guide_2(x_data, y_data): conv1w_mu = torch.randn_like(net.conv1.weight) conv1w_sigma = torch.randn_like(net.conv1.weight) conv1w_mu_param = pyro.param("conv1w_mu", conv1w_mu) conv1w_sigma_param = softplus(pyro.param("conv1w_sigma", conv1w_sigma)) conv1w_prior = Normal(loc=conv1w_mu_param, scale=conv1w_sigma_param) conv1b_mu = torch.randn_like(net.conv1.bias) conv1b_sigma = torch.randn_like(net.conv1.bias) conv1b_mu_param = pyro.param("conv1b_mu", conv1b_mu) conv1b_sigma_param = softplus(pyro.param("conv1b_sigma", conv1b_sigma)) conv1b_prior = Normal(loc=conv1b_mu_param, scale=conv1b_sigma_param) conv2w_mu = torch.randn_like(net.conv2.weight) conv2w_sigma = torch.randn_like(net.conv2.weight) conv2w_mu_param = pyro.param("conv2w_mu", conv2w_mu) conv2w_sigma_param = softplus(pyro.param("conv2w_sigma", conv2w_sigma)) conv2w_prior = Normal(loc=conv2w_mu_param, scale=conv2w_sigma_param) conv2b_mu = torch.randn_like(net.conv2.bias) conv2b_sigma = torch.randn_like(net.conv2.bias) conv2b_mu_param = pyro.param("conv2b_mu", conv2b_mu) conv2b_sigma_param = softplus(pyro.param("conv2b_sigma", conv2b_sigma)) conv2b_prior = Normal(loc=conv2b_mu_param, scale=conv2b_sigma_param) # First layer weight distribution priors fc1w_mu = torch.randn_like(net.fc1.weight) fc1w_sigma = torch.randn_like(net.fc1.weight) fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu) fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma)) fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param) fc1b_mu = torch.randn_like(net.fc1.bias) fc1b_sigma = torch.randn_like(net.fc1.bias) fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu) fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma)) fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param) fc2w_mu = torch.randn_like(net.fc2.weight) fc2w_sigma = torch.randn_like(net.fc2.weight) fc2w_mu_param = pyro.param("fc2w_mu", fc2w_mu) fc2w_sigma_param = softplus(pyro.param("fc2w_sigma", fc2w_sigma)) fc2w_prior = Normal(loc=fc2w_mu_param, scale=fc2w_sigma_param) fc2b_mu = torch.randn_like(net.fc2.bias) fc2b_sigma = torch.randn_like(net.fc2.bias) fc2b_mu_param = pyro.param("fc2b_mu", fc2b_mu) fc2b_sigma_param = softplus(pyro.param("fc2b_sigma", fc2b_sigma)) fc2b_prior = Normal(loc=fc2b_mu_param, scale=fc2b_sigma_param) priors = {'conv1.weight': conv1w_prior, 'conv1.bias': conv1b_prior, 'conv2.weight': conv2w_prior, 'conv2.bias': conv2b_prior, 'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior} lifted_module = pyro.random_module("module", net, priors) return lifted_module()
def pgm_model(self): sex_dist = Bernoulli(logits=self.sex_logits).to_event(1) _ = self.sex_logits sex = pyro.sample('sex', sex_dist) age_base_dist = Normal(self.age_base_loc, self.age_base_scale).to_event(1) age_dist = TransformedDistribution(age_base_dist, self.age_flow_transforms) age = pyro.sample('age', age_dist) age_ = self.age_flow_constraint_transforms.inv(age) # pseudo call to thickness_flow_transforms to register with pyro _ = self.age_flow_components brain_context = torch.cat([sex, age_], 1) brain_volume_base_dist = Normal( self.brain_volume_base_loc, self.brain_volume_base_scale).to_event(1) brain_volume_dist = ConditionalTransformedDistribution( brain_volume_base_dist, self.brain_volume_flow_transforms).condition(brain_context) brain_volume = pyro.sample('brain_volume', brain_volume_dist) # pseudo call to intensity_flow_transforms to register with pyro _ = self.brain_volume_flow_components brain_volume_ = self.brain_volume_flow_constraint_transforms.inv( brain_volume) ventricle_context = torch.cat([age_, brain_volume_], 1) ventricle_volume_base_dist = Normal( self.ventricle_volume_base_loc, self.ventricle_volume_base_scale).to_event(1) ventricle_volume_dist = ConditionalTransformedDistribution( ventricle_volume_base_dist, self.ventricle_volume_flow_transforms).condition( ventricle_context) # noqa: E501 ventricle_volume = pyro.sample('ventricle_volume', ventricle_volume_dist) # pseudo call to intensity_flow_transforms to register with pyro _ = self.ventricle_volume_flow_components return age, sex, ventricle_volume, brain_volume
def guide(x_data, y_data): # First layer weight distribution priors fc1w_mu = torch.randn_like(net.fc1.weight) fc1w_sigma = torch.randn_like(net.fc1.weight) fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu) fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma)) fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param) # First layer bias distribution priors fc1b_mu = torch.randn_like(net.fc1.bias) fc1b_sigma = torch.randn_like(net.fc1.bias) fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu) fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma)) fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param) # Second layer weight distribution priors fc2w_mu = torch.randn_like(net.fc2.weight) fc2w_sigma = torch.randn_like(net.fc2.weight) fc2w_mu_param = pyro.param("fc2w_mu", fc2w_mu) fc2w_sigma_param = softplus(pyro.param("fc2w_sigma", fc2w_sigma)) fc2w_prior = Normal(loc=fc2w_mu_param, scale=fc2w_sigma_param) # Second layer bias distribution priors fc2b_mu = torch.randn_like(net.fc2.bias) fc2b_sigma = torch.randn_like(net.fc2.bias) fc2b_mu_param = pyro.param("fc2b_mu", fc2b_mu) fc2b_sigma_param = softplus(pyro.param("fc2b_sigma", fc2b_sigma)) fc2b_prior = Normal(loc=fc2b_mu_param, scale=fc2b_sigma_param) #3rd layer weight distribution priors fc3w_mu = torch.randn_like(net.fc3.weight) fc3w_sigma = torch.randn_like(net.fc3.weight) fc3w_mu_param = pyro.param("fc3w_mu", fc3w_mu) fc3w_sigma_param = softplus(pyro.param("fc3w_sigma", fc3w_sigma)) fc3w_prior = Normal(loc=fc3w_mu_param, scale=fc3w_sigma_param) # Third layer bias distribution priors fc3b_mu = torch.randn_like(net.fc3.bias) fc3b_sigma = torch.randn_like(net.fc3.bias) fc3b_mu_param = pyro.param("fc3b_mu", fc3b_mu) fc3b_sigma_param = softplus(pyro.param("fc3b_sigma", fc3b_sigma)) fc3b_prior = Normal(loc=fc3b_mu_param, scale=fc3b_sigma_param) # Output layer weight distribution priors outw_mu = torch.randn_like(net.out.weight) outw_sigma = torch.randn_like(net.out.weight) outw_mu_param = pyro.param("outw_mu", outw_mu) outw_sigma_param = softplus(pyro.param("outw_sigma", outw_sigma)) outw_prior = Normal(loc=outw_mu_param, scale=outw_sigma_param).independent(1) # Output layer bias distribution priors outb_mu = torch.randn_like(net.out.bias) outb_sigma = torch.randn_like(net.out.bias) outb_mu_param = pyro.param("outb_mu", outb_mu) outb_sigma_param = softplus(pyro.param("outb_sigma", outb_sigma)) outb_prior = Normal(loc=outb_mu_param, scale=outb_sigma_param) priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior,\ 'fc3.weight': fc3w_prior, 'fc3.bias': fc3b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior} lifted_module = pyro.random_module("module", net, priors) return lifted_module()
def guide(fc_network: BNN, x_data, y_data): """ Approximation of the posterior P(w|x_data), --> likelihood p(y_data|w, x_data) :param fc_network: :param x_data: :param y_data: :return: """ # create weight distribution parameters priors priors = {} for i, layer in enumerate(fc_network.fc): if not hasattr(layer, 'weight'): continue # print("guide: ",i,layer) # print('guide_shapes',layer.weight.shape, layer.bias.shape) fciw_mu = Variable(torch.randn_like(layer.weight).type_as(x_data), requires_grad=True) fcib_mu = Variable(torch.randn_like(layer.bias).type_as(x_data), requires_grad=True) fciw_sigma = Variable(0.1 * torch.randn_like(layer.weight).type_as(x_data), requires_grad=True) fcib_sigma = Variable(0.1 * torch.randn_like(layer.bias).type_as(x_data), requires_grad=True) fciw_mu_param = pyro.param("guide.{}.w_mu".format(str(i)), fciw_mu) fcib_mu_param = pyro.param("guide.{}.b_mu".format(str(i)), fcib_mu) fciw_sigma_param = softplus( pyro.param("guide.{}.w_sigma".format(str(i)), fciw_sigma)) fcib_sigma_param = softplus( pyro.param("guide.{}.b_sigma".format(str(i)), fcib_sigma)) fciw_prior = Normal(fciw_mu_param, fciw_sigma_param) fcib_prior = Normal(fcib_mu_param, fcib_sigma_param) # TODO prior should have the same weight as in for name, _ in fc_network.named_parameters(), # according to https://forum.pyro.ai/t/how-does-pyro-random-module-match-priors-with-regressionmodel-parameters/528/7 priors['model.{}.weight'.format(str(i))] = fciw_prior priors['model.{}.bias'.format(str(i))] = fcib_prior # lifted_module = pyro.module("module", fc_network, priors) # print('guide: ',priors) # for name, _ in fc_network.named_parameters(): # print(name) # exit(0) lifted_module = pyro.random_module("module", fc_network, priors) random_model = lifted_module() # print('lifted_module', random_model) return random_model
def guide(is_cont_africa, ruggedness, log_gdp): a_loc = pyro.param('a_loc', torch.tensor(0.)) a_scale = pyro.param('a_scale', torch.tensor(1.), constraint=constraints.positive) sigma_loc = pyro.param('sigma_loc', torch.tensor(1.), constraint=constraints.positive) weights_loc = pyro.param('weights_loc', torch.rand(3)) weights_scale = pyro.param('weights_scale', torch.ones(3), constraint=constraints.positive) a = pyro.sample("a", Normal(a_loc, a_scale)) b_a = pyro.sample("bA", Normal(weights_loc[0], weights_scale[0])) b_r = pyro.sample("bR", Normal(weights_loc[1], weights_scale[1])) b_ar = pyro.sample("bAR", Normal(weights_loc[2], weights_scale[2])) sigma = pyro.sample("sigma", Normal(sigma_loc, torch.tensor(0.05))) mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
def linear(): x_data, y_data = [1, 2, 3, 4, 5, 6], torch.tensor([2.2, 4.2, 5.5, 8.3, 9.9, 12.1]) k = sample('k', pyro.distributions.Normal(0, 1)) if k < 0: slope = sample('slope', Normal(0, 5)) else: slope = sample('slope', pyro.distributions.Bernoulli(0.5)) bias = sample('bias', Normal(0, 5)) for i in range(len(x_data)): x = x_data[i] mu = x * slope + bias y = sample(f"y_{i}", Normal(mu, 1), obs=y_data[i])
def update_noise_svi(self, observed_steady_state, initial_noise): def guide(noise): noise_terms = list(noise.keys()) mu_constraints = constraints.interval(-3., 3.) sigma_constraints = constraints.interval(.0001, 3) mu = { k: pyro.param('{}_mu'.format(k), tensor(0.), constraint=mu_constraints) for k in noise_terms } sigma = { k: pyro.param('{}_sigma'.format(k), tensor(1.), constraint=sigma_constraints) for k in noise_terms } for noise in noise_terms: sample(noise, Normal(mu[noise], sigma[noise])) observation_model = condition(self.noisy_model, observed_steady_state) pyro.clear_param_store() svi = SVI(model=observation_model, guide=guide, optim=SGD({ "lr": 0.001, "momentum": 0.1 }), loss=Trace_ELBO()) losses = [] num_steps = 1000 samples = defaultdict(list) for t in range(num_steps): losses.append(svi.step(initial_noise)) for noise in initial_noise.keys(): mu = '{}_mu'.format(noise) sigma = '{}_sigma'.format(noise) samples[mu].append(pyro.param(mu).item()) samples[sigma].append(pyro.param(sigma).item()) means = {k: statistics.mean(v) for k, v in samples.items()} updated_noise = { 'N_Raf': Normal(means['N_Raf_mu'], means['N_Raf_sigma']), 'N_Mek': Normal(means['N_Mek_mu'], means['N_Mek_sigma']), 'N_Erk': Normal(means['N_Erk_mu'], means['N_Erk_sigma']) } return updated_noise, losses
def guide_step(t, n, prev, inputs): rnn_input = torch.cat( (inputs['embed'], prev.z_where, prev.z_what, prev.z_pres), 1) h, c = rnn(rnn_input, (prev.h, prev.c)) #===== predict out = predict_l2(F.relu(predict_l1(h))) z_pres_p = torch.sigmoid(out[:, 0:1]) z_where_loc = out[:, 1:4] z_where_scale = F.softplus(out[:, 4:]) #===== predict infer_dict, bl_h, bl_c = baseline_step(prev, inputs) z_pres =\ pyro.sample('z_pres_{}'.format(t), Bernoulli(z_pres_p * prev.z_pres).to_event(1), infer=infer_dict) sample_mask = z_pres if use_masking else torch.tensor(1.0) z_where =\ pyro.sample('z_where_{}'.format(t), Normal(z_where_loc + z_where_loc_prior, z_where_scale * z_where_scale_prior) .mask(sample_mask) .to_event(1)) x_att = image_to_window(z_where, window_size, x_size, inputs['raw']) #===== encode a = encode_l2(F.relu(encode_l1(x_att))) z_what_loc = a[:, 0:50] z_what_scale = F.softplus(a[:, 50:]) #===== encode z_what =\ pyro.sample('z_what_{}'.format(t), Normal(z_what_loc, z_what_scale) .mask(sample_mask) .to_event(1)) return GuideState(h=h, c=c, bl_h=bl_h, bl_c=bl_c, z_pres=z_pres, z_where=z_where, z_what=z_what)
def guide( self, x: torch.Tensor, x_packed_reversed: nn.utils.rnn.PackedSequence, seq_mask: torch.Tensor, seq_lengths: torch.Tensor, annealing=1.0, ) -> Tensor: pyro.module("dmm", self) batch_dim, time_steps, _ = x.shape h0 = self.h0.expand(self.h0.size(0), batch_dim, self.h0.size(-1)).contiguous() h_packed_reversed = self.encode(x_packed_reversed, h0)[0] h_reversed, _ = pad_packed_sequence(h_packed_reversed, batch_first=True) h = self.reverse_sequences(h_reversed, seq_lengths) z = self.qz0.expand(batch_dim, self.qz0.size(-1)) with pyro.plate("data", batch_dim): for t in range(time_steps): z_params = self.combine(h[:, t, :], z) with poutine.scale(None, annealing): z = pyro.sample( f"z_{t+1}", Normal(*z_params).mask(seq_mask[:, t:t + 1]).to_event(1), ) return z
def pyromodel(x, y): priors = {} for name, par in model.named_parameters(): priors[name] = dist.Normal(torch.zeros(*par.shape), 50 * torch.ones(*par.shape)).independent( par.dim()) #print("batch shape:", priors[name].batch_shape) #print("event shape:", priors[name].event_shape) #print("event dim:", priors[name].event_dim) bayesian_model = pyro.random_module('bayesian_model', model, priors) sampled_model = bayesian_model() sigma = pyro.sample('sigma', Uniform(0, 50)) with pyro.iarange("map", len(x)): prediction_mean = sampled_model(x) logging.debug(f"prediction_mean: {prediction_mean.shape}") if y is not None: logging.debug(f"y_data: {y.shape}") d_dist = Normal(prediction_mean, sigma).to_event(1) if y is not None: logging.debug(f"y_data: {y.shape}") logging.debug(f"batch shape: {d_dist.batch_shape}") logging.debug(f"event shape: {d_dist.event_shape}") logging.debug(f"event dim: {d_dist.event_dim}") pyro.sample("obs", d_dist, obs=y) return prediction_mean
def reparam_dist(self, mu, sigma): if self.post_approx == 'diag': dist = Independent(Normal(mu, sigma), 1) elif self.post_approx == 'low_rank': if sigma.dim() == 2: W = sigma[..., self.dim_stochastic:].view(sigma.shape[0], self.dim_stochastic, self.rank) elif sigma.dim() == 3: W = sigma[..., self.dim_stochastic:].view(sigma.shape[0], sigma.shape[1], self.dim_stochastic, self.rank) else: raise NotImplemented() D = sigma[..., :self.dim_stochastic] dist = LowRankMultivariateNormal(mu, W, D) else: raise ValueError('should not be here') sample = torch.squeeze(dist.rsample((1, ))) if len(sample.shape) == 1: sample = sample[None, ...] return sample, dist
def guide(observations={'x1': 0, 'x2': 0}): pyro.module("first", first) pyro.module("second", second) pyro.module("third", third) pyro.module("fourth", fourth) pyro.module("fifth", fifth) obs = torch.tensor([float(observations['x1']), float(observations['x2'])]) # x1 = observations['x1'] # x2 = observations['x2'] x1 = obs[0] x2 = obs[1] # v = torch.cat((x1.view(1, 1), x2.view(1, 1)), 1) v = torch.cat( (torch.Tensor.view(x1, [1, 1]), torch.Tensor.view(x2, [1, 1])), 1) h1 = relu(first(v)) h2 = relu(second(h1)) h3 = relu(third(h2)) h4 = relu(fourth(h3)) out = fifth(h4) mean = out[0, 0] # std = out[0, 1].exp() std = torch.exp(out[0, 1]) pyro.sample("z", Normal(mean, std))
def _sample(name, y): z_mu = get_module( f"{name}-mu", lambda: torch.nn.Sequential( torch.nn.Conv2d(y.shape[1], y.shape[1], 1), torch.nn.BatchNorm2d(y.shape[1], momentum=0.05), torch.nn.LeakyReLU(0.2, inplace=True), torch.nn.Conv2d(y.shape[1], y.shape[1], 1), ), checkpoint=True, ) z_sd = get_module( f"{name}-sd", lambda: torch.nn.Sequential( torch.nn.Conv2d(y.shape[1], y.shape[1], 1), torch.nn.BatchNorm2d(y.shape[1], momentum=0.05), torch.nn.LeakyReLU(0.2, inplace=True), torch.nn.Conv2d(y.shape[1], y.shape[1], 1), torch.nn.Softplus(), ), checkpoint=True, ) return p.sample( name, Normal(z_mu(y), 1e-8 + z_sd(y)).to_event(3) )
def vnormal(name, target): softplus = nn.Softplus() return Normal(loc=pyro.param(name + '_m', torch.randn_like(target)), scale=softplus( pyro.param(name + '_s', torch.randn_like(target))))
def encode(self, src, src_mask, src_lengths, pad_pack=True, calc_z=True, deterministic=True): #TODO need to add option to use surrogate...not ...that important atm X = self.src_embed(src) mu_x, sig_x, latent_input = self.inference_network(X, src_mask, src_lengths, pad_pack_x=True) if self.use_latent: if deterministic: self.z = mu_x else: self.z = (Normal(mu_x, sig_x).to_event(1)).sample() #z is otherwise only used as additional input where as project is for initializing hidden states so has to align with rnn hidden state self.z = self.applyFlows(self.z, cond_inp=latent_input) self.z_hid = self.project(self.z) else: self.z = torch.zeros_like(mu_x) self.z_hid = self.project(self.z) z_hid = self.resize_z(self.z_hid, 2 * self.num_layers) hidden_states, encoder_final = super(GenerativeEncoderDecoder, self).encode(src, src_mask, src_lengths, pad_pack=pad_pack, hidden=z_hid) return hidden_states, encoder_final
def _sample_metagene(name, metagene): if metagene.profile is None: metagene = MetageneDefault( metagene.scale, torch.zeros(len(self._allocated_genes)) ) mu = get_param( f"{_encode_metagene_name(name)}_mu", # pylint: disable=unnecessary-lambda lambda: metagene.profile.float(), lr_multiplier=2.0, ) sd = get_param( f"{_encode_metagene_name(name)}_sd", lambda: 1e-2 * torch.ones_like(mu), constraint=constraints.positive, lr_multiplier=2.0, ) if len(self.__metagenes) < 2: mu = mu.detach() sd = sd.detach() pyro.sample( _encode_metagene_name(name), Normal(mu, 1e-8 + sd), infer={"is_global": True}, )
def p_Zt_Ztm1(self, Zg, Zt_1T, A, B, Xt): mu0 = self.pre_t_mu(Zg)[:, None, :] sig0 = torch.nn.functional.softplus(self.pre_t_sigma(Zg))[:, None, :] Tmax = Zt_1T.shape[1] Z_rep = Zg[:, None, :].repeat(1, Tmax - 1, 1) if self.augmented: Zinp = torch.cat([Zt_1T, Xt], -1) else: Zinp = Zt_1T inp = torch.cat([Zinp[:, :-1, :], A[:, 1:Tmax, :], Z_rep], -1) if self.include_baseline: Aval = A[:, 1:Tmax, :] # include baseline in both control and input signals Acat = torch.cat([ Aval[..., [0]], B[:, None, :].repeat(1, Aval.shape[1], 1), Aval[..., 1:] ], -1) inp = torch.cat([B[:, None, :].repeat(1, Aval.shape[1], 1), inp], -1) mu1T, sig1T = self.transition_fxn(inp, Acat) else: mu1T, sig1T = self.transition_fxn(inp, A[:, 1:Tmax, :]) mu, sig = torch.cat([mu0, mu1T], 1), torch.cat([sig0, sig1T], 1) return Independent(Normal(mu, sig), 1)
def partially_pooled_with_logit(at_bats, hits): r""" Number of hits has a Binomial distribution with a logit link function. The logits $\alpha$ for each player is normally distributed with the mean and scale parameters sharing a common prior. :param (torch.Tensor) at_bats: Number of at bats for each player. :param (torch.Tensor) hits: Number of hits for the given at bats. :return: Number of hits predicted by the model. """ num_players = at_bats.shape[0] loc = pyro.sample("loc", Normal(scalar_like(at_bats, -1), scalar_like(at_bats, 1))) scale = pyro.sample("scale", HalfCauchy(scale=scalar_like(at_bats, 1))) with pyro.plate("num_players", num_players): alpha = pyro.sample("alpha", Normal(loc, scale)) return pyro.sample("obs", Binomial(at_bats, logits=alpha), obs=hits)
def p_Zt_Ztm1(self, Zt, A, B, X, A0, Am, eps=0.): X0 = X[:, 0, :] Xt = X[:, 1:, :] inp_cat = torch.cat([B, X0, A0], -1) mu1 = self.prior_W(inp_cat)[:, None, :] sig1 = torch.nn.functional.softplus(self.prior_sigma(inp_cat))[:, None, :] Tmax = Zt.shape[1] if self.hparams['augmented']: Zinp = torch.cat([Zt[:, :-1, :], Xt[:, :-1, :]], -1) else: Zinp = Zt[:, :-1, :] Aval = A[:, 1:Tmax, :] sub_mask = np.triu(np.ones( (Aval.shape[0], Aval.shape[1], Aval.shape[1])), k=1).astype('uint8') Zm = (torch.from_numpy(sub_mask) == 0).to(Am.device) res = self.attn(self.attn_lin(torch.cat([Xt[:, :-1, :], Aval], -1)), Zinp, Zinp, mask=Zm, use_matmul=True) if self.hparams['include_baseline']: Acat = torch.cat([ Aval[..., [0]], B[:, None, :].repeat(1, Aval.shape[1], 1), Aval[..., 1:] ], -1) mu2T, sig2T = self.transition_fxn(res, Acat, eps=eps) else: mu2T, sig2T = self.transition_fxn(res, A[:, 1:Tmax, :], eps=eps) mu, sig = torch.cat([mu1, mu2T], 1), torch.cat([sig1, sig2T], 1) return Independent(Normal(mu, sig), 1)