def AutoMixed(model_full, init_loc={}, delta=None): guide = AutoGuideList(model_full) marginalised_guide_block = poutine.block(model_full, expose_all=True, hide_all=False, hide=['tau']) if delta is None: guide.append( AutoNormal(marginalised_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc), init_scale=0.05)) elif delta == 'part' or delta == 'all': guide.append( AutoDelta(marginalised_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc))) full_rank_guide_block = poutine.block(model_full, hide_all=True, expose=['tau']) if delta is None or delta == 'part': guide.append( AutoMultivariateNormal( full_rank_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc), init_scale=0.05)) else: guide.append( AutoDelta(full_rank_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc))) return guide
def test_svi_model_side_enumeration(model, temperature): # Perform fake inference. # This has the wrong distribution but the right type for tests. guide = AutoNormal( handlers.enum( handlers.block(infer.config_enumerate(model), expose=["loc", "scale"]))) guide() # Initialize but don't bother to train. guide_trace = handlers.trace(guide).get_trace() guide_data = { name: site["value"] for name, site in guide_trace.nodes.items() if site["type"] == "sample" } # MAP estimate discretes, conditioned on posterior sampled continous latents. actual_trace = handlers.trace( infer.infer_discrete( # TODO support replayed sites in infer_discrete. # handlers.replay(infer.config_enumerate(model), guide_trace) handlers.condition(infer.config_enumerate(model), guide_data), temperature=temperature, )).get_trace() # Check site names and shapes. expected_trace = handlers.trace(model).get_trace() assert set(actual_trace.nodes) == set(expected_trace.nodes) assert "z1" not in actual_trace.nodes["scale"]["funsor"]["value"].inputs
def __init__(self, **kwargs): super().__init__() self._model = BayesianRegressionPyroModel(**kwargs) self._guide = AutoNormal(self.model, init_loc_fn=init_to_mean, create_plates=self.model.create_plates) self._get_fn_args_from_batch = self._model._get_fn_args_from_batch
def __init__(self, model, data, covariates, *, guide=None, init_loc_fn=init_to_sample, init_scale=0.1, create_plates=None, optim=None, learning_rate=0.01, betas=(0.9, 0.99), learning_rate_decay=0.1, clip_norm=10.0, dct_gradients=False, subsample_aware=False, num_steps=1001, num_particles=1, vectorize_particles=True, warm_start=False, log_every=100): assert data.size(-2) == covariates.size(-2) super().__init__() self.model = model if guide is None: guide = AutoNormal(self.model, init_loc_fn=init_loc_fn, init_scale=init_scale, create_plates=create_plates) self.guide = guide # Initialize. if warm_start: model = PrefixWarmStartMessenger()(model) guide = PrefixWarmStartMessenger()(guide) if dct_gradients: model = MarkDCTParamMessenger("time")(model) guide = MarkDCTParamMessenger("time")(guide) elbo = Trace_ELBO(num_particles=num_particles, vectorize_particles=vectorize_particles) elbo._guess_max_plate_nesting(model, guide, (data, covariates), {}) elbo.max_plate_nesting = max(elbo.max_plate_nesting, 1) # force a time plate losses = [] if num_steps: if optim is None: optim = DCTAdam({"lr": learning_rate, "betas": betas, "lrd": learning_rate_decay ** (1 / num_steps), "clip_norm": clip_norm, "subsample_aware": subsample_aware}) svi = SVI(self.model, self.guide, optim, elbo) for step in range(num_steps): loss = svi.step(data, covariates) / data.numel() if log_every and step % log_every == 0: logger.info("step {: >4d} loss = {:0.6g}".format(step, loss)) print("step {: >4d} loss = {:0.6g}".format(step, loss)) losses.append(loss) self.guide.create_plates = None # Disable subsampling after training. self.max_plate_nesting = elbo.max_plate_nesting self.losses = losses
def test_subsample_smoke(Reparam, subsample): def model(): with poutine.reparam(config={"x": Reparam()}): with pyro.plate("plate", 10): return pyro.sample("x", dist.Stable(1.5, 0)) def create_plates(): return pyro.plate("plate", 10, subsample_size=3) guide = AutoNormal(model, create_plates=create_plates if subsample else None) Trace_ELBO().loss(model, guide) # smoke test
def test_end_to_end(model): # Test training. model = AutoReparam()(model) guide = AutoNormal(model) svi = SVI(model, guide, Adam({"lr": 1e-9}), Trace_ELBO()) for step in range(3): svi.step() # Test prediction. predictive = Predictive(model, guide=guide, num_samples=2) samples = predictive() assert set("abc").issubset(samples.keys())
def __init__(self, **kwargs): super().__init__() self.hist = [] self._model = LocationModelLinearDependentWMultiExperimentModel( **kwargs) self._guide = AutoNormal( self.model, init_loc_fn=init_to_mean, create_plates=self.model.create_plates, )
def test_subsample_guide(auto_class, init_fn): # The model from tutorial/source/easyguide.ipynb def model(batch, subsample, full_size): num_time_steps = len(batch) result = [None] * num_time_steps drift = pyro.sample("drift", dist.LogNormal(-1, 0.5)) plate = pyro.plate("data", full_size, subsample=subsample) assert plate.size == 50 with plate: z = 0. for t in range(num_time_steps): z = pyro.sample("state_{}".format(t), dist.Normal(z, drift)) result[t] = pyro.sample("obs_{}".format(t), dist.Bernoulli(logits=z), obs=batch[t]) return torch.stack(result) def create_plates(batch, subsample, full_size): return pyro.plate("data", full_size, subsample=subsample) if auto_class == AutoGuideList: guide = AutoGuideList(model, create_plates=create_plates) guide.add(AutoDelta(poutine.block(model, expose=["drift"]))) guide.add(AutoNormal(poutine.block(model, hide=["drift"]))) else: guide = auto_class(model, create_plates=create_plates) full_size = 50 batch_size = 20 num_time_steps = 8 pyro.set_rng_seed(123456789) data = model([None] * num_time_steps, torch.arange(full_size), full_size) assert data.shape == (num_time_steps, full_size) pyro.get_param_store().clear() pyro.set_rng_seed(123456789) svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO()) for epoch in range(2): beg = 0 while beg < full_size: end = min(full_size, beg + batch_size) subsample = torch.arange(beg, end) batch = data[:, beg:end] beg = end svi.step(batch, subsample, full_size=full_size)
def main(args): # Create a model, synthetic data, and a guide. pyro.set_rng_seed(args.seed) model = Model(args.size) covariates = torch.randn(args.size) data = model(covariates) guide = AutoNormal(model) if args.horovod: # Initialize Horovod and set PyTorch globals. import horovod.torch as hvd hvd.init() torch.set_num_threads(1) if args.cuda: torch.cuda.set_device(hvd.local_rank()) if args.cuda: torch.set_default_tensor_type("torch.cuda.FloatTensor") device = torch.tensor(0).device if args.horovod: # Initialize parameters and broadcast to all workers. guide(covariates[:1], data[:1]) # Initializes model and guide. hvd.broadcast_parameters(guide.state_dict(), root_rank=0) hvd.broadcast_parameters(model.state_dict(), root_rank=0) # Create an ELBO loss and a Pyro optimizer. elbo = Trace_ELBO() optim = Adam({"lr": args.learning_rate}) if args.horovod: # Wrap the basic optimizer in a distributed optimizer. optim = HorovodOptimizer(optim) # Create a dataloader. dataset = torch.utils.data.TensorDataset(covariates, data) if args.horovod: # Horovod requires a distributed sampler. sampler = torch.utils.data.distributed.DistributedSampler( dataset, hvd.size(), hvd.rank()) else: sampler = torch.utils.data.RandomSampler(dataset) config = {"batch_size": args.batch_size, "sampler": sampler} if args.cuda: config["num_workers"] = 1 config["pin_memory"] = True # Try to use forkserver to spawn workers instead of fork. if (hasattr(mp, "_supports_context") and mp._supports_context and "forkserver" in mp.get_all_start_methods()): config["multiprocessing_context"] = "forkserver" dataloader = torch.utils.data.DataLoader(dataset, **config) # Run stochastic variational inference. svi = SVI(model, guide, optim, elbo) for epoch in range(args.num_epochs): if args.horovod: # Set rng seeds on distributed samplers. This is required. sampler.set_epoch(epoch) for step, (covariates_batch, data_batch) in enumerate(dataloader): loss = svi.step(covariates_batch.to(device), data_batch.to(device)) if args.horovod: # Optionally average loss metric across workers. # You can do this with arbitrary torch.Tensors. loss = torch.tensor(loss) loss = hvd.allreduce(loss, "loss") loss = loss.item() # Print only on the rank=0 worker. if step % 100 == 0 and hvd.rank() == 0: print("epoch {} step {} loss = {:0.4g}".format( epoch, step, loss)) else: if step % 100 == 0: print("epoch {} step {} loss = {:0.4g}".format( epoch, step, loss)) if args.horovod: # After we're done with the distributed parts of the program, # we can shutdown all but the rank=0 worker. hvd.shutdown() if hvd.rank() != 0: return if args.outfile: print("saving to {}".format(args.outfile)) torch.save({"model": model, "guide": guide}, args.outfile)
def fit_advi_iterative(self, n=3, method='advi', n_type='restart', n_iter=None, learning_rate=None, progressbar=True, num_workers=2, train_proportion=None, stratify_cv=None, l2_weight=False, sample_scaling_weight=0.5, checkpoints=None, checkpoint_dir='./checkpoints', tracking=False): r""" Train posterior using ADVI method. (maximising likehood of the data and minimising KL-divergence of posterior to prior) :param n: number of independent initialisations :param method: to allow for potential use of SVGD or MCMC (currently only ADVI implemented). :param n_type: type of repeated initialisation: 'restart' to pick different initial value, 'cv' for molecular cross-validation - splits counts into n datasets, for now, only n=2 is implemented 'bootstrap' for fitting the model to multiple downsampled datasets. Run `mod.bootstrap_data()` to generate variants of data :param n_iter: number of iterations, supersedes self.n_iter :param train_proportion: if not None, which proportion of cells to use for training and which for validation. :param checkpoints: int, list of int's or None, number of checkpoints to save while model training or list of iterations to save checkpoints on :param checkpoint_dir: str, directory to save checkpoints in :param tracking: bool, track all latent variables during training - if True makes training 2 times slower :return: None """ # initialise parameter store self.svi = {} self.hist = {} self.guide_i = {} self.samples = {} self.node_samples = {} if tracking: self.logp_hist = {} if n_iter is None: n_iter = self.n_iter if type(checkpoints) is int: if n_iter < checkpoints: checkpoints = n_iter checkpoints = np.linspace(0, n_iter, checkpoints + 1, dtype=int)[1:] self.checkpoints = list(checkpoints) else: self.checkpoints = checkpoints self.checkpoint_dir = checkpoint_dir self.n_type = n_type self.l2_weight = l2_weight self.sample_scaling_weight = sample_scaling_weight self.train_proportion = train_proportion if stratify_cv is not None: self.stratify_cv = stratify_cv if train_proportion is not None: self.validation_hist = {} self.training_hist = {} if tracking: self.logp_hist_val = {} self.logp_hist_train = {} if learning_rate is None: learning_rate = self.learning_rate if np.isin(n_type, ['bootstrap']): if self.X_data_sample is None: self.bootstrap_data(n=n) elif np.isin(n_type, ['cv']): self.generate_cv_data() # cv data added to self.X_data_sample init_names = ['init_' + str(i + 1) for i in np.arange(n)] for i, name in enumerate(init_names): ################### Initialise parameters & optimiser ################### # initialise Variational distribution = guide if method is 'advi': self.guide_i[name] = AutoGuideList(self.model) normal_guide_block = poutine.block( self.model, expose_all=True, hide_all=False, hide=self.point_estim + flatten_iterable(self.custom_guides.keys())) self.guide_i[name].append( AutoNormal(normal_guide_block, init_loc_fn=init_to_mean)) self.guide_i[name].append( AutoDelta( poutine.block(self.model, hide_all=True, expose=self.point_estim))) for k, v in self.custom_guides.items(): self.guide_i[name].append(v) elif method is 'custom': self.guide_i[name] = self.guide # initialise SVI inference method self.svi[name] = SVI( self.model, self.guide_i[name], optim.ClippedAdam({ 'lr': learning_rate, # limit the gradient step from becoming too large 'clip_norm': self.total_grad_norm_constraint }), loss=JitTrace_ELBO()) pyro.clear_param_store() self.set_initial_values() # record ELBO Loss history here self.hist[name] = [] if tracking: self.logp_hist[name] = defaultdict(list) if train_proportion is not None: self.validation_hist[name] = [] if tracking: self.logp_hist_val[name] = defaultdict(list) ################### Select data for this iteration ################### if np.isin(n_type, ['cv', 'bootstrap']): X_data = self.X_data_sample[i].astype(self.data_type) else: X_data = self.X_data.astype(self.data_type) ################### Training / validation split ################### # split into training and validation if train_proportion is not None: idx = np.arange(len(X_data)) train_idx, val_idx = train_test_split( idx, train_size=train_proportion, shuffle=True, stratify=self.stratify_cv) extra_data_val = { k: torch.FloatTensor(v[val_idx]).to(self.device) for k, v in self.extra_data.items() } extra_data_train = { k: torch.FloatTensor(v[train_idx]) for k, v in self.extra_data.items() } x_data_val = torch.FloatTensor(X_data[val_idx]).to(self.device) x_data = torch.FloatTensor(X_data[train_idx]) else: # just convert data to CPU tensors x_data = torch.FloatTensor(X_data) extra_data_train = { k: torch.FloatTensor(v) for k, v in self.extra_data.items() } ################### Move data to cuda - FULL data ################### # if not minibatch do this: if self.minibatch_size is None: # move tensors to CUDA x_data = x_data.to(self.device) for k in extra_data_train.keys(): extra_data_train[k] = extra_data_train[k].to(self.device) # extra_data_train = {k: v.to(self.device) for k, v in extra_data_train.items()} ################### MINIBATCH data ################### else: # create minibatches dataset = MiniBatchDataset(x_data, extra_data_train, return_idx=True) loader = DataLoader(dataset, batch_size=self.minibatch_size, num_workers=0) # TODO num_workers ################### Training the model ################### # start training in epochs epochs_iterator = tqdm(range(n_iter)) for epoch in epochs_iterator: if self.minibatch_size is None: ################### Training FULL data ################### iter_loss = self.step_train(name, x_data, extra_data_train) self.hist[name].append(iter_loss) # save data for posterior sampling self.x_data = x_data self.extra_data_train = extra_data_train if tracking: guide_tr, model_tr = self.step_trace( name, x_data, extra_data_train) self.logp_hist[name]['guide'].append( guide_tr.log_prob_sum().item()) self.logp_hist[name]['model'].append( model_tr.log_prob_sum().item()) for k, v in model_tr.nodes.items(): if "log_prob_sum" in v: self.logp_hist[name][k].append( v["log_prob_sum"].item()) else: ################### Training MINIBATCH data ################### aver_loss = [] if tracking: aver_logp_guide = [] aver_logp_model = [] aver_logp = defaultdict(list) for batch in loader: x_data_batch, extra_data_batch = batch x_data_batch = x_data_batch.to(self.device) extra_data_batch = { k: v.to(self.device) for k, v in extra_data_batch.items() } loss = self.step_train(name, x_data_batch, extra_data_batch) if tracking: guide_tr, model_tr = self.step_trace( name, x_data_batch, extra_data_batch) aver_logp_guide.append( guide_tr.log_prob_sum().item()) aver_logp_model.append( model_tr.log_prob_sum().item()) for k, v in model_tr.nodes.items(): if "log_prob_sum" in v: aver_logp[k].append( v["log_prob_sum"].item()) aver_loss.append(loss) iter_loss = np.sum(aver_loss) # save data for posterior sampling self.x_data = x_data_batch self.extra_data_train = extra_data_batch self.hist[name].append(iter_loss) if tracking: iter_logp_guide = np.sum(aver_logp_guide) iter_logp_model = np.sum(aver_logp_model) self.logp_hist[name]['guide'].append(iter_logp_guide) self.logp_hist[name]['model'].append(iter_logp_model) for k, v in aver_logp.items(): self.logp_hist[name][k].append(np.sum(v)) if self.checkpoints is not None: if (epoch + 1) in self.checkpoints: self.save_checkpoint(epoch + 1, prefix=name) ################### Evaluating cross-validation loss ################### if train_proportion is not None: iter_loss_val = self.step_eval_loss( name, x_data_val, extra_data_val) if tracking: guide_tr, model_tr = self.step_trace( name, x_data_val, extra_data_val) self.logp_hist_val[name]['guide'].append( guide_tr.log_prob_sum().item()) self.logp_hist_val[name]['model'].append( model_tr.log_prob_sum().item()) for k, v in model_tr.nodes.items(): if "log_prob_sum" in v: self.logp_hist_val[name][k].append( v["log_prob_sum"].item()) self.validation_hist[name].append(iter_loss_val) epochs_iterator.set_description(f'ELBO Loss: ' + '{:.4e}'.format(iter_loss) \ + ': Val loss: ' + '{:.4e}'.format(iter_loss_val)) else: epochs_iterator.set_description('ELBO Loss: ' + '{:.4e}'.format(iter_loss)) if epoch % 20 == 0: torch.cuda.empty_cache() if train_proportion is not None: # rescale loss self.validation_hist[name] = [ i / (1 - train_proportion) for i in self.validation_hist[name] ] self.hist[name] = [ i / train_proportion for i in self.hist[name] ] # reassing the main loss to be displayed self.training_hist[name] = self.hist[name] self.hist[name] = self.validation_hist[name] if tracking: for k, v in self.logp_hist[name].items(): self.logp_hist[name][k] = [ i / train_proportion for i in self.logp_hist[name][k] ] self.logp_hist_val[name][k] = [ i / (1 - train_proportion) for i in self.logp_hist_val[name][k] ] self.logp_hist_train[name] = self.logp_hist[name] self.logp_hist[name] = self.logp_hist_val[name] if self.verbose: print(plt.plot(np.log10(self.hist[name][0:])))
def fit_advi_iterative_simple( self, n: int = 3, method='advi', n_type='restart', n_iter=None, learning_rate=None, progressbar=True, ): r""" Find posterior using ADVI (deprecated) (maximising likehood of the data and minimising KL-divergence of posterior to prior) :param n: number of independent initialisations :param method: which approximation of the posterior (guide) to use?. * ``'advi'`` - Univariate normal approximation (pyro.infer.autoguide.AutoDiagonalNormal) * ``'custom'`` - Custom guide using conjugate posteriors :return: self.svi dictionary with svi pyro objects for each n, and sefl.elbo dictionary storing training history. """ # Pass data to pyro / pytorch self.x_data = torch.tensor(self.X_data.astype( self.data_type)) # .double() # initialise parameter store self.svi = {} self.hist = {} self.guide_i = {} self.samples = {} self.node_samples = {} self.n_type = n_type if n_iter is None: n_iter = self.n_iter if learning_rate is None: learning_rate = self.learning_rate if np.isin(n_type, ['bootstrap']): if self.X_data_sample is None: self.bootstrap_data(n=n) elif np.isin(n_type, ['cv']): self.generate_cv_data() # cv data added to self.X_data_sample init_names = ['init_' + str(i + 1) for i in np.arange(n)] for i, name in enumerate(init_names): # initialise Variational distributiion = guide if method is 'advi': self.guide_i[name] = AutoGuideList(self.model) self.guide_i[name].append( AutoNormal(poutine.block(self.model, expose_all=True, hide_all=False, hide=self.point_estim), init_loc_fn=init_to_mean)) self.guide_i[name].append( AutoDelta( poutine.block(self.model, hide_all=True, expose=self.point_estim))) elif method is 'custom': self.guide_i[name] = self.guide # initialise SVI inference method self.svi[name] = SVI( self.model, self.guide_i[name], optim.ClippedAdam({ 'lr': learning_rate, # limit the gradient step from becoming too large 'clip_norm': self.total_grad_norm_constraint }), loss=JitTrace_ELBO()) pyro.clear_param_store() # record ELBO Loss history here self.hist[name] = [] # pick dataset depending on the training mode and move to GPU if np.isin(n_type, ['cv', 'bootstrap']): self.x_data = torch.tensor(self.X_data_sample[i].astype( self.data_type)) else: self.x_data = torch.tensor(self.X_data.astype(self.data_type)) if self.use_cuda: # move tensors and modules to CUDA self.x_data = self.x_data.cuda() # train for n_iter it_iterator = tqdm(range(n_iter)) for it in it_iterator: hist = self.svi[name].step(self.x_data) it_iterator.set_description('ELBO Loss: ' + str(np.round(hist, 3))) self.hist[name].append(hist) # if it % 50 == 0 & self.verbose: # logging.info("Elbo loss: {}".format(hist)) if it % 500 == 0: torch.cuda.empty_cache()
def fit_svi(self, *, num_samples=100, num_steps=2000, num_particles=32, learning_rate=0.1, learning_rate_decay=0.01, betas=(0.8, 0.99), haar=True, init_scale=0.01, guide_rank=0, jit=False, log_every=200, **options): """ Runs stochastic variational inference to generate posterior samples. This runs :class:`~pyro.infer.svi.SVI`, setting the ``.samples`` attribute on completion. This approximate inference method is useful for quickly iterating on probabilistic models. :param int num_samples: Number of posterior samples to draw from the trained guide. Defaults to 100. :param int num_steps: Number of :class:`~pyro.infer.svi.SVI` steps. :param int num_particles: Number of :class:`~pyro.infer.svi.SVI` particles per step. :param int learning_rate: Learning rate for the :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. :param int learning_rate_decay: Learning rate for the :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. Note this is decay over the entire schedule, not per-step decay. :param tuple betas: Momentum parameters for the :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. :param bool haar: Whether to use a Haar wavelet reparameterizer. :param int guide_rank: Rank of the auto normal guide. If zero (default) use an :class:`~pyro.infer.autoguide.AutoNormal` guide. If a positive integer or None, use an :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide. If the string "full", use an :class:`~pyro.infer.autoguide.AutoMultivariateNormal` guide. These latter two require more ``num_steps`` to fit. :param float init_scale: Initial scale of the :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide. :param bool jit: Whether to use a jit compiled ELBO. :param int log_every: How often to log svi losses. :param int heuristic_num_particles: Passed to :meth:`heuristic` as ``num_particles``. Defaults to 1024. :returns: Time series of SVI losses (useful to diagnose convergence). :rtype: list """ # Save configuration for .predict(). self.relaxed = True self.num_quant_bins = 1 # Setup Haar wavelet transform. if haar: time_dim = -2 if self.is_regional else -1 dims = {"auxiliary": time_dim} supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)} for name, (fn, is_regional) in self._non_compartmental.items(): dims[name] = time_dim - fn.event_dim supports[name] = fn.support haar = _HaarSplitReparam(0, self.duration, dims, supports) # Heuristically initialize to feasible latents. heuristic_options = {k.replace("heuristic_", ""): options.pop(k) for k in list(options) if k.startswith("heuristic_")} assert not options, "unrecognized options: {}".format(", ".join(options)) init_strategy = self._heuristic(haar, **heuristic_options) # Configure variational inference. logger.info("Running inference...") model = self._relaxed_model if haar: model = haar.reparam(model) if guide_rank == 0: guide = AutoNormal(model, init_loc_fn=init_strategy, init_scale=init_scale) elif guide_rank == "full": guide = AutoMultivariateNormal(model, init_loc_fn=init_strategy, init_scale=init_scale) elif guide_rank is None or isinstance(guide_rank, int): guide = AutoLowRankMultivariateNormal(model, init_loc_fn=init_strategy, init_scale=init_scale, rank=guide_rank) else: raise ValueError("Invalid guide_rank: {}".format(guide_rank)) Elbo = JitTrace_ELBO if jit else Trace_ELBO elbo = Elbo(max_plate_nesting=self.max_plate_nesting, num_particles=num_particles, vectorize_particles=True, ignore_jit_warnings=True) optim = ClippedAdam({"lr": learning_rate, "betas": betas, "lrd": learning_rate_decay ** (1 / num_steps)}) svi = SVI(model, guide, optim, elbo) # Run inference. start_time = default_timer() losses = [] for step in range(1 + num_steps): loss = svi.step() / self.duration if step % log_every == 0: logger.info("step {} loss = {:0.4g}".format(step, loss)) losses.append(loss) elapsed = default_timer() - start_time logger.info("SVI took {:0.1f} seconds, {:0.1f} step/sec" .format(elapsed, (1 + num_steps) / elapsed)) # Draw posterior samples. with torch.no_grad(): particle_plate = pyro.plate("particles", num_samples, dim=-1 - self.max_plate_nesting) guide_trace = poutine.trace(particle_plate(guide)).get_trace() model_trace = poutine.trace( poutine.replay(particle_plate(model), guide_trace)).get_trace() self.samples = {name: site["value"] for name, site in model_trace.nodes.items() if site["type"] == "sample" if not site["is_observed"] if not site_is_subsample(site)} if haar: haar.aux_to_user(self.samples) assert all(v.size(0) == num_samples for v in self.samples.values()), \ {k: tuple(v.shape) for k, v in self.samples.items()} return losses
def _create_autoguide( self, model, amortised, encoder_kwargs, data_transform, encoder_mode, init_loc_fn=init_to_mean, n_cat_list: list = [], encoder_instance=None, ): if not amortised: _guide = AutoNormal( model, init_loc_fn=init_loc_fn, create_plates=model.create_plates, ) else: encoder_kwargs = encoder_kwargs if isinstance(encoder_kwargs, dict) else dict() n_hidden = encoder_kwargs["n_hidden"] if "n_hidden" in encoder_kwargs.keys() else 200 init_param_scale = ( encoder_kwargs["init_param_scale"] if "init_param_scale" in encoder_kwargs.keys() else 1 / 50 ) if "init_param_scale" in encoder_kwargs.keys(): del encoder_kwargs["init_param_scale"] amortised_vars = self.list_obs_plate_vars _guide = AutoGuideList(model, create_plates=model.create_plates) _guide.append( AutoNormal( pyro.poutine.block(model, hide=list(amortised_vars["sites"].keys())), init_loc_fn=init_loc_fn, ) ) if isinstance(data_transform, np.ndarray): # add extra info about gene clusters to the network self.register_buffer("gene_clusters", torch.tensor(data_transform.astype("float32"))) n_in = model.n_vars + data_transform.shape[1] data_transform = self.data_transform_clusters() elif data_transform == "log1p": # use simple log1p transform data_transform = torch.log1p n_in = self.model.n_vars elif ( isinstance(data_transform, dict) and "var_std" in list(data_transform.keys()) and "var_mean" in list(data_transform.keys()) ): # use data transform by scaling n_in = model.n_vars self.register_buffer( "var_mean", torch.tensor(data_transform["var_mean"].astype("float32").reshape((1, n_in))), ) self.register_buffer( "var_std", torch.tensor(data_transform["var_std"].astype("float32").reshape((1, n_in))), ) data_transform = self.data_transform_scale() else: # use custom data transform data_transform = data_transform n_in = model.n_vars if len(amortised_vars["input"]) >= 2: encoder_kwargs["n_cat_list"] = n_cat_list amortised_vars["input_transform"][0] = data_transform _guide.append( AutoNormalEncoder( pyro.poutine.block(model, expose=list(amortised_vars["sites"].keys())), amortised_plate_sites=amortised_vars, n_in=n_in, n_hidden=n_hidden, init_param_scale=init_param_scale, encoder_kwargs=encoder_kwargs, encoder_mode=encoder_mode, encoder_instance=encoder_instance, ) ) return _guide
# Load full unnormalised data # X_train, y_train = torch.load('/Users/ricard/test/pyro/files/MNIST/processed/training.pt') # X_test, y_test = torch.load('/Users/ricard/test/pyro/files/MNIST/processed/test.pt') # Load data using torch.utils.data.DataLoader train_loader, test_loader = setup_data_loaders(batch_size=512, subset=True) # clear param store pyro.clear_param_store() # setup the Factor Analysis model fa = FA() # Define guide using automatic ELBO with mean-field assumption guide = AutoNormal(fa) # Defineg guide manually # guide = fa.guide optim = Adam({"lr": LEARNING_RATE}) svi = SVI(fa.forward, guide=guide, optim=optim, loss=Trace_ELBO()) # training loop train_elbo = [] test_elbo = [] for epoch in range(NUM_EPOCHS): total_epoch_loss_train = train(svi, train_loader) train_elbo.append(-total_epoch_loss_train) # print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train))