def plate_custom_model(subsample): with pyro.plate('plate', 20, subsample=subsample) as batch: result = batch return result
def predict(args, data, samples, truth=None): logging.info("Forecasting {} steps ahead...".format(args.forecast)) particle_plate = pyro.plate("particles", args.num_samples, dim=-1) # First we sample discrete auxiliary variables from the continuous # variables sampled in vectorized_model. This samples only time steps # [0:duration]. Here infer_discrete runs a forward-filter backward-sample # algorithm. We'll add these new samples to the existing dict of samples. model = poutine.condition(continuous_model, samples) model = particle_plate(model) model = infer_discrete(model, first_available_dim=-2) with poutine.trace() as tr: model(args, data) samples = OrderedDict((name, site["value"]) for name, site in tr.trace.nodes.items() if site["type"] == "sample") # Next we'll run the forward generative process in discrete_model. This # samples time steps [duration:duration+forecast]. Again we'll update the # dict of samples. extended_data = list(data) + [None] * args.forecast model = poutine.condition(discrete_model, samples) model = particle_plate(model) with poutine.trace() as tr: model(args, extended_data) samples = OrderedDict((name, site["value"]) for name, site in tr.trace.nodes.items() if site["type"] == "sample") # Finally we'll concatenate the sequentially sampled values into contiguous # tensors. This operates on the entire time interval [0:duration+forecast]. for key in ("S", "I", "S2I", "I2R"): pattern = key + "_[0-9]+" series = [ value for name, value in samples.items() if re.match(pattern, name) ] assert len(series) == args.duration + args.forecast series[0] = series[0].expand(series[1].shape) samples[key] = torch.stack(series, dim=-1) S2I = samples["S2I"] median = S2I.median(dim=0).values logging.info( "Median prediction of new infections (starting on day 0):\n{}".format( " ".join(map(str, map(int, median))))) # Optionally plot the latent and forecasted series of new infections. if args.plot: import matplotlib.pyplot as plt plt.figure() time = torch.arange(args.duration + args.forecast) p05 = S2I.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values p95 = S2I.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values plt.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI") plt.plot(time, median, "r-", label="median") plt.plot(time[:args.duration], data, "k.", label="observed") if truth is not None: plt.plot(time, truth, "k--", label="truth") plt.axvline(args.duration - 0.5, color="gray", lw=1) plt.xlim(0, len(time) - 1) plt.ylim(0, None) plt.xlabel("day after first infection") plt.ylabel("new infections per day") plt.title("New infections in population of {}".format(args.population)) plt.legend(loc="upper left") plt.tight_layout() return samples
def model(subsample): with pyro.plate("data", len(data), subsample_size, subsample) as ind: x = data[ind] z = pyro.sample("z", Normal(0, 1)) pyro.sample("x", Normal(z, 1), obs=x)
def guide(data): data = torch.reshape(data, [60000, 50, 50]) pyro.module('rnn', rnn) pyro.module('bl_rnn', bl_rnn) pyro.module('predict_l1', predict_l1) pyro.module('predict_l2', predict_l2) pyro.module('encode_l1', encode_l1) pyro.module('encode_l2', encode_l2) pyro.module('bl_predict_l1', bl_predict_l1) pyro.module('bl_predict_l2', bl_predict_l2) pyro.param('h_init', h_init) pyro.param('c_init', c_init) pyro.param('z_where_init', z_where_init) pyro.param('z_what_init', z_what_init) pyro.param('bl_h_init', bl_h_init) pyro.param('bl_c_init', bl_c_init) with pyro.plate('data', 60000, 64) as ix: # size = [64, 50, 50] batch = data[ix] flattened_batch = torch.Tensor.view(batch, [64, 2500]) # inputs_raw = batch # inputs_embed = flattened_batch # inputs_bl_embed = flattened_batch state_h = torch.Tensor.expand(h_init, [64, 256]) state_c = torch.Tensor.expand(c_init, [64, 256]) state_bl_h = torch.Tensor.expand(bl_h_init, [64, 256]) state_bl_c = torch.Tensor.expand(bl_c_init, [64, 256]) state_z_pres = torch.ones(64, 1) state_z_where = torch.Tensor.expand(z_where_init, [64, 3]) state_z_what = torch.Tensor.expand(z_what_init, [64, 50]) z_pres = [] z_where = [] for t in range(3): #=========== guide_step # prev_h = state_h # prev_c = state_c # prev_bl_h = state_bl_h # prev_bl_c = state_bl_c # prev_z_pres = state_z_pres # prev_z_where = state_z_where # prev_z_what = state_z_what # size = [64, 2554] rnn_input = torch.cat( (flattened_batch, state_z_where, state_z_what, state_z_pres), 1) # size = [64, 256], [64, 256] state_h, state_c = rnn(rnn_input, (state_h, state_c)) #===== predict # size = [64, 7] out = predict_l2(F.relu(predict_l1(state_h))) # size = [64, 1] z_pres_p = torch.sigmoid(out[:, 0:1]) # size = [64, 3] z_where_loc = out[:, 1:4] # size = [64, 3] z_where_scale = F.softplus(out[:, 4:]) #===== predict #===== baseline_step # size = [64, 2554] rnn_input = torch.cat( (flattened_batch, torch.Tensor.detach(state_z_where), torch.Tensor.detach(state_z_what), torch.Tensor.detach(state_z_pres)), 1) # size = [64, 256], [64, 256] state_bl_h, state_bl_c = bl_rnn(rnn_input, (state_bl_h, state_bl_c)) #===== bl_predict # size = [64, 1] bl_value = bl_predict_l2(F.relu(bl_predict_l1(state_bl_h))) #===== bl_predict bl_value = bl_value * state_z_pres infer_dict = dict(baseline=dict( baseline_value=torch.squeeze(bl_value, -1))) #===== baseline_step # size = [64, 1] cur_z_pres =\ pyro.sample('z_pres_{}'.format(t), Bernoulli(z_pres_p * state_z_pres).to_event(1), infer=infer_dict) # sample_mask = cur_z_pres # size = [64, 3] cur_z_where =\ pyro.sample('z_where_{}'.format(t), Normal(z_where_loc + z_where_loc_prior, z_where_scale * z_where_scale_prior) .mask(cur_z_pres) .to_event(1)) #===== image_to_window # images = batch #===== z_where_inv # size = [64, 3] out = torch.cat((torch.ones(64, 1), -cur_z_where[:, 1:]), 1) out = out / cur_z_where[:, 0:1] cur_z_where_inv = out #===== z_where_inv #===== expand_z_where # size = [64, 4] out = torch.cat((torch.zeros(64, 1), cur_z_where_inv), 1) # size = [64, 6] out = torch.index_select(out, 1, expansion_indices) out = torch.Tensor.view(out, [64, 2, 3]) theta_inv = out #===== expand_z_where # size = [64, 28, 28, 2] grid = F.affine_grid(theta_inv, [64, 1, 28, 28]) # size = [64, 1, 28, 28] out = F.grid_sample(torch.Tensor.view(batch, [64, 1, 50, 50]), grid) x_att = torch.Tensor.view(out, [64, 784]) #===== image_to_window #===== encode # size = [64, 100] a = encode_l2(F.relu(encode_l1(x_att))) # size = [64, 50] z_what_loc = a[:, 0:50] # size = [64, 50] z_what_scale = F.softplus(a[:, 50:]) #===== encode # size = [64, 50] cur_z_what =\ pyro.sample('z_what_{}'.format(t), Normal(z_what_loc, z_what_scale) .mask(cur_z_pres) .to_event(1)) # state_h = h # state_c = c # state_bl_h = bl_h # state_bl_c = bl_c state_z_pres = cur_z_pres state_z_where = cur_z_where state_z_what = cur_z_what #=========== guide_step z_where.append(state_z_where) z_pres.append(state_z_pres) return z_where, z_pres
def model(): with pyro.plate_stack("plates", shape[:dim]): with pyro.plate("particles", 10000): pyro.sample( "x", dist.Normal(loc, scale).expand(shape).to_event(-dim))
def guide(self, x, temp_id=None, anneal_id=None, anneal_t=None, anneal_dynamics=None): pyro.module('vdsm_seq', self) torch.set_default_tensor_type('torch.cuda.FloatTensor') bs, seq_len, pixels = x.view(x.shape[0], x.shape[1], self.imsize**2 * self.nc).shape h_0_enc = self.h_0_enc.expand(2 * self.num_layers_rnn, bs, self.hid_dim).contiguous() c_0_enc = self.c_0_enc.expand(2 * self.num_layers_rnn, bs, self.hid_dim).contiguous() z_prev_ = self.z_q_0.expand(bs, 1, -1) # z0 dec_inp_0 = self.dec_inp_0.expand(bs, 1, -1) x = x.view(bs * seq_len, self.nc, self.imsize, self.imsize) pre_z, _, ID_loc, ID_scale = self.image_enc(x) ID_loc, ID_scale = self.id_layers(ID_loc, ID_scale) # extra trainable layer ID_loc = torch.mean(ID_loc.view(bs, seq_len, -1), 1).unsqueeze(1)[:, 0] ID_scale = torch.mean(ID_scale.view(bs, seq_len, -1), 1).unsqueeze(1)[:, 0] pre_z = pre_z.view(bs, seq_len, -1) # from https://github.com/yatindandi/Disentangled-Sequential-Autoencoder/blob/master/model.py self.encode_f sequence = pre_z.permute(1, 0, 2) _, h, _, rnn_enc_raw, out = self.seq2seq_enc(sequence, h_0_enc, c_0_enc) h = h.permute(1, 0, 2).contiguous() h = h.view(bs, self.hid_dim * 2 * self.num_layers_rnn) d_params = self.cats(h) dz_loc = self.act(d_params[:, :self.dynamics_dim]) dz_scale = self.softplus(d_params[:, self.dynamics_dim:]) # infer dynamics and identity from data with pyro.plate('ID_plate', bs): IDdist = dist.Normal(ID_loc, ID_scale).to_event(1) dz_dist = dist.Normal(dz_loc, dz_scale).to_event(1) with poutine.scale(scale=anneal_id): ID = pyro.sample('ID', IDdist) * temp_id # static factors with poutine.scale(scale=anneal_dynamics): dz = pyro.sample("dz", dz_dist) # dynamics z h_dec = self.dz_to_dec_h(dz).view(-1, self.num_layers_rnn, self.hid_dim).permute(1, 0, 2) c_dec = self.dz_to_dec_c(dz).view(-1, self.num_layers_rnn, self.hid_dim).permute(1, 0, 2) for i in pyro.plate('batch_loop', bs): dec_inp = dec_inp_0[None, i].contiguous() z_prev = z_prev_[None, i].contiguous() h = h_dec[:, None, i] c = c_dec[:, None, i] dz_dec = dz[None, i, None, :] for t in pyro.markov(range(seq_len)): dec_inp, (h, c) = self.seq2seq_dec(dec_inp, (h, c)) z_loc, z_scale = self.comb(z_prev, dec_inp, dz_dec) z_dist = dist.Normal(z_loc[0], z_scale[0]).to_event(1) with poutine.scale(scale=anneal_t): z = pyro.sample('z_{}_{}'.format(i, t), z_dist) z_prev = z.view(1, 1, -1)
def _fn(*args, **kwargs): with pyro.plate("num_particles_vectorized", num_samples, dim=-max_plate_nesting): return fn(*args, **kwargs)
def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): # this is the number of time steps we need to process in the mini-batch T_max = mini_batch.size(1) # register all PyTorch (sub)modules with pyro pyro.module("dmm", self) # if on gpu we need the fully broadcast view of the rnn initial state # to be in contiguous gpu memory h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous() # push the observed x's through the rnn; # rnn_output contains the hidden state at each time step rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig) # reverse the time-ordering in the hidden state and un-pack it rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0)) # we enclose all the sample statements in the guide in a plate. # this marks that each datapoint is conditionally independent of the others. with pyro.plate("z_minibatch", len(mini_batch)): # sample the latents z one time step at a time # we wrap this loop in pyro.markov so that TraceEnum_ELBO can use multiple samples from the guide at each z for t in pyro.markov(range(1, T_max + 1)): # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T}) z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :]) # if we are using normalizing flows, we apply the sequence of transformations # parameterized by self.iafs to the base distribution defined in the previous line # to yield a transformed distribution that we use for q(z_t|...) if len(self.iafs) > 0: z_dist = TransformedDistribution( dist.Normal(z_loc, z_scale), self.iafs) assert z_dist.event_shape == (self.z_q_0.size(0), ) assert z_dist.batch_shape[-1:] == (len(mini_batch), ) else: z_dist = dist.Normal(z_loc, z_scale) assert z_dist.event_shape == () assert z_dist.batch_shape[-2:] == (len(mini_batch), self.z_q_0.size(0)) # sample z_t from the distribution z_dist with pyro.poutine.scale(scale=annealing_factor): if len(self.iafs) > 0: # in output of normalizing flow, all dimensions are correlated (event shape is not empty) z_t = pyro.sample( "z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1])) else: # when no normalizing flow used, ".to_event(1)" indicates latent dimensions are independent z_t = pyro.sample( "z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1:t]).to_event(1)) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t
def model(): with pyro.plate("particles", 20000): return pyro.sample("x", dist.Stable(stability, skew))
def guide(self, mini_batch, mini_batch_reversed, annealing_factor=1.0): """ The inference model q(z_{1:T} | y_{1:T}) """ # Number of time steps through mini-batch T_max = mini_batch.size(1) # Register all PyTorch modules with Pyro pyro.module("dkf", self) # Contiguous hidden state h_0 = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous() # Flatten and reverse input y batch_size = mini_batch_reversed.shape[0] seq_len = mini_batch_reversed.shape[1] flat_mini_batch_reversed = torch.zeros( batch_size, seq_len, self.rnn_input_dim).to(self.device) for t in range(seq_len): flat_mini_batch_reversed[:, t, :] = self.flatten( mini_batch_reversed[:, t, :, :, :]) # Feed y through RNN; rnn_output, _ = self.rnn(flat_mini_batch_reversed, h_0) # Backwards to take future observations into account rnn_output = reversed_input(rnn_output) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0)) # We enclose all the sample statements in the model in a plate # for conditional independence with pyro.plate("z_test", len(mini_batch)): # Sample the latents z for t in range(1, T_max + 1): # Mean and variance for the distribution q(z_t | z_{t-1}, y_{t:T}) z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :]) # If we are using normalizing flows, we apply the sequence of transformations # parameterized by self.iafs to the base distribution if len(self.iafs) > 0: z_dist = TransformedDistribution( dist.Normal(z_loc, z_scale), self.iafs) else: z_dist = dist.Normal(z_loc, z_scale) # Sample z_t from the distribution z_dist with pyro.poutine.scale(scale=annealing_factor): if len(self.iafs) > 0: z_t = pyro.sample("z_%d" % t, z_dist) else: # When no normalizing flow is used, ".to_event(1)" # indicates latent dimensions are independent z_t = pyro.sample("z_%d" % t, z_dist.to_event(1)) # Update time step z_prev = z_t return z_t
def forward(self, data): loc, log_scale = self.z.unbind(-1) with pyro.plate("data"): pyro.sample("obs", dist.Cauchy(loc, log_scale.exp()), obs=data)
def iplate_cuda_model(subsample_size): loc = torch.zeros(20).cuda() scale = torch.ones(20).cuda() for i in pyro.plate("data", 20, subsample_size, device=loc.device): pyro.sample("x_{}".format(i), dist.Normal(loc[i], scale[i]))
def plate_cuda_model(subsample_size): loc = torch.zeros(20).cuda() scale = torch.ones(20).cuda() with pyro.plate("data", 20, subsample_size, device=loc.device) as batch: pyro.sample("x", dist.Normal(loc[batch], scale[batch]))
def iplate_custom_model(subsample): result = [] for i in pyro.plate('plate', 20, subsample=subsample): result.append(i) return result
def fit(self, model_name, model_param_names, data_input, fitter=None, init_values=None): # verbose is passed through from orbit.models.base_estimator verbose = self.verbose message = self.message learning_rate = self.learning_rate learning_rate_total_decay = self.learning_rate_total_decay num_sample = self.num_sample seed = self.seed num_steps = self.num_steps pyro.set_rng_seed(seed) if fitter is None: fitter = get_pyro_model(model_name) # abstract model = fitter(data_input) # concrete # Perform stochastic variational inference using an auto guide. pyro.clear_param_store() guide = AutoLowRankMultivariateNormal(model) optim = ClippedAdam({ "lr": learning_rate, "lrd": learning_rate_total_decay**(1 / num_steps) }) elbo = Trace_ELBO(num_particles=self.num_particles, vectorize_particles=True) svi = SVI(model, guide, optim, elbo) for step in range(num_steps): loss = svi.step() if verbose and step % message == 0: scale_rms = guide._loc_scale()[1].detach().pow( 2).mean().sqrt().item() print("step {: >4d} loss = {:0.5g}, scale = {:0.5g}".format( step, loss, scale_rms)) # Extract samples. vectorize = pyro.plate("samples", num_sample, dim=-1 - model.max_plate_nesting) with pyro.poutine.trace() as tr: samples = vectorize(guide)() with pyro.poutine.replay(trace=tr.trace): samples.update(vectorize(model)()) # Convert from torch.Tensors to numpy.ndarrays. extract = { name: value.detach().squeeze().numpy() for name, value in samples.items() } # make sure that model param names are a subset of stan extract keys invalid_model_param = set(model_param_names) - set(list( extract.keys())) if invalid_model_param: raise EstimatorException( "Pyro model definition does not contain required parameters") # `stan.optimizing` automatically returns all defined parameters # filter out unnecessary keys extract = {param: extract[param] for param in model_param_names} return extract
def model(): with pyro.plate("plate", 10): with poutine.reparam(config={"x": Reparam()}): return pyro.sample("x", dist.Stable(1.5, 0))
def spire_model(priors, sub=1): if len(priors) != 3: raise ValueError band_plate = pyro.plate('bands', len(priors), dim=-2) src_plate = pyro.plate('nsrc', priors[0].nsrc, dim=-1) psw_plate = pyro.plate('psw_pixels', priors[0].sim.size, dim=-3, subsample_size=np.rint( sub * priors[0].sim.size).astype(int)) pmw_plate = pyro.plate('pmw_pixels', priors[1].sim.size, dim=-3, subsample_size=np.rint( sub * priors[1].sim.size).astype(int)) plw_plate = pyro.plate('plw_pixels', priors[2].sim.size, dim=-3, subsample_size=np.rint( sub * priors[2].sim.size).astype(int)) pointing_matrices = [ torch.sparse.FloatTensor(torch.LongTensor([p.amat_row, p.amat_col]), torch.Tensor(p.amat_data), torch.Size([p.snpix, p.nsrc])) for p in priors ] bkg_prior = torch.tensor([p.bkg[0] for p in priors]) bkg_prior_sig = torch.tensor([p.bkg[1] for p in priors]) nsrc = priors[0].nsrc f_low_lim = torch.tensor([p.prior_flux_lower for p in priors], dtype=torch.float) f_up_lim = torch.tensor([p.prior_flux_upper for p in priors], dtype=torch.float) with band_plate as ind_band: sigma_conf = pyro.sample( 'sigma_conf', dist.HalfCauchy(torch.tensor([1.0]), torch.tensor([0.5])).expand( [1]).to_event(1)).squeeze(-1) bkg = pyro.sample('bkg', dist.Normal(-5, 0.5).expand([1]).to_event(1)).squeeze(-1) with src_plate as ind_src: src_f = pyro.sample('src_f', dist.Uniform(0, 1).expand( [1]).to_event(1)).squeeze(-1) f_vec = (f_up_lim - f_low_lim) * src_f + f_low_lim db_hat_psw = torch.sparse.mm(pointing_matrices[0], f_vec[0, ...].unsqueeze(-1)) + bkg[0] db_hat_pmw = torch.sparse.mm(pointing_matrices[1].to_dense(), f_vec[1, ...].unsqueeze(-1)) + bkg[1] db_hat_plw = torch.sparse.mm(pointing_matrices[2].to_dense(), f_vec[2, ...].unsqueeze(-1)) + bkg[2] sigma_tot_psw = torch.sqrt( torch.pow(torch.tensor(priors[0].snim), 2) + torch.pow(sigma_conf[0], 2)) sigma_tot_pmw = torch.sqrt( torch.pow(torch.tensor(priors[1].snim), 2) + torch.pow(sigma_conf[1], 2)) sigma_tot_plw = torch.sqrt( torch.pow(torch.tensor(priors[2].snim), 2) + torch.pow(sigma_conf[2], 2)) with psw_plate as ind_psw: psw_map = pyro.sample("obs_psw", dist.Normal(db_hat_psw.squeeze()[ind_psw], sigma_tot_psw[ind_psw]), obs=torch.tensor(priors[0].sim[ind_psw])) with pmw_plate as ind_pmw: pmw_map = pyro.sample("obs_pmw", dist.Normal(db_hat_pmw.squeeze()[ind_pmw], sigma_tot_pmw[ind_pmw]), obs=torch.tensor(priors[1].sim[ind_pmw])) with plw_plate as ind_plw: plw_map = pyro.sample("obs_plw", dist.Normal(db_hat_plw.squeeze()[ind_plw], sigma_tot_plw[ind_plw]), obs=torch.tensor(priors[2].sim[ind_plw])) return psw_map, pmw_map, plw_map
def create_plates(): return pyro.plate("plate", 10, subsample_size=3)
def main(args): pyro.set_rng_seed(args.rng_seed) fig = plt.figure(figsize=(8, 16), constrained_layout=True) gs = GridSpec(4, 2, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[0, 1]) ax3 = fig.add_subplot(gs[1, 0]) ax4 = fig.add_subplot(gs[2, 0]) ax5 = fig.add_subplot(gs[3, 0]) ax6 = fig.add_subplot(gs[1, 1]) ax7 = fig.add_subplot(gs[2, 1]) ax8 = fig.add_subplot(gs[3, 1]) xlim = tuple(int(x) for x in args.x_lim.strip().split(',')) ylim = tuple(int(x) for x in args.y_lim.strip().split(',')) assert len(xlim) == 2 assert len(ylim) == 2 # 1. Plot samples drawn from BananaShaped distribution x1, x2 = torch.meshgrid([torch.linspace(*xlim, 100), torch.linspace(*ylim, 100)]) d = BananaShaped(args.param_a, args.param_b) p = torch.exp(d.log_prob(torch.stack([x1, x2], dim=-1))) ax1.contourf(x1, x2, p, cmap='OrRd',) ax1.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='BananaShaped distribution: \nlog density') # 2. Run vanilla HMC logging.info('\nDrawing samples using vanilla HMC ...') mcmc = run_hmc(args, model) vanilla_samples = mcmc.get_samples()['x'].cpu().numpy() ax2.contourf(x1, x2, p, cmap='OrRd') ax2.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='Posterior \n(vanilla HMC)') sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], ax=ax2) # 3(a). Fit a diagonal normal autoguide logging.info('\nFitting a DiagNormal autoguide ...') guide = AutoDiagonalNormal(model, init_scale=0.05) fit_guide(guide, args) with pyro.plate('N', args.num_samples): guide_samples = guide()['x'].detach().cpu().numpy() ax3.contourf(x1, x2, p, cmap='OrRd') ax3.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='Posterior \n(DiagNormal autoguide)') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax3) # 3(b). Draw samples using NeuTra HMC logging.info('\nDrawing samples using DiagNormal autoguide + NeuTra HMC ...') neutra = NeuTraReparam(guide.requires_grad_(False)) neutra_model = poutine.reparam(model, config=lambda _: neutra) mcmc = run_hmc(args, neutra_model) zs = mcmc.get_samples()['x_shared_latent'] sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax4) ax4.set(xlabel='x0', ylabel='x1', title='Posterior (warped) samples \n(DiagNormal + NeuTra HMC)') samples = neutra.transform_sample(zs) samples = samples['x'].cpu().numpy() ax5.contourf(x1, x2, p, cmap='OrRd') ax5.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='Posterior (transformed) \n(DiagNormal + NeuTra HMC)') sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax5) # 4(a). Fit a BNAF autoguide logging.info('\nFitting a BNAF autoguide ...') guide = AutoNormalizingFlow(model, partial(iterated, args.num_flows, block_autoregressive)) fit_guide(guide, args) with pyro.plate('N', args.num_samples): guide_samples = guide()['x'].detach().cpu().numpy() ax6.contourf(x1, x2, p, cmap='OrRd') ax6.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='Posterior \n(BNAF autoguide)') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax6) # 4(b). Draw samples using NeuTra HMC logging.info('\nDrawing samples using BNAF autoguide + NeuTra HMC ...') neutra = NeuTraReparam(guide.requires_grad_(False)) neutra_model = poutine.reparam(model, config=lambda _: neutra) mcmc = run_hmc(args, neutra_model) zs = mcmc.get_samples()['x_shared_latent'] sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax7) ax7.set(xlabel='x0', ylabel='x1', title='Posterior (warped) samples \n(BNAF + NeuTra HMC)') samples = neutra.transform_sample(zs) samples = samples['x'].cpu().numpy() ax8.contourf(x1, x2, p, cmap='OrRd') ax8.set(xlabel='x0', ylabel='x1', xlim=xlim, ylim=ylim, title='Posterior (transformed) \n(BNAF + NeuTra HMC)') sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax8) plt.savefig(os.path.join(os.path.dirname(__file__), 'neutra.pdf'))
def model(): with pyro.plate_stack("plates", shape): with pyro.plate("particles", 200000): return pyro.sample("x", dist.Stable(stability, 0, scale, loc))
def guide(self, data): pyro.module("encoder", self.encoder) x_observation = data[0] with pyro.plate("data", x_observation.shape[0]): z_loc, z_scale = self.encoder.forward(x_observation) pyro.sample('latent', dist.Normal(z_loc, z_scale).to_event(1))
def sample(self, n_samples=1): with pyro.plate('observations', n_samples): samples = self.model() return (*samples, )
def model(data): data = torch.reshape(data, [60000, 50, 50]) pyro.module("decode_l1", decode_l1) pyro.module("decode_l2", decode_l2) with pyro.plate('data', 60000, 64) as ix: # size = [64, 50, 50] batch = data[ix] #================= prior state_x = torch.zeros([64, 50, 50]) state_z_pres = torch.ones([64, 1]) state_z_where = None z_pres = [] z_where = [] for t in range(3): #==================== prior_step # size = [64, 50, 50] prev_x = state_x # size = [64, 1] prev_z_pres = state_z_pres # size = None or [64, 3] prev_z_where = state_z_where # size = [64, 1] cur_z_pres =\ pyro.sample('z_pres_{}'.format(t), Bernoulli(trial_probs[t] * prev_z_pres) .to_event(1)) sample_mask = cur_z_pres # size = [64, 3] cur_z_where =\ pyro.sample('z_where_{}'.format(t), Normal(torch.Tensor.expand(z_where_loc_prior, [64, 3]), torch.Tensor.expand(z_where_scale_prior, [64, 3])) .mask(sample_mask) .to_event(1)) # size = [64, 50] cur_z_what =\ pyro.sample('z_what_{}'.format(t), Normal(torch.zeros([64, 50]), torch.ones([64, 50])) .mask(sample_mask) .to_event(1)) #===== decode # size = [64, 784] y_att = torch.sigmoid( decode_l2(F.relu(decode_l1(cur_z_what))) - 2.0) #===== decode #===== window_to_image windows = y_att #===== expand_z_where # size = [64, 4] out = torch.cat((torch.zeros(64, 1), cur_z_where), 1) # size = [64, 6] out = torch.index_select(out, 1, expansion_indices) # size = [64, 2, 3] out = torch.Tensor.view(out, [64, 2, 3]) theta = out #===== expand_z_where # size = [64, 50, 50, 2] grid = F.affine_grid(theta, [64, 1, 50, 50]) # size = [64, 1, 50, 50] out = F.grid_sample(torch.Tensor.view(windows, [64, 1, 28, 28]), grid) y = torch.Tensor.view(out, [64, 50, 50]) #===== window_to_image # size = [64, 50, 50] cur_x = prev_x + (y * torch.Tensor.view(cur_z_pres, [64, 1, 1])) state_x = cur_x state_z_pres = cur_z_pres state_z_where = cur_z_where #==================== prior_step z_where.append(state_z_where) z_pres.append(state_z_pres) # size = [64, 50, 50] x = state_x #================== prior pyro.sample('obs', Normal(torch.Tensor.view(x, [64, 2500]), (0.3 * torch.ones(64, 2500))).to_event(1), obs=torch.Tensor.view(batch, [64, 2500]))
def sample_scm(self, n_samples=1): with pyro.plate('observations', n_samples): samples = self.scm() return (*samples, )
def model(): with pyro.plate_stack("plates", shape[:dim]): with pyro.plate("particles", 10000): pyro.sample("x", dist.Uniform(0, 1).expand(shape).to_event(-dim))
def sample_pgm(num_samples): with pyro.plate('observations', num_samples): return self.pyro_model.pgm_model()
def model(data): y_prob = pyro.sample("y_prob", dist.Beta(1., 1.)) with pyro.plate("data", data.shape[0]): y = pyro.sample("y", dist.Bernoulli(y_prob)) z = pyro.sample("z", dist.Bernoulli(0.65 * y + 0.1)) pyro.sample("obs", dist.Normal(2. * z, 1.), obs=data)
def model(): lambda_latent = pyro.sample("lambda_latent", Gamma(alpha0, beta0)) with pyro.plate("data", n_data): pyro.sample("obs", dist.Poisson(lambda_latent), obs=data) return lambda_latent
def model(): with pyro.plate("data", len(data), subsample_size) as ind: x = data[ind] z = pyro.sample("z", Normal(0, 1).expand_by(x.shape)) pyro.sample("x", Normal(z, 1), obs=x)
def neals_funnel(dim=10): y = pyro.sample("y", dist.Normal(0, 3)) with pyro.plate("D", dim): return pyro.sample("x", dist.Normal(0, torch.exp(y / 2)))