def AutoMixed(model_full, init_loc={}, delta=None): guide = AutoGuideList(model_full) marginalised_guide_block = poutine.block(model_full, expose_all=True, hide_all=False, hide=['tau']) if delta is None: guide.append( AutoNormal(marginalised_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc), init_scale=0.05)) elif delta == 'part' or delta == 'all': guide.append( AutoDelta(marginalised_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc))) full_rank_guide_block = poutine.block(model_full, hide_all=True, expose=['tau']) if delta is None or delta == 'part': guide.append( AutoMultivariateNormal( full_rank_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc), init_scale=0.05)) else: guide.append( AutoDelta(full_rank_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc))) return guide
def guide(self, MAP=False, *args, **kwargs): if (MAP): return AutoDelta(poutine.block( self.model, expose=['mixture_weights', 'cnv_probs', 'norm_sd']), init_loc_fn=self.init_fn()) else: def guide_ret(*args, **kwargs): I, N = self._data['data'].shape batch = N if self._params['batch_size'] else self._params[ 'batch_size'] param_weights = pyro.param( "param_weights", lambda: torch.ones(self._params['K']) / self._params['K'], constraint=constraints.simplex) cnv_mean = pyro.param( "param_cnv_mean", lambda: self.create_gaussian_init_values(), constraint=constraints.positive) cnv_var = pyro.param( "param_cnv_var", lambda: torch.ones(1) * self._params['init_sd'], constraint=constraints.positive) pyro.sample('mixture_weights', dist.Dirichlet(param_weights)) with pyro.plate('segments', I): with pyro.plate('components', self._params['K']): pyro.sample( 'cnv_probs', dist.LogNormal(torch.log(cnv_mean), cnv_var)) return guide_ret
def guide(self,MAP = False,*args, **kwargs): if (MAP): return AutoDelta(poutine.block(self.model, expose=['mixture_weights', 'norm_factor', 'cnv_probs']), init_loc_fn=self.init_fn()) else: def guide_ret(*args, **kwargs): I, N = self._data['data'].shape batch = N if self._params['batch_size'] else self._params['batch_size'] kappa = pyro.param('param_kappa', lambda: dist.Uniform(0, 1).sample([self._params['T'] - 1]), constraint=constraints.positive) cnv_mean = pyro.param("param_cnv_mean", lambda: self.create_gaussian_init_values(), constraint=constraints.positive) cnv_var = pyro.param("param_cnv_var", lambda: torch.ones(1) * self._params['cnv_var'], constraint=constraints.positive) gamma_scale = pyro.param("param_gamma_scale", lambda: torch.mean( self._data['data'] / (2 * self._data['mu'].reshape(self._data['data'].shape[0], 1)), axis=0) * self._params['gamma_multiplier'], constraint=constraints.positive) gamma_rate = pyro.param("param_rate", lambda: torch.ones(1) * self._params['gamma_multiplier'], constraint=constraints.positive) param_weights = pyro.param("param_weights", lambda: torch.ones(self._params['T']) / self._params['T'], constraint=constraints.simplex) with pyro.plate("beta_plate", self._params['T'] - 1): pyro.sample("mixture_weights", dist.Beta(1, kappa)) with pyro.plate('segments', I): with pyro.plate('components', self._params['T']): pyro.sample('cnv_probs', dist.LogNormal(torch.log(cnv_mean), cnv_var)) with pyro.plate("data2", N, batch): pyro.sample('norm_factor', dist.Gamma(gamma_scale, gamma_rate).expand([N])) return guide_ret
def test_reparam_stable(): data = dist.Poisson(torch.randn(8).exp()).sample() @poutine.reparam(config={ "dz": LatentStableReparam(), "y": LatentStableReparam() }) def model(): stability = pyro.sample("stability", dist.Uniform(1.0, 2.0)) trans_skew = pyro.sample("trans_skew", dist.Uniform(-1.0, 1.0)) obs_skew = pyro.sample("obs_skew", dist.Uniform(-1.0, 1.0)) scale = pyro.sample("scale", dist.Gamma(3, 1)) # We use separate plates because the .cumsum() op breaks independence. with pyro.plate("time1", len(data)): dz = pyro.sample("dz", dist.Stable(stability, trans_skew)) z = dz.cumsum(-1) with pyro.plate("time2", len(data)): y = pyro.sample("y", dist.Stable(stability, obs_skew, scale, z)) pyro.sample("x", dist.Poisson(y.abs()), obs=data) guide = AutoDelta(model) svi = SVI(model, guide, optim.Adam({"lr": 0.01}), Trace_ELBO()) for step in range(100): loss = svi.step() if step % 20 == 0: logger.info("step {} loss = {:0.4g}".format(step, loss))
def guide(self, *args, **kwargs): return AutoDelta(poutine.block(self.model, expose=[ 'mixture_weights', 'norm_factor', 'cnv_probs', 'gene_basal' ]), init_loc_fn=self.init_fn())
def nested_auto_guide_callable(model): guide = AutoGuideList(model) guide.append(AutoDelta(poutine.block(model, expose=['x']))) guide_y = AutoGuideList(poutine.block(model, expose=['y'])) guide_y.z = AutoIAFNormal(poutine.block(model, expose=['y'])) guide.append(guide_y) return guide
def _get_initial_trace(): guide = AutoDelta(poutine.block(model, expose_fn=lambda msg: not msg["name"].startswith("x") and not msg["name"].startswith("y"))) elbo = TraceEnum_ELBO(max_plate_nesting=1) svi = SVI(model, guide, optim.Adam({"lr": .01}), elbo) for _ in range(100): svi.step(data) return poutine.trace(guide).get_trace(data)
def fit(self, model_name, model_param_names, data_input, fitter=None, init_values=None): verbose = self.verbose message = self.message learning_rate = self.learning_rate seed = self.seed num_steps = self.num_steps learning_rate_total_decay = self.learning_rate_total_decay pyro.set_rng_seed(seed) if fitter is None: fitter = get_pyro_model(model_name) # abstract model = fitter(data_input) # concrete # Perform MAP inference using an AutoDelta guide. pyro.clear_param_store() guide = AutoDelta(model) optim = ClippedAdam({ "lr": learning_rate, "lrd": learning_rate_total_decay**(1 / num_steps), "betas": (0.5, 0.8) }) elbo = Trace_ELBO() loss_elbo = list() svi = SVI(model, guide, optim, elbo) for step in range(num_steps): loss = svi.step() loss_elbo.append(loss) if verbose and step % message == 0: print("step {: >4d} loss = {:0.5g}".format(step, loss)) # Extract point estimates. values = guide() values.update(pyro.poutine.condition(model, values)()) # Convert from torch.Tensors to numpy.ndarrays. extract = { name: value.detach().numpy() for name, value in values.items() } # make sure that model param names are a subset of stan extract keys invalid_model_param = set(model_param_names) - set(list( extract.keys())) if invalid_model_param: raise EstimatorException( "Pyro model definition does not contain required parameters") # `stan.optimizing` automatically returns all defined parameters # filter out unnecessary keys posteriors = {param: extract[param] for param in model_param_names} training_metrics = {'loss_elbo': np.array(loss_elbo)} return posteriors, training_metrics
def _get_initial_trace(): guide = AutoDelta( poutine.block(model, expose_fn=lambda msg: not msg["name"].startswith("x") and not msg["name"].startswith("y"))) elbo = TraceEnum_ELBO(max_plate_nesting=1) svi = SVI(model, guide, optim.Adam({"lr": .01}), elbo, num_steps=100).run(data) return svi.exec_traces[-1]
def guide(self, MAP=False, *args, **kwargs): if (MAP): exposing = ['cnv_probs'] if self._params['assignments'] is None: exposing.append('mixture_weights') if self._params['norm_factor'] is None: exposing.append('norm_factor') return AutoDelta(poutine.block(self.model, expose=exposing), init_loc_fn=self.init_fn()) else: def guide_ret(*args, **kwargs): I, N = self._data['data'].shape batch = N if self._params['batch_size'] else self._params[ 'batch_size'] if self._params['assignments'] is None: param_weights = pyro.param("param_mixture_weights", lambda: torch.ones(self._params[ 'K']) / self._params['K'], constraint=constraints.simplex) if self._params['cnv_locs'] is None: cnv_mean = pyro.param( "param_cnv_probs", lambda: self.create_gaussian_init_values(), constraint=constraints.positive) else: cnv_mean = self._params['cnv_locs'] cnv_var = pyro.param("param_cnv_var", lambda: torch.ones([self._params['K'], I]) * self._params['cnv_sd'], constraint=constraints.positive) if self._params['norm_factor'] is None: gamma_scale = pyro.param( "param_norm_factor", lambda: torch.mean(self._data['data'] / (self._data[ 'mu'].reshape(self._data['data'].shape[0], 1)), axis=0), constraint=constraints.positive) if self._params['assignments'] is None: pyro.sample('mixture_weights', dist.Dirichlet(param_weights)) with pyro.plate('segments', I): with pyro.plate('components', self._params['K']): pyro.sample( 'cnv_probs', dist.LogNormal(torch.log(cnv_mean), cnv_var)) with pyro.plate("data2", N, batch): if self._params['norm_factor'] is None: pyro.sample('norm_factor', dist.Delta(gamma_scale)) return guide_ret
def initialize(data): pyro.clear_param_store() optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]}) elbo = TraceEnum_ELBO(max_plate_nesting=2) # global global_guide global_guide = AutoDelta( poutine.block(model, expose=['weights', 'mus', 'lambdas'])) svi = SVI(model, global_guide, optim, loss=elbo) svi.loss(model, global_guide, data) return svi
def test_posterior_predictive_svi_auto_delta_guide(parallel): true_probs = torch.ones(5) * 0.7 num_trials = torch.ones(5) * 1000 num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) guide = AutoDelta(conditioned_model) svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=1.0)), Trace_ELBO()) for i in range(1000): svi.step(num_trials) posterior_predictive = Predictive(model, guide=guide, num_samples=10000, parallel=parallel) marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"] assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05)
def test_posterior_predictive_svi_one_hot(): pseudocounts = torch.ones(3) * 0.1 true_probs = torch.tensor([0.15, 0.6, 0.25]) classes = dist.OneHotCategorical(true_probs).sample((10000,)) guide = AutoDelta(one_hot_model) svi = SVI(one_hot_model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO()) for i in range(1000): svi.step(pseudocounts, classes=classes) posterior_samples = Predictive(guide, num_samples=10000).get_samples(pseudocounts) posterior_predictive = Predictive(one_hot_model, posterior_samples) marginal_return_vals = posterior_predictive.get_samples(pseudocounts)["obs"] assert_close(marginal_return_vals.mean(dim=0), true_probs.unsqueeze(0), rtol=0.1)
def initialize(seed): global global_guide, svi pyro.set_rng_seed(seed) pyro.clear_param_store() global_guide = AutoDelta( poutine.block( model, expose=[ 'weights', 'locs', 'scale']), init_loc_fn=init_loc_fn) svi = SVI(model, global_guide, optim, loss=elbo) return svi.loss(model, global_guide, data)
def main(_argv): transition_alphas = torch.tensor([[10., 90.], [90., 10.]]) emission_alphas = torch.tensor([[[30., 20., 5.]], [[5., 10., 100.]]]) lengths = torch.randint(10, 30, (10000,)) trace = poutine.trace(model).get_trace(transition_alphas, emission_alphas, lengths) obs_sequences = [site['value'] for name, site in trace.nodes.items() if name.startswith("element_")] obs_sequences = torch.stack(obs_sequences, dim=-2) guide = AutoDelta(poutine.block(model, hide_fn=lambda site: site['name'].startswith('state')), init_loc_fn=init_to_sample) svi = SVI(model, guide, Adam(dict(lr=0.1)), JitTraceEnum_ELBO()) total = 1000 with tqdm.trange(total) as t: for i in t: loss = svi.step(0.5 * torch.ones((2, 2), dtype=torch.float), 0.3 * torch.ones((2, 1, 3), dtype=torch.float), lengths, obs_sequences) t.set_description_str(f"SVI ({i}/{total}): {loss}") median = guide.median() print("Transition probs: ", median['transition_probs'].detach().numpy()) print("Emission probs: ", median['emission_probs'].squeeze().detach().numpy())
def test_subsample_guide(auto_class, init_fn): # The model from tutorial/source/easyguide.ipynb def model(batch, subsample, full_size): num_time_steps = len(batch) result = [None] * num_time_steps drift = pyro.sample("drift", dist.LogNormal(-1, 0.5)) plate = pyro.plate("data", full_size, subsample=subsample) assert plate.size == 50 with plate: z = 0. for t in range(num_time_steps): z = pyro.sample("state_{}".format(t), dist.Normal(z, drift)) result[t] = pyro.sample("obs_{}".format(t), dist.Bernoulli(logits=z), obs=batch[t]) return torch.stack(result) def create_plates(batch, subsample, full_size): return pyro.plate("data", full_size, subsample=subsample) if auto_class == AutoGuideList: guide = AutoGuideList(model, create_plates=create_plates) guide.add(AutoDelta(poutine.block(model, expose=["drift"]))) guide.add(AutoNormal(poutine.block(model, hide=["drift"]))) else: guide = auto_class(model, create_plates=create_plates) full_size = 50 batch_size = 20 num_time_steps = 8 pyro.set_rng_seed(123456789) data = model([None] * num_time_steps, torch.arange(full_size), full_size) assert data.shape == (num_time_steps, full_size) pyro.get_param_store().clear() pyro.set_rng_seed(123456789) svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO()) for epoch in range(2): beg = 0 while beg < full_size: end = min(full_size, beg + batch_size) subsample = torch.arange(beg, end) batch = data[:, beg:end] beg = end svi.step(batch, subsample, full_size=full_size)
def main(_argv): aas, ds, phis, psis, lengths = ProteinParser.parsef_tensor( 'data/TorusDBN/top500.txt') guide = AutoDelta(poutine.block( torus_dbn, hide_fn=lambda site: site['name'].startswith('state')), init_loc_fn=init_to_sample) svi = SVI(torus_dbn, guide, Adam(dict(lr=0.1)), TraceEnum_ELBO()) plot_rama(lengths, phis, psis, filename='ground_truth') total_iters = 100 num_states = 55 plot_rate = 5 dataset = TensorDataset(phis, psis, lengths) batch_size = 32 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) total_losses = [] with tqdm.trange(total_iters) as pbar: total_loss = float('inf') for i in pbar: losses = [] num_batches = 0 for j, (phis, psis, lengths) in enumerate(dataloader): loss = svi.step(phis, psis, lengths, num_states=num_states) losses.append(loss) num_batches += 1 pbar.set_description_str( f"SVI (batch {j}/{len(dataset)//batch_size}):" f" {loss / batch_size:.4} [epoch loss: {total_loss:.4}]", refresh=True) total_loss = np.sum(losses) / (batch_size * num_batches) total_losses.append(total_loss) pbar.set_description_str( f"SVI (batch {j}/{len(dataset)//batch_size}):" f" {loss / batch_size:.4} [epoch loss: {total_loss:.4}]", refresh=True) if i % plot_rate == 0: sample_and_plot(torus_dbn, guide, filename=f'learned_{i}', num_sequences=len(dataset), num_states=num_states) sample_and_plot(torus_dbn, guide, filename=f'learned_finish', num_sequences=len(dataset), num_states=num_states) plot_losses(total_losses)
def test_posterior_predictive_svi_auto_delta_guide(): true_probs = torch.ones(5) * 0.7 num_trials = torch.ones(5) * 1000 num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) opt = optim.Adam(dict(lr=1.0)) loss = Trace_ELBO() guide = AutoDelta(conditioned_model) svi_run = SVI(conditioned_model, guide, opt, loss, num_steps=1000, num_samples=100).run(num_trials) posterior_predictive = TracePredictive(model, svi_run, num_samples=10000).run(num_trials) marginal_return_vals = posterior_predictive.marginal().empirical["_RETURN"] assert_close(marginal_return_vals.mean, torch.ones(5) * 700, rtol=0.05)
def test_posterior_predictive_svi_one_hot(): pseudocounts = torch.ones(3) * 0.1 true_probs = torch.tensor([0.15, 0.6, 0.25]) classes = dist.OneHotCategorical(true_probs).sample((10000, )) opt = optim.Adam(dict(lr=0.1)) loss = Trace_ELBO() guide = AutoDelta(one_hot_model) svi_run = SVI(one_hot_model, guide, opt, loss, num_steps=1000, num_samples=1000).run(pseudocounts, classes=classes) posterior_predictive = TracePredictive(one_hot_model, svi_run, num_samples=10000).run(pseudocounts) marginal_return_vals = posterior_predictive.marginal().empirical["_RETURN"] assert_close(marginal_return_vals.mean, true_probs.unsqueeze(0), rtol=0.1)
def guide(self, MAP=False, *args, **kwargs): if (MAP): return AutoDelta(poutine.block( self.model, expose=['mixture_weights', 'norm_factor', 'cnv_probs']), init_loc_fn=self.init_fn()) else: def guide_ret(*args, **kwargs): I, N = self._data['data'].shape batch = N if self._params['batch_size'] else self._params[ 'batch_size'] param_weights = pyro.param( "param_weights", lambda: torch.ones(self._params['K']) / self._params['K'], constraint=constraints.simplex) hidden_vals = pyro.param( "param_hidden_weights", lambda: self.create_dirichlet_init_values(), constraint=constraints.simplex) gamma_scale = pyro.param( "param_gamma_scale", lambda: torch.mean( self._data['data'] / (2 * self._data['mu'].reshape( self._data['data'].shape[0], 1)), axis=0) * self._params['gamma_multiplier'], constraint=constraints.positive) gamma_rate = pyro.param( "param_rate", lambda: torch.ones(1) * self._params['gamma_multiplier'], constraint=constraints.positive) weights = pyro.sample('mixture_weights', dist.Dirichlet(param_weights)) with pyro.plate('segments', I): with pyro.plate('components', self._params['K']): pyro.sample("cnv_probs", dist.Dirichlet(hidden_vals)) with pyro.plate("data2", N, batch): pyro.sample('norm_factor', dist.Gamma(gamma_scale, gamma_rate)) return guide_ret
def initialize(seed,model,data): global global_guide, svi pyro.set_rng_seed(seed) pyro.clear_param_store() exposed_params = [] # set the parameters inferred through the guide based on the kind of data if 'gr' in mtype: if dtype == 'norm': exposed_params = ['weights', 'concentration'] elif dtype == 'raw': exposed_params = ['weights', 'alpha', 'beta'] elif 'dim' in mtype: if dtype == 'norm': exposed_params = ['topic_weights', 'topic_concentration', 'participant_topics'] elif dtype == 'raw': exposed_params = ['topic_weights', 'topic_a','topic_b', 'participant_topics'] global_guide = AutoDelta(poutine.block(model, expose = exposed_params)) svi = SVI(model, global_guide, optim, loss = elbo) return svi.loss(model, global_guide, data)
def test_svi_custom_smoke(subsample_aware): t_obs = 5 t_forecast = 4 cov_dim = 3 obs_dim = 2 model = Model0() data = torch.randn(t_obs, obs_dim) covariates = torch.randn(t_obs + t_forecast, cov_dim) guide = AutoDelta(model) optim = Adam({}) Forecaster(model, data, covariates[..., :t_obs, :], guide=guide, optim=optim, subsample_aware=subsample_aware, num_steps=2, log_every=1)
def guide(self,MAP = False,*args, **kwargs): if(MAP): return AutoDelta(poutine.block(self.model, expose=['mixture_weights', 'norm_factor', 'cnv_probs', 'segment_mean']), init_loc_fn=self.init_fn()) else: def guide_ret(*args, **kwargs): I, N = self._data['data'].shape seg_mean = torch.mean(self._data['data'] / self._data['pld'].reshape([I, 1]), axis=1) batch = N if self._params['batch_size'] else self._params['batch_size'] param_weights = pyro.param("param_mixture_weights", lambda: torch.ones(self._params['K']) / self._params['K'], constraint=constraints.simplex) cnv_mean = pyro.param("param_cnv_probs", lambda: self.create_gaussian_init_values(), constraint=constraints.positive) cnv_var = pyro.param("param_cnv_var", lambda: torch.ones(I) * self._params['cnv_sd'], constraint=constraints.positive) seg_var = pyro.param("param_seg_var", lambda: torch.ones(I) * self._params['cnv_sd'], constraint=constraints.positive) seg_mean = pyro.param("param_seg_mean", lambda: seg_mean) gamma_scale = pyro.param("param_norm_factor", lambda: torch.sum(self._data['data'], axis = 0) / torch.sum(seg_mean), constraint=constraints.positive) pyro.sample('mixture_weights', dist.Dirichlet(param_weights)) with pyro.plate('segments', I): pyro.sample('segment_mean', dist.LogNormal(torch.log(seg_mean) - seg_var ** 2 / 2, seg_var)) with pyro.plate('components', self._params['K']): pyro.sample('cnv_probs', dist.LogNormal(torch.log(cnv_mean) - cnv_var ** 2 / 2, cnv_var)) with pyro.plate("data2", N, batch): pyro.sample('norm_factor', dist.Delta(gamma_scale)) return guide_ret
def fit(self, x: torch.Tensor) -> MixtureModel: def init_loc_fn(site): K = self.num_components if site["name"] == "weights": return torch.ones(K) / K if site["name"] == "scales": return torch.tensor([[(x.var() / 2).sqrt()] * 2] * K) if site["name"] == "locs": return x[torch.multinomial(torch.ones(x.shape[0]) / x.shape[0], K), :] raise ValueError(site["name"]) self.guide = AutoDelta(poutine.block(self.model, expose=['weights', 'locs', 'scales']), init_loc_fn=init_loc_fn) optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]}) loss = TraceEnum_ELBO(max_plate_nesting=1) svi = SVI(self.model, self.guide, optim, loss=loss) for i in range(self.optim_steps): elbo = svi.step(x) self.history["loss"].append(elbo) return self
from pyro.infer.autoguide import AutoDelta from BackammonGamer import Analyzing, NeuralNetwork from Backgammon import Checker, Game from Player import AiPlayer, RandomPlayer def play_against_ai(): player1 = AiPlayer(Checker.WHITE) player2 = RandomPlayer(Checker.BLACK) game = Game(player_1=player1, player_2=player2, create_protocol=True) game.run() if __name__ == '__main__': # Analyzing(NeuralNetwork()).analyzing() # print("guide auto delta") model4 = NeuralNetwork() Analyzing(model4, AutoDelta(model4)).analyzing()
def main(args): if args.cuda: torch.set_default_tensor_type("torch.cuda.FloatTensor") logging.info("Loading data") data = poly.load_data(poly.JSB_CHORALES) logging.info("-" * 40) model = models[args.model] logging.info("Training {} on {} sequences".format( model.__name__, len(data["train"]["sequences"]))) sequences = data["train"]["sequences"] lengths = data["train"]["sequence_lengths"] # find all the notes that are present at least once in the training set present_notes = (sequences == 1).sum(0).sum(0) > 0 # remove notes that are never played (we remove 37/88 notes) sequences = sequences[..., present_notes] if args.truncate: lengths = lengths.clamp(max=args.truncate) sequences = sequences[:, :args.truncate] num_observations = float(lengths.sum()) pyro.set_rng_seed(args.seed) pyro.clear_param_store() # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing # out the hidden state x. This is accomplished via an automatic guide that # learns point estimates of all of our conditional probability tables, # named probs_*. guide = AutoDelta( poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))) # To help debug our tensor shapes, let's print the shape of each site's # distribution, value, and log_prob tensor. Note this information is # automatically printed on most errors inside SVI. if args.print_shapes: first_available_dim = -2 if model is model_0 else -3 guide_trace = poutine.trace(guide).get_trace( sequences, lengths, args=args, batch_size=args.batch_size) model_trace = poutine.trace( poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace(sequences, lengths, args=args, batch_size=args.batch_size) logging.info(model_trace.format_shapes()) # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting. # All of our models have two plates: "data" and "tones". optim = Adam({"lr": args.learning_rate}) if args.tmc: if args.jit: raise NotImplementedError( "jit support not yet added for TraceTMC_ELBO") elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2) tmc_model = poutine.infer_config( model, lambda msg: { "num_samples": args.tmc_num_samples, "expand": False } if msg["infer"].get("enumerate", None) == "parallel" else {}, ) # noqa: E501 svi = SVI(tmc_model, guide, optim, elbo) else: Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO elbo = Elbo( max_plate_nesting=1 if model is model_0 else 2, strict_enumeration_warning=(model is not model_7), jit_options={"time_compilation": args.time_compilation}, ) svi = SVI(model, guide, optim, elbo) # We'll train on small minibatches. logging.info("Step\tLoss") for step in range(args.num_steps): loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size) logging.info("{: >5d}\t{}".format(step, loss / num_observations)) if args.jit and args.time_compilation: logging.debug("time to compile: {} s.".format( elbo._differentiable_loss.compile_time)) # We evaluate on the entire training dataset, # excluding the prior term so our results are comparable across models. train_loss = elbo.loss(model, guide, sequences, lengths, args, include_prior=False) logging.info("training loss = {}".format(train_loss / num_observations)) # Finally we evaluate on the test dataset. logging.info("-" * 40) logging.info("Evaluating on {} test sequences".format( len(data["test"]["sequences"]))) sequences = data["test"]["sequences"][..., present_notes] lengths = data["test"]["sequence_lengths"] if args.truncate: lengths = lengths.clamp(max=args.truncate) num_observations = float(lengths.sum()) # note that since we removed unseen notes above (to make the problem a bit easier and for # numerical stability) this test loss may not be directly comparable to numbers # reported on this dataset elsewhere. test_loss = elbo.loss(model, guide, sequences, lengths, args=args, include_prior=False) logging.info("test loss = {}".format(test_loss / num_observations)) # We expect models with higher capacity to perform better, # but eventually overfit to the training set. capacity = sum( value.reshape(-1).size(0) for value in pyro.get_param_store().values()) logging.info("{} capacity = {} parameters".format(model.__name__, capacity))
def main(args): # setup hyperparameters for the model hypers = { 'expected_sparsity': max(1.0, args.num_dimensions / 10), 'alpha1': 3.0, 'beta1': 1.0, 'alpha2': 3.0, 'beta2': 1.0, 'alpha3': 1.0, 'c': 1.0 } P = args.num_dimensions S = args.active_dimensions Q = args.quadratic_dimensions # generate artificial dataset X, Y, expected_thetas, expected_quad_dims = get_data(N=args.num_data, P=P, S=S, Q=Q, sigma_obs=args.sigma) loss_fn = Trace_ELBO().differentiable_loss # We initialize the AutoDelta guide (for MAP estimation) with args.num_trials many # initial parameters sampled from the vicinity of the median of the prior distribution # and then continue optimizing with the best performing initialization. init_losses = [] for restart in range(args.num_restarts): pyro.clear_param_store() pyro.set_rng_seed(restart) guide = AutoDelta(model, init_loc_fn=init_loc_fn) with torch.no_grad(): init_losses.append(loss_fn(model, guide, X, Y, hypers).item()) pyro.set_rng_seed(np.argmin(init_losses)) pyro.clear_param_store() guide = AutoDelta(model, init_loc_fn=init_loc_fn) # Instead of using pyro.infer.SVI and pyro.optim we instead construct our own PyTorch # optimizer and take charge of gradient-based optimization ourselves. with poutine.block(), poutine.trace(param_only=True) as param_capture: guide(X, Y, hypers) params = list( [pyro.param(name).unconstrained() for name in param_capture.trace]) adam = Adam(params, lr=args.lr) report_frequency = 50 print("Beginning MAP optimization...") # the optimization loop for step in range(args.num_steps): loss = loss_fn(model, guide, X, Y, hypers) / args.num_data loss.backward() adam.step() adam.zero_grad() # we manually reduce the learning rate according to this schedule if step in [100, 300, 700, 900]: adam.param_groups[0]['lr'] *= 0.2 if step % report_frequency == 0 or step == args.num_steps - 1: print("[step %04d] loss: %.5f" % (step, loss)) print("Expected singleton thetas:\n", expected_thetas.data.numpy()) # we do the final computation using double precision median = guide.median() # == mode for MAP inference active_dims, active_quad_dims = \ compute_posterior_stats(X.double(), Y.double(), median['msq'].double(), median['lambda'].double(), median['eta1'].double(), median['xisq'].double(), torch.tensor(hypers['c']).double(), median['sigma'].double()) expected_active_dims = np.arange(S).tolist() tp_singletons = len(set(active_dims) & set(expected_active_dims)) fp_singletons = len(set(active_dims) - set(expected_active_dims)) fn_singletons = len(set(expected_active_dims) - set(active_dims)) singleton_stats = (tp_singletons, fp_singletons, fn_singletons) tp_quads = len(set(active_quad_dims) & set(expected_quad_dims)) fp_quads = len(set(active_quad_dims) - set(expected_quad_dims)) fn_quads = len(set(expected_quad_dims) - set(active_quad_dims)) quad_stats = (tp_quads, fp_quads, fn_quads) # We report how well we did, i.e. did we recover the sparse set of coefficients # that we expected for our artificial dataset? print("[SUMMARY STATS]") print("Singletons (true positive, false positive, false negative): " + "(%d, %d, %d)" % singleton_stats) print("Quadratic (true positive, false positive, false negative): " + "(%d, %d, %d)" % quad_stats)
def auto_guide_list_x(model): guide = AutoGuideList(model) guide.append(AutoDelta(poutine.block(model, expose=["x"]))) guide.append(AutoDiagonalNormal(poutine.block(model, hide=["x"]))) return guide
def fit_advi_iterative_simple( self, n: int = 3, method='advi', n_type='restart', n_iter=None, learning_rate=None, progressbar=True, ): r""" Find posterior using ADVI (deprecated) (maximising likehood of the data and minimising KL-divergence of posterior to prior) :param n: number of independent initialisations :param method: which approximation of the posterior (guide) to use?. * ``'advi'`` - Univariate normal approximation (pyro.infer.autoguide.AutoDiagonalNormal) * ``'custom'`` - Custom guide using conjugate posteriors :return: self.svi dictionary with svi pyro objects for each n, and sefl.elbo dictionary storing training history. """ # Pass data to pyro / pytorch self.x_data = torch.tensor(self.X_data.astype( self.data_type)) # .double() # initialise parameter store self.svi = {} self.hist = {} self.guide_i = {} self.samples = {} self.node_samples = {} self.n_type = n_type if n_iter is None: n_iter = self.n_iter if learning_rate is None: learning_rate = self.learning_rate if np.isin(n_type, ['bootstrap']): if self.X_data_sample is None: self.bootstrap_data(n=n) elif np.isin(n_type, ['cv']): self.generate_cv_data() # cv data added to self.X_data_sample init_names = ['init_' + str(i + 1) for i in np.arange(n)] for i, name in enumerate(init_names): # initialise Variational distributiion = guide if method is 'advi': self.guide_i[name] = AutoGuideList(self.model) self.guide_i[name].append( AutoNormal(poutine.block(self.model, expose_all=True, hide_all=False, hide=self.point_estim), init_loc_fn=init_to_mean)) self.guide_i[name].append( AutoDelta( poutine.block(self.model, hide_all=True, expose=self.point_estim))) elif method is 'custom': self.guide_i[name] = self.guide # initialise SVI inference method self.svi[name] = SVI( self.model, self.guide_i[name], optim.ClippedAdam({ 'lr': learning_rate, # limit the gradient step from becoming too large 'clip_norm': self.total_grad_norm_constraint }), loss=JitTrace_ELBO()) pyro.clear_param_store() # record ELBO Loss history here self.hist[name] = [] # pick dataset depending on the training mode and move to GPU if np.isin(n_type, ['cv', 'bootstrap']): self.x_data = torch.tensor(self.X_data_sample[i].astype( self.data_type)) else: self.x_data = torch.tensor(self.X_data.astype(self.data_type)) if self.use_cuda: # move tensors and modules to CUDA self.x_data = self.x_data.cuda() # train for n_iter it_iterator = tqdm(range(n_iter)) for it in it_iterator: hist = self.svi[name].step(self.x_data) it_iterator.set_description('ELBO Loss: ' + str(np.round(hist, 3))) self.hist[name].append(hist) # if it % 50 == 0 & self.verbose: # logging.info("Elbo loss: {}".format(hist)) if it % 500 == 0: torch.cuda.empty_cache()
def fit_advi_iterative(self, n=3, method='advi', n_type='restart', n_iter=None, learning_rate=None, progressbar=True, num_workers=2, train_proportion=None, stratify_cv=None, l2_weight=False, sample_scaling_weight=0.5, checkpoints=None, checkpoint_dir='./checkpoints', tracking=False): r""" Train posterior using ADVI method. (maximising likehood of the data and minimising KL-divergence of posterior to prior) :param n: number of independent initialisations :param method: to allow for potential use of SVGD or MCMC (currently only ADVI implemented). :param n_type: type of repeated initialisation: 'restart' to pick different initial value, 'cv' for molecular cross-validation - splits counts into n datasets, for now, only n=2 is implemented 'bootstrap' for fitting the model to multiple downsampled datasets. Run `mod.bootstrap_data()` to generate variants of data :param n_iter: number of iterations, supersedes self.n_iter :param train_proportion: if not None, which proportion of cells to use for training and which for validation. :param checkpoints: int, list of int's or None, number of checkpoints to save while model training or list of iterations to save checkpoints on :param checkpoint_dir: str, directory to save checkpoints in :param tracking: bool, track all latent variables during training - if True makes training 2 times slower :return: None """ # initialise parameter store self.svi = {} self.hist = {} self.guide_i = {} self.samples = {} self.node_samples = {} if tracking: self.logp_hist = {} if n_iter is None: n_iter = self.n_iter if type(checkpoints) is int: if n_iter < checkpoints: checkpoints = n_iter checkpoints = np.linspace(0, n_iter, checkpoints + 1, dtype=int)[1:] self.checkpoints = list(checkpoints) else: self.checkpoints = checkpoints self.checkpoint_dir = checkpoint_dir self.n_type = n_type self.l2_weight = l2_weight self.sample_scaling_weight = sample_scaling_weight self.train_proportion = train_proportion if stratify_cv is not None: self.stratify_cv = stratify_cv if train_proportion is not None: self.validation_hist = {} self.training_hist = {} if tracking: self.logp_hist_val = {} self.logp_hist_train = {} if learning_rate is None: learning_rate = self.learning_rate if np.isin(n_type, ['bootstrap']): if self.X_data_sample is None: self.bootstrap_data(n=n) elif np.isin(n_type, ['cv']): self.generate_cv_data() # cv data added to self.X_data_sample init_names = ['init_' + str(i + 1) for i in np.arange(n)] for i, name in enumerate(init_names): ################### Initialise parameters & optimiser ################### # initialise Variational distribution = guide if method is 'advi': self.guide_i[name] = AutoGuideList(self.model) normal_guide_block = poutine.block( self.model, expose_all=True, hide_all=False, hide=self.point_estim + flatten_iterable(self.custom_guides.keys())) self.guide_i[name].append( AutoNormal(normal_guide_block, init_loc_fn=init_to_mean)) self.guide_i[name].append( AutoDelta( poutine.block(self.model, hide_all=True, expose=self.point_estim))) for k, v in self.custom_guides.items(): self.guide_i[name].append(v) elif method is 'custom': self.guide_i[name] = self.guide # initialise SVI inference method self.svi[name] = SVI( self.model, self.guide_i[name], optim.ClippedAdam({ 'lr': learning_rate, # limit the gradient step from becoming too large 'clip_norm': self.total_grad_norm_constraint }), loss=JitTrace_ELBO()) pyro.clear_param_store() self.set_initial_values() # record ELBO Loss history here self.hist[name] = [] if tracking: self.logp_hist[name] = defaultdict(list) if train_proportion is not None: self.validation_hist[name] = [] if tracking: self.logp_hist_val[name] = defaultdict(list) ################### Select data for this iteration ################### if np.isin(n_type, ['cv', 'bootstrap']): X_data = self.X_data_sample[i].astype(self.data_type) else: X_data = self.X_data.astype(self.data_type) ################### Training / validation split ################### # split into training and validation if train_proportion is not None: idx = np.arange(len(X_data)) train_idx, val_idx = train_test_split( idx, train_size=train_proportion, shuffle=True, stratify=self.stratify_cv) extra_data_val = { k: torch.FloatTensor(v[val_idx]).to(self.device) for k, v in self.extra_data.items() } extra_data_train = { k: torch.FloatTensor(v[train_idx]) for k, v in self.extra_data.items() } x_data_val = torch.FloatTensor(X_data[val_idx]).to(self.device) x_data = torch.FloatTensor(X_data[train_idx]) else: # just convert data to CPU tensors x_data = torch.FloatTensor(X_data) extra_data_train = { k: torch.FloatTensor(v) for k, v in self.extra_data.items() } ################### Move data to cuda - FULL data ################### # if not minibatch do this: if self.minibatch_size is None: # move tensors to CUDA x_data = x_data.to(self.device) for k in extra_data_train.keys(): extra_data_train[k] = extra_data_train[k].to(self.device) # extra_data_train = {k: v.to(self.device) for k, v in extra_data_train.items()} ################### MINIBATCH data ################### else: # create minibatches dataset = MiniBatchDataset(x_data, extra_data_train, return_idx=True) loader = DataLoader(dataset, batch_size=self.minibatch_size, num_workers=0) # TODO num_workers ################### Training the model ################### # start training in epochs epochs_iterator = tqdm(range(n_iter)) for epoch in epochs_iterator: if self.minibatch_size is None: ################### Training FULL data ################### iter_loss = self.step_train(name, x_data, extra_data_train) self.hist[name].append(iter_loss) # save data for posterior sampling self.x_data = x_data self.extra_data_train = extra_data_train if tracking: guide_tr, model_tr = self.step_trace( name, x_data, extra_data_train) self.logp_hist[name]['guide'].append( guide_tr.log_prob_sum().item()) self.logp_hist[name]['model'].append( model_tr.log_prob_sum().item()) for k, v in model_tr.nodes.items(): if "log_prob_sum" in v: self.logp_hist[name][k].append( v["log_prob_sum"].item()) else: ################### Training MINIBATCH data ################### aver_loss = [] if tracking: aver_logp_guide = [] aver_logp_model = [] aver_logp = defaultdict(list) for batch in loader: x_data_batch, extra_data_batch = batch x_data_batch = x_data_batch.to(self.device) extra_data_batch = { k: v.to(self.device) for k, v in extra_data_batch.items() } loss = self.step_train(name, x_data_batch, extra_data_batch) if tracking: guide_tr, model_tr = self.step_trace( name, x_data_batch, extra_data_batch) aver_logp_guide.append( guide_tr.log_prob_sum().item()) aver_logp_model.append( model_tr.log_prob_sum().item()) for k, v in model_tr.nodes.items(): if "log_prob_sum" in v: aver_logp[k].append( v["log_prob_sum"].item()) aver_loss.append(loss) iter_loss = np.sum(aver_loss) # save data for posterior sampling self.x_data = x_data_batch self.extra_data_train = extra_data_batch self.hist[name].append(iter_loss) if tracking: iter_logp_guide = np.sum(aver_logp_guide) iter_logp_model = np.sum(aver_logp_model) self.logp_hist[name]['guide'].append(iter_logp_guide) self.logp_hist[name]['model'].append(iter_logp_model) for k, v in aver_logp.items(): self.logp_hist[name][k].append(np.sum(v)) if self.checkpoints is not None: if (epoch + 1) in self.checkpoints: self.save_checkpoint(epoch + 1, prefix=name) ################### Evaluating cross-validation loss ################### if train_proportion is not None: iter_loss_val = self.step_eval_loss( name, x_data_val, extra_data_val) if tracking: guide_tr, model_tr = self.step_trace( name, x_data_val, extra_data_val) self.logp_hist_val[name]['guide'].append( guide_tr.log_prob_sum().item()) self.logp_hist_val[name]['model'].append( model_tr.log_prob_sum().item()) for k, v in model_tr.nodes.items(): if "log_prob_sum" in v: self.logp_hist_val[name][k].append( v["log_prob_sum"].item()) self.validation_hist[name].append(iter_loss_val) epochs_iterator.set_description(f'ELBO Loss: ' + '{:.4e}'.format(iter_loss) \ + ': Val loss: ' + '{:.4e}'.format(iter_loss_val)) else: epochs_iterator.set_description('ELBO Loss: ' + '{:.4e}'.format(iter_loss)) if epoch % 20 == 0: torch.cuda.empty_cache() if train_proportion is not None: # rescale loss self.validation_hist[name] = [ i / (1 - train_proportion) for i in self.validation_hist[name] ] self.hist[name] = [ i / train_proportion for i in self.hist[name] ] # reassing the main loss to be displayed self.training_hist[name] = self.hist[name] self.hist[name] = self.validation_hist[name] if tracking: for k, v in self.logp_hist[name].items(): self.logp_hist[name][k] = [ i / train_proportion for i in self.logp_hist[name][k] ] self.logp_hist_val[name][k] = [ i / (1 - train_proportion) for i in self.logp_hist_val[name][k] ] self.logp_hist_train[name] = self.logp_hist[name] self.logp_hist[name] = self.logp_hist_val[name] if self.verbose: print(plt.plot(np.log10(self.hist[name][0:])))