def AutoMixed(model_full, init_loc={}, delta=None): guide = AutoGuideList(model_full) marginalised_guide_block = poutine.block(model_full, expose_all=True, hide_all=False, hide=['tau']) if delta is None: guide.append( AutoNormal(marginalised_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc), init_scale=0.05)) elif delta == 'part' or delta == 'all': guide.append( AutoDelta(marginalised_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc))) full_rank_guide_block = poutine.block(model_full, hide_all=True, expose=['tau']) if delta is None or delta == 'part': guide.append( AutoMultivariateNormal( full_rank_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc), init_scale=0.05)) else: guide.append( AutoDelta(full_rank_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc))) return guide
def test_guide_list(auto_class): def model(): pyro.sample("x", dist.Normal(0., 1.).expand([2])) pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5))) guide = AutoGuideList(model) guide.append(auto_class(poutine.block(model, expose=["x"]))) guide.append(auto_class(poutine.block(model, expose=["y"]))) guide()
def nested_auto_guide_callable(model): guide = AutoGuideList(model) guide.append(AutoDelta(poutine.block(model, expose=['x']))) guide_y = AutoGuideList(poutine.block(model, expose=['y'])) guide_y.z = AutoIAFNormal(poutine.block(model, expose=['y'])) guide.append(guide_y) return guide
def auto_guide_module_callable(model): class GuideX(AutoGuide): def __init__(self, model): super().__init__(model) self.x_loc = nn.Parameter(torch.tensor(1.)) self.x_scale = PyroParam(torch.tensor(.1), constraint=constraints.positive) def forward(self, *args, **kwargs): return {"x": pyro.sample("x", dist.Normal(self.x_loc, self.x_scale))} def median(self, *args, **kwargs): return {"x": self.x_loc.detach()} guide = AutoGuideList(model) guide.custom = GuideX(model) guide.diagnorm = AutoDiagonalNormal(poutine.block(model, hide=["x"])) return guide
def auto_guide_callable(model): def guide_x(): x_loc = pyro.param("x_loc", torch.tensor(1.)) x_scale = pyro.param("x_scale", torch.tensor(.1), constraint=constraints.positive) pyro.sample("x", dist.Normal(x_loc, x_scale)) def median_x(): return {"x": pyro.param("x_loc", torch.tensor(1.))} guide = AutoGuideList(model) guide.append(AutoCallable(model, guide_x, median_x)) guide.append(AutoDiagonalNormal(poutine.block(model, hide=["x"]))) return guide
def test_callable(auto_class): def model(): pyro.sample("x", dist.Normal(0., 1.)) pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5))) def guide_x(): x_loc = pyro.param("x_loc", torch.tensor(0.)) pyro.sample("x", dist.Delta(x_loc)) guide = AutoGuideList(model) guide.append(guide_x) guide.append(auto_class(poutine.block(model, expose=["y"]))) values = guide() assert set(values) == set(["y"])
def test_subsample_guide(auto_class, init_fn): # The model from tutorial/source/easyguide.ipynb def model(batch, subsample, full_size): num_time_steps = len(batch) result = [None] * num_time_steps drift = pyro.sample("drift", dist.LogNormal(-1, 0.5)) plate = pyro.plate("data", full_size, subsample=subsample) assert plate.size == 50 with plate: z = 0. for t in range(num_time_steps): z = pyro.sample("state_{}".format(t), dist.Normal(z, drift)) result[t] = pyro.sample("obs_{}".format(t), dist.Bernoulli(logits=z), obs=batch[t]) return torch.stack(result) def create_plates(batch, subsample, full_size): return pyro.plate("data", full_size, subsample=subsample) if auto_class == AutoGuideList: guide = AutoGuideList(model, create_plates=create_plates) guide.add(AutoDelta(poutine.block(model, expose=["drift"]))) guide.add(AutoNormal(poutine.block(model, hide=["drift"]))) else: guide = auto_class(model, create_plates=create_plates) full_size = 50 batch_size = 20 num_time_steps = 8 pyro.set_rng_seed(123456789) data = model([None] * num_time_steps, torch.arange(full_size), full_size) assert data.shape == (num_time_steps, full_size) pyro.get_param_store().clear() pyro.set_rng_seed(123456789) svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO()) for epoch in range(2): beg = 0 while beg < full_size: end = min(full_size, beg + batch_size) subsample = torch.arange(beg, end) batch = data[:, beg:end] beg = end svi.step(batch, subsample, full_size=full_size)
def test_discrete_parallel(continuous_class): K = 2 data = torch.tensor([0., 1., 10., 11., 12.]) def model(data): weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K))) locs = pyro.sample('locs', dist.Normal(0, 10).expand_by([K]).to_event(1)) scale = pyro.sample('scale', dist.LogNormal(0, 1)) with pyro.plate('data', len(data)): weights = weights.expand(torch.Size((len(data),)) + weights.shape) assignment = pyro.sample('assignment', dist.Categorical(weights)) pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data) guide = AutoGuideList(model) guide.append(continuous_class(poutine.block(model, hide=["assignment"]))) guide.append(AutoDiscreteParallel(poutine.block(model, expose=["assignment"]))) elbo = TraceEnum_ELBO(max_plate_nesting=1) loss = elbo.loss_and_grads(model, guide, data) assert np.isfinite(loss), loss
def auto_guide_list_x(model): guide = AutoGuideList(model) guide.append(AutoDelta(poutine.block(model, expose=["x"]))) guide.append(AutoDiagonalNormal(poutine.block(model, hide=["x"]))) return guide
def fit_advi_iterative(self, n=3, method='advi', n_type='restart', n_iter=None, learning_rate=None, progressbar=True, num_workers=2, train_proportion=None, stratify_cv=None, l2_weight=False, sample_scaling_weight=0.5, checkpoints=None, checkpoint_dir='./checkpoints', tracking=False): r""" Train posterior using ADVI method. (maximising likehood of the data and minimising KL-divergence of posterior to prior) :param n: number of independent initialisations :param method: to allow for potential use of SVGD or MCMC (currently only ADVI implemented). :param n_type: type of repeated initialisation: 'restart' to pick different initial value, 'cv' for molecular cross-validation - splits counts into n datasets, for now, only n=2 is implemented 'bootstrap' for fitting the model to multiple downsampled datasets. Run `mod.bootstrap_data()` to generate variants of data :param n_iter: number of iterations, supersedes self.n_iter :param train_proportion: if not None, which proportion of cells to use for training and which for validation. :param checkpoints: int, list of int's or None, number of checkpoints to save while model training or list of iterations to save checkpoints on :param checkpoint_dir: str, directory to save checkpoints in :param tracking: bool, track all latent variables during training - if True makes training 2 times slower :return: None """ # initialise parameter store self.svi = {} self.hist = {} self.guide_i = {} self.samples = {} self.node_samples = {} if tracking: self.logp_hist = {} if n_iter is None: n_iter = self.n_iter if type(checkpoints) is int: if n_iter < checkpoints: checkpoints = n_iter checkpoints = np.linspace(0, n_iter, checkpoints + 1, dtype=int)[1:] self.checkpoints = list(checkpoints) else: self.checkpoints = checkpoints self.checkpoint_dir = checkpoint_dir self.n_type = n_type self.l2_weight = l2_weight self.sample_scaling_weight = sample_scaling_weight self.train_proportion = train_proportion if stratify_cv is not None: self.stratify_cv = stratify_cv if train_proportion is not None: self.validation_hist = {} self.training_hist = {} if tracking: self.logp_hist_val = {} self.logp_hist_train = {} if learning_rate is None: learning_rate = self.learning_rate if np.isin(n_type, ['bootstrap']): if self.X_data_sample is None: self.bootstrap_data(n=n) elif np.isin(n_type, ['cv']): self.generate_cv_data() # cv data added to self.X_data_sample init_names = ['init_' + str(i + 1) for i in np.arange(n)] for i, name in enumerate(init_names): ################### Initialise parameters & optimiser ################### # initialise Variational distribution = guide if method is 'advi': self.guide_i[name] = AutoGuideList(self.model) normal_guide_block = poutine.block( self.model, expose_all=True, hide_all=False, hide=self.point_estim + flatten_iterable(self.custom_guides.keys())) self.guide_i[name].append( AutoNormal(normal_guide_block, init_loc_fn=init_to_mean)) self.guide_i[name].append( AutoDelta( poutine.block(self.model, hide_all=True, expose=self.point_estim))) for k, v in self.custom_guides.items(): self.guide_i[name].append(v) elif method is 'custom': self.guide_i[name] = self.guide # initialise SVI inference method self.svi[name] = SVI( self.model, self.guide_i[name], optim.ClippedAdam({ 'lr': learning_rate, # limit the gradient step from becoming too large 'clip_norm': self.total_grad_norm_constraint }), loss=JitTrace_ELBO()) pyro.clear_param_store() self.set_initial_values() # record ELBO Loss history here self.hist[name] = [] if tracking: self.logp_hist[name] = defaultdict(list) if train_proportion is not None: self.validation_hist[name] = [] if tracking: self.logp_hist_val[name] = defaultdict(list) ################### Select data for this iteration ################### if np.isin(n_type, ['cv', 'bootstrap']): X_data = self.X_data_sample[i].astype(self.data_type) else: X_data = self.X_data.astype(self.data_type) ################### Training / validation split ################### # split into training and validation if train_proportion is not None: idx = np.arange(len(X_data)) train_idx, val_idx = train_test_split( idx, train_size=train_proportion, shuffle=True, stratify=self.stratify_cv) extra_data_val = { k: torch.FloatTensor(v[val_idx]).to(self.device) for k, v in self.extra_data.items() } extra_data_train = { k: torch.FloatTensor(v[train_idx]) for k, v in self.extra_data.items() } x_data_val = torch.FloatTensor(X_data[val_idx]).to(self.device) x_data = torch.FloatTensor(X_data[train_idx]) else: # just convert data to CPU tensors x_data = torch.FloatTensor(X_data) extra_data_train = { k: torch.FloatTensor(v) for k, v in self.extra_data.items() } ################### Move data to cuda - FULL data ################### # if not minibatch do this: if self.minibatch_size is None: # move tensors to CUDA x_data = x_data.to(self.device) for k in extra_data_train.keys(): extra_data_train[k] = extra_data_train[k].to(self.device) # extra_data_train = {k: v.to(self.device) for k, v in extra_data_train.items()} ################### MINIBATCH data ################### else: # create minibatches dataset = MiniBatchDataset(x_data, extra_data_train, return_idx=True) loader = DataLoader(dataset, batch_size=self.minibatch_size, num_workers=0) # TODO num_workers ################### Training the model ################### # start training in epochs epochs_iterator = tqdm(range(n_iter)) for epoch in epochs_iterator: if self.minibatch_size is None: ################### Training FULL data ################### iter_loss = self.step_train(name, x_data, extra_data_train) self.hist[name].append(iter_loss) # save data for posterior sampling self.x_data = x_data self.extra_data_train = extra_data_train if tracking: guide_tr, model_tr = self.step_trace( name, x_data, extra_data_train) self.logp_hist[name]['guide'].append( guide_tr.log_prob_sum().item()) self.logp_hist[name]['model'].append( model_tr.log_prob_sum().item()) for k, v in model_tr.nodes.items(): if "log_prob_sum" in v: self.logp_hist[name][k].append( v["log_prob_sum"].item()) else: ################### Training MINIBATCH data ################### aver_loss = [] if tracking: aver_logp_guide = [] aver_logp_model = [] aver_logp = defaultdict(list) for batch in loader: x_data_batch, extra_data_batch = batch x_data_batch = x_data_batch.to(self.device) extra_data_batch = { k: v.to(self.device) for k, v in extra_data_batch.items() } loss = self.step_train(name, x_data_batch, extra_data_batch) if tracking: guide_tr, model_tr = self.step_trace( name, x_data_batch, extra_data_batch) aver_logp_guide.append( guide_tr.log_prob_sum().item()) aver_logp_model.append( model_tr.log_prob_sum().item()) for k, v in model_tr.nodes.items(): if "log_prob_sum" in v: aver_logp[k].append( v["log_prob_sum"].item()) aver_loss.append(loss) iter_loss = np.sum(aver_loss) # save data for posterior sampling self.x_data = x_data_batch self.extra_data_train = extra_data_batch self.hist[name].append(iter_loss) if tracking: iter_logp_guide = np.sum(aver_logp_guide) iter_logp_model = np.sum(aver_logp_model) self.logp_hist[name]['guide'].append(iter_logp_guide) self.logp_hist[name]['model'].append(iter_logp_model) for k, v in aver_logp.items(): self.logp_hist[name][k].append(np.sum(v)) if self.checkpoints is not None: if (epoch + 1) in self.checkpoints: self.save_checkpoint(epoch + 1, prefix=name) ################### Evaluating cross-validation loss ################### if train_proportion is not None: iter_loss_val = self.step_eval_loss( name, x_data_val, extra_data_val) if tracking: guide_tr, model_tr = self.step_trace( name, x_data_val, extra_data_val) self.logp_hist_val[name]['guide'].append( guide_tr.log_prob_sum().item()) self.logp_hist_val[name]['model'].append( model_tr.log_prob_sum().item()) for k, v in model_tr.nodes.items(): if "log_prob_sum" in v: self.logp_hist_val[name][k].append( v["log_prob_sum"].item()) self.validation_hist[name].append(iter_loss_val) epochs_iterator.set_description(f'ELBO Loss: ' + '{:.4e}'.format(iter_loss) \ + ': Val loss: ' + '{:.4e}'.format(iter_loss_val)) else: epochs_iterator.set_description('ELBO Loss: ' + '{:.4e}'.format(iter_loss)) if epoch % 20 == 0: torch.cuda.empty_cache() if train_proportion is not None: # rescale loss self.validation_hist[name] = [ i / (1 - train_proportion) for i in self.validation_hist[name] ] self.hist[name] = [ i / train_proportion for i in self.hist[name] ] # reassing the main loss to be displayed self.training_hist[name] = self.hist[name] self.hist[name] = self.validation_hist[name] if tracking: for k, v in self.logp_hist[name].items(): self.logp_hist[name][k] = [ i / train_proportion for i in self.logp_hist[name][k] ] self.logp_hist_val[name][k] = [ i / (1 - train_proportion) for i in self.logp_hist_val[name][k] ] self.logp_hist_train[name] = self.logp_hist[name] self.logp_hist[name] = self.logp_hist_val[name] if self.verbose: print(plt.plot(np.log10(self.hist[name][0:])))
def fit_advi_iterative_simple( self, n: int = 3, method='advi', n_type='restart', n_iter=None, learning_rate=None, progressbar=True, ): r""" Find posterior using ADVI (deprecated) (maximising likehood of the data and minimising KL-divergence of posterior to prior) :param n: number of independent initialisations :param method: which approximation of the posterior (guide) to use?. * ``'advi'`` - Univariate normal approximation (pyro.infer.autoguide.AutoDiagonalNormal) * ``'custom'`` - Custom guide using conjugate posteriors :return: self.svi dictionary with svi pyro objects for each n, and sefl.elbo dictionary storing training history. """ # Pass data to pyro / pytorch self.x_data = torch.tensor(self.X_data.astype( self.data_type)) # .double() # initialise parameter store self.svi = {} self.hist = {} self.guide_i = {} self.samples = {} self.node_samples = {} self.n_type = n_type if n_iter is None: n_iter = self.n_iter if learning_rate is None: learning_rate = self.learning_rate if np.isin(n_type, ['bootstrap']): if self.X_data_sample is None: self.bootstrap_data(n=n) elif np.isin(n_type, ['cv']): self.generate_cv_data() # cv data added to self.X_data_sample init_names = ['init_' + str(i + 1) for i in np.arange(n)] for i, name in enumerate(init_names): # initialise Variational distributiion = guide if method is 'advi': self.guide_i[name] = AutoGuideList(self.model) self.guide_i[name].append( AutoNormal(poutine.block(self.model, expose_all=True, hide_all=False, hide=self.point_estim), init_loc_fn=init_to_mean)) self.guide_i[name].append( AutoDelta( poutine.block(self.model, hide_all=True, expose=self.point_estim))) elif method is 'custom': self.guide_i[name] = self.guide # initialise SVI inference method self.svi[name] = SVI( self.model, self.guide_i[name], optim.ClippedAdam({ 'lr': learning_rate, # limit the gradient step from becoming too large 'clip_norm': self.total_grad_norm_constraint }), loss=JitTrace_ELBO()) pyro.clear_param_store() # record ELBO Loss history here self.hist[name] = [] # pick dataset depending on the training mode and move to GPU if np.isin(n_type, ['cv', 'bootstrap']): self.x_data = torch.tensor(self.X_data_sample[i].astype( self.data_type)) else: self.x_data = torch.tensor(self.X_data.astype(self.data_type)) if self.use_cuda: # move tensors and modules to CUDA self.x_data = self.x_data.cuda() # train for n_iter it_iterator = tqdm(range(n_iter)) for it in it_iterator: hist = self.svi[name].step(self.x_data) it_iterator.set_description('ELBO Loss: ' + str(np.round(hist, 3))) self.hist[name].append(hist) # if it % 50 == 0 & self.verbose: # logging.info("Elbo loss: {}".format(hist)) if it % 500 == 0: torch.cuda.empty_cache()