def test_stack_overwrite_failure(self): data1 = {"latent2": torch.randn(2)} data2 = {"latent2": torch.randn(2)} cm = poutine.condition(poutine.condition(self.model, data=data1), data=data2) with pytest.raises(AssertionError): cm()
def predict(self, forecast=0): """ Predict latent variables and optionally forecast forward. This may be run only after :meth:`fit_mcmc` and draws the same ``num_samples`` as passed to :meth:`fit_mcmc`. :param int forecast: The number of time steps to forecast forward. :returns: A dictionary mapping sample site name (or compartment name) to a tensor whose first dimension corresponds to sample batching. :rtype: dict """ if self.num_quant_bins > 1: _require_double_precision() if not self.samples: raise RuntimeError("Missing samples, try running .fit_mcmc() first") samples = self.samples num_samples = len(next(iter(samples.values()))) particle_plate = pyro.plate("particles", num_samples, dim=-1 - self.max_plate_nesting) # Sample discrete auxiliary variables conditioned on the continuous # variables sampled by _quantized_model. This samples only time steps # [0:duration]. Here infer_discrete runs a forward-filter # backward-sample algorithm. logger.info("Predicting latent variables for {} time steps..." .format(self.duration)) model = self._sequential_model model = poutine.condition(model, samples) model = particle_plate(model) if not self.relaxed: model = infer_discrete(model, first_available_dim=-2 - self.max_plate_nesting) trace = poutine.trace(model).get_trace() samples = OrderedDict((name, site["value"].expand(site["fn"].shape())) for name, site in trace.nodes.items() if site["type"] == "sample" if not site_is_subsample(site) if not site_is_factor(site)) assert all(v.size(0) == num_samples for v in samples.values()), \ {k: tuple(v.shape) for k, v in samples.items()} # Optionally forecast with the forward _generative_model. This samples # time steps [duration:duration+forecast]. if forecast: logger.info("Forecasting {} steps ahead...".format(forecast)) model = self._generative_model model = poutine.condition(model, samples) model = particle_plate(model) trace = poutine.trace(model).get_trace(forecast) samples = OrderedDict((name, site["value"]) for name, site in trace.nodes.items() if site["type"] == "sample" if not site_is_subsample(site) if not site_is_factor(site)) self._concat_series(samples, trace, forecast) assert all(v.size(0) == num_samples for v in samples.values()), \ {k: tuple(v.shape) for k, v in samples.items()} return samples
def save_posterior_predictive(model, guide, filename, N=300): if N == 1: mock = {} guide_trace = poutine.trace(guide).get_trace() trace = poutine.trace(poutine.condition(model, data=guide_trace)).get_trace() for tag in trace: if trace.nodes[tag]["type"] == "sample": mock[tag] = trace.nodes[tag]["value"].detach().cpu().numpy() else: mock = defaultdict(list) for i in range(N): # Faster way if we don't need `deterministic` statements. # Literally just samples from the guide. # # for tag, value in guide()[1].items(): # mock[tag].append(value.detach().cpu().numpy()) # continue guide_trace = poutine.trace(guide).get_trace() trace = poutine.trace(poutine.condition( model, data=guide_trace)).get_trace() for tag in trace: if trace.nodes[tag]["type"] == "sample": mock[tag].append( trace.nodes[tag]["value"].detach().cpu().numpy()) np.savez(filename, **mock) print("Saved %i sample(s) from posterior predictive distribution to %s" % (N, filename))
def test_stack_overwrite_behavior(self): data1 = {"latent2": torch.randn(2)} data2 = {"latent2": torch.randn(2)} with poutine.trace() as tr: cm = poutine.condition(poutine.condition(self.model, data=data1), data=data2) cm() assert tr.trace.nodes['latent2']['value'] is data2['latent2']
def _predictive(model, posterior_samples, num_samples, return_sites=(), return_trace=False, parallel=False, model_args=(), model_kwargs={}): max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs) vectorize = pyro.plate("_num_predictive_samples", num_samples, dim=-max_plate_nesting-1) model_trace = prune_subsample_sites(poutine.trace(model).get_trace(*model_args, **model_kwargs)) reshaped_samples = {} for name, sample in posterior_samples.items(): sample_shape = sample.shape[1:] sample = sample.reshape((num_samples,) + (1,) * (max_plate_nesting - len(sample_shape)) + sample_shape) reshaped_samples[name] = sample if return_trace: trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\ .get_trace(*model_args, **model_kwargs) return trace return_site_shapes = {} for site in model_trace.stochastic_nodes + model_trace.observation_nodes: append_ndim = max_plate_nesting - len(model_trace.nodes[site]["fn"].batch_shape) site_shape = (num_samples,) + (1,) * append_ndim + model_trace.nodes[site]['value'].shape # non-empty return-sites if return_sites: if site in return_sites: return_site_shapes[site] = site_shape # special case (for guides): include all sites elif return_sites is None: return_site_shapes[site] = site_shape # default case: return sites = () # include all sites not in posterior samples elif site not in posterior_samples: return_site_shapes[site] = site_shape # handle _RETURN site if return_sites is not None and '_RETURN' in return_sites: value = model_trace.nodes['_RETURN']['value'] shape = (num_samples,) + value.shape if torch.is_tensor(value) else None return_site_shapes['_RETURN'] = shape if not parallel: return _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes, return_trace=False) trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\ .get_trace(*model_args, **model_kwargs) predictions = {} for site, shape in return_site_shapes.items(): value = trace.nodes[site]['value'] if site == '_RETURN' and shape is None: predictions[site] = value continue if value.numel() < reduce((lambda x, y: x * y), shape): predictions[site] = value.expand(shape) else: predictions[site] = value.reshape(shape) return predictions
def test_stack_success(self): data1 = {"latent1": torch.randn(2)} data2 = {"latent2": torch.randn(2)} tr = poutine.trace( poutine.condition(poutine.condition(self.model, data=data1), data=data2)).get_trace() assert tr.nodes["latent1"]["type"] == "sample" and \ tr.nodes["latent1"]["is_observed"] assert tr.nodes["latent1"]["value"] is data1["latent1"] assert tr.nodes["latent2"]["type"] == "sample" and \ tr.nodes["latent2"]["is_observed"] assert tr.nodes["latent2"]["value"] is data2["latent2"]
def test_nested(): shape = (5, 6) @poutine.reparam(config={ "x": HaarReparam(dim=-1), "x_haar": HaarReparam(dim=-2) }) def model(): pyro.sample("x", dist.Normal(torch.zeros(shape), 1).to_event(2)) # Try without initialization, e.g. in AutoGuide._setup_prototype(). trace = poutine.trace(model).get_trace() assert {"x", "x_haar", "x_haar_haar"}.issubset(trace.nodes) assert trace.nodes["x"]["is_observed"] assert trace.nodes["x_haar"]["is_observed"] assert not trace.nodes["x_haar_haar"]["is_observed"] assert trace.nodes["x"]["value"].shape == shape # Try conditioning on x_haar_haar, e.g. in Predictive. x = torch.randn(shape) x_haar = HaarTransform(dim=-1)(x) x_haar_haar = HaarTransform(dim=-2)(x_haar) with poutine.condition(data={"x_haar_haar": x_haar_haar}): trace = poutine.trace(model).get_trace() assert {"x", "x_haar", "x_haar_haar"}.issubset(trace.nodes) assert trace.nodes["x"]["is_observed"] assert trace.nodes["x_haar"]["is_observed"] assert trace.nodes["x_haar_haar"]["is_observed"] assert_close(trace.nodes["x"]["value"], x) assert_close(trace.nodes["x_haar"]["value"], x_haar) assert_close(trace.nodes["x_haar_haar"]["value"], x_haar_haar) # Try with custom initialization. # This is required for autoguides and MCMC. with InitMessenger(init_to_value(values={"x": x})): trace = poutine.trace(model).get_trace() assert {"x", "x_haar", "x_haar_haar"}.issubset(trace.nodes) assert trace.nodes["x"]["is_observed"] assert trace.nodes["x_haar"]["is_observed"] assert not trace.nodes["x_haar_haar"]["is_observed"] assert_close(trace.nodes["x"]["value"], x) # Try conditioning on x. x = torch.randn(shape) with poutine.condition(data={"x": x}): trace = poutine.trace(model).get_trace() assert {"x", "x_haar", "x_haar_haar"}.issubset(trace.nodes) assert trace.nodes["x"]["is_observed"] assert trace.nodes["x_haar"]["is_observed"] # TODO Decide whether it is worth fixing this failing assertion. # See https://github.com/pyro-ppl/pyro/issues/2878 # assert trace.nodes["x_haar_haar"]["is_observed"] assert_close(trace.nodes["x"]["value"], x)
def test_counterfactual_query(intervene, observe, flip): # x -> y -> z -> w sites = ["x", "y", "z", "w"] observations = {"x": 1., "y": None, "z": 1., "w": 1.} interventions = {"x": None, "y": 0., "z": 2., "w": 1.} def model(): x = _item(pyro.sample("x", dist.Normal(0, 1))) y = _item(pyro.sample("y", dist.Normal(x, 1))) z = _item(pyro.sample("z", dist.Normal(y, 1))) w = _item(pyro.sample("w", dist.Normal(z, 1))) return dict(x=x, y=y, z=z, w=w) if not flip: if intervene: model = poutine.do(model, data=interventions) if observe: model = poutine.condition(model, data=observations) elif flip and intervene and observe: model = poutine.do(poutine.condition(model, data=observations), data=interventions) tr = poutine.trace(model).get_trace() actual_values = tr.nodes["_RETURN"]["value"] for name in sites: # case 1: purely observational query like poutine.condition if not intervene and observe: if observations[name] is not None: assert tr.nodes[name]['is_observed'] assert_equal(observations[name], actual_values[name]) assert_equal(observations[name], tr.nodes[name]['value']) if interventions[name] != observations[name]: assert_not_equal(interventions[name], actual_values[name]) # case 2: purely interventional query like old poutine.do elif intervene and not observe: assert not tr.nodes[name]['is_observed'] if interventions[name] is not None: assert_equal(interventions[name], actual_values[name]) assert_not_equal(observations[name], tr.nodes[name]['value']) assert_not_equal(interventions[name], tr.nodes[name]['value']) # case 3: counterfactual query mixing intervention and observation elif intervene and observe: if observations[name] is not None: assert tr.nodes[name]['is_observed'] assert_equal(observations[name], tr.nodes[name]['value']) if interventions[name] is not None: assert_equal(interventions[name], actual_values[name]) if interventions[name] != observations[name]: assert_not_equal(interventions[name], tr.nodes[name]['value'])
def test_condition(self): data = {"latent2": torch.randn(2)} tr2 = poutine.trace(poutine.condition(self.model, data=data)).get_trace() assert "latent2" in tr2 assert tr2.nodes["latent2"]["type"] == "sample" and \ tr2.nodes["latent2"]["is_observed"] assert tr2.nodes["latent2"]["value"] is data["latent2"]
def generate_data(args): logging.info("Generating data...") params = {"R0": torch.tensor(args.basic_reproduction_number), "rho": torch.tensor(args.response_rate)} empty_data = [None] * (args.duration + args.forecast) # We'll retry until we get an actual outbreak. for attempt in range(100): with poutine.trace() as tr: with poutine.condition(data=params): discrete_model(args, empty_data) # Concatenate sequential time series into tensors. obs = torch.stack([site["value"] for name, site in tr.trace.nodes.items() if re.match("obs_[0-9]+", name)]) S2I = torch.stack([site["value"] for name, site in tr.trace.nodes.items() if re.match("S2I_[0-9]+", name)]) assert len(obs) == len(empty_data) obs_sum = int(obs[:args.duration].sum()) S2I_sum = int(S2I[:args.duration].sum()) if obs_sum >= args.min_observations: logging.info("Observed {:d}/{:d} infections:\n{}".format( obs_sum, S2I_sum, " ".join([str(int(x)) for x in obs[:args.duration]]))) return {"S2I": S2I, "obs": obs} raise ValueError("Failed to generate {} observations. Try increasing " "--population or decreasing --min-observations" .format(args.min_observations))
def _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, sample_sites, return_trace=False): collected = [] samples = [{k: v[i] for k, v in posterior_samples.items()} for i in range(num_samples)] for i in range(num_samples): trace = poutine.trace(poutine.condition(model, samples[i])).get_trace( *model_args, **model_kwargs) if return_trace: collected.append(trace) else: collected.append( {site: trace.nodes[site]['value'] for site in sample_sites}) return collected if return_trace else { site: torch.stack([s[site] for s in collected]) for site in sample_sites }
def run_pyro(site_values, data, model, transformed_data, n_samples, params): # import model, transformed_data functions (if exists) from pyro module assert model is not None, "model couldn't be imported" variablize_params(params) log_pdfs = [] n_log_probs = None for j in range(n_samples): if n_samples > 1: sample_site_values = {v: site_values[v][j] for v in site_values} else: sample_site_values = {v: float(site_values[v]) if site_values[v].shape == () else site_values[v][0] for v in site_values} #print(sample_site_values) process_2d_sites(sample_site_values) variablize_params(sample_site_values) model_trace = poutine.trace(poutine.condition(model, data=sample_site_values), graph_type="flat").get_trace(data, params) log_p = model_trace.log_pdf() if n_log_probs is None: n_log_probs = get_num_log_probs(model_trace) else: assert n_log_probs == get_num_log_probs(model_trace) #print(log_p.data.numpy()) log_pdfs.append(to_float(log_p)) return log_pdfs, n_log_probs
def posterior_replay(model, posterior_samples, *args, **kwargs): r""" Given a model and samples from the posterior (potentially with conjugate sites collapsed), return a `dict` of samples from the posterior with conjugate sites uncollapsed. Note that this can also be used to generate samples from the posterior predictive distribution. :param model: Python callable. :param dict posterior_samples: posterior samples keyed by site name. :param args: arguments to `model`. :param kwargs: keyword arguments to `model`. :return: `dict` of samples from the posterior. """ posterior_samples = posterior_samples.copy() num_samples = kwargs.pop("num_samples", None) assert posterior_samples or num_samples, "`num_samples` must be provided if `posterior_samples` is empty." if num_samples is None: num_samples = list(posterior_samples.values())[0].shape[0] return_samples = defaultdict(list) for i in range(num_samples): conditioned_nodes = {k: v[i] for k, v in posterior_samples.items()} collapsed_trace = poutine.trace(poutine.condition(collapse_conjugate(model), conditioned_nodes))\ .get_trace(*args, **kwargs) trace = poutine.trace(uncollapse_conjugate(model, collapsed_trace)).get_trace( *args, **kwargs) for name, site in trace.iter_stochastic_nodes(): if not site_is_subsample(site): return_samples[name].append(site["value"]) return {k: torch.stack(v) for k, v in return_samples.items()}
def get_log_prob(mcmc, data, site_names): """Gets the pointwise log probability of the posterior density conditioned on the data Arguments: mcmc (pyro.infer.mcmc.MCMC): the fitted MC model data (dict): dictionary containing all the input data (including return sites) site_names (str or List[str]): names of return sites to measure log likelihood at Returns: Tensor: pointwise log-likelihood of shape (num posterior samples, num data points) """ samples = mcmc.get_samples() model = mcmc.kernel.model # get number of samples N = [v.shape[0] for v in samples.values()] assert [n == N[0] for n in N] N = N[0] if isinstance(site_names, str): site_names = [site_names] # iterate over samples log_prob = torch.zeros(N, len(data[site_names[0]])) for i in range(N): # condition on samples and get trace s = {k: v[i] for k, v in samples.items()} for nm in site_names: s[nm] = data[nm] tr = poutine.trace(poutine.condition(model, data=s)).get_trace(data) # get pointwise log probability for nm in site_names: node = tr.nodes[nm] log_prob[i] += node["fn"].log_prob(node["value"]) return log_prob
def test_trace_data(self): tr1 = poutine.trace( poutine.block(self.model, expose_types=["sample"])).get_trace() tr2 = poutine.trace( poutine.condition(self.model, data=tr1)).get_trace() assert tr2.nodes["latent2"]["type"] == "sample" and \ tr2.nodes["latent2"]["is_observed"] assert tr2.nodes["latent2"]["value"] is tr1.nodes["latent2"]["value"]
def nested(): true_probs = torch.ones(5) * 0.7 num_trials = torch.ones(5) * 1000 num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) nuts_kernel = NUTS(conditioned_model, adapt_step_size=True) mcmc_run = MCMC(nuts_kernel, num_samples=10, warmup_steps=2).run(num_trials) return mcmc_run
def __init__(self,model,data,descriptor,block_info=None,stan_info=None): self.descriptor = descriptor self.model = poutine.condition(model, data=data) self.parameters = get_rvs(self.model,False) self.observed = get_rvs(self.model,True) self.data = data self.block_info = block_info self.stan_info = stan_info
def _conditioned_model(self, model, ratings): data = dict() for x in range(self.n_user): for y in range(self.n_item): if ratings[x, y] != 0: data["obs" + str(x * self.n_item + y)] = \ torch.tensor(ratings[x,y], dtype = torch.float64) return poutine.condition(model, data=data)()
def conditioned_model(model, at_bats, hits): """ Condition the model on observed data, for inference. :param model: python callable with Pyro primitives. :param (torch.Tensor) at_bats: Number of at bats for each player. :param (torch.Tensor) hits: Number of hits for the given at bats. """ return poutine.condition(model, data={"obs": hits})(at_bats)
def _potential_fn(self, params): params_constrained = {k: self.transforms[k].inv(v) for k, v in params.items()} cond_model = poutine.condition(self.model, params_constrained) model_trace = poutine.trace(cond_model).get_trace(*self.model_args, **self.model_kwargs) log_joint = self.trace_prob_evaluator.log_prob(model_trace) for name, t in self.transforms.items(): log_joint = log_joint - torch.sum( t.log_abs_det_jacobian(params_constrained[name], params[name])) return -log_joint
def test_posterior_predictive(): true_probs = torch.ones(5) * 0.7 num_trials = torch.ones(5) * 1000 num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) nuts_kernel = NUTS(conditioned_model, adapt_step_size=True) mcmc_run = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200).run(num_trials) posterior_predictive = TracePredictive(model, mcmc_run, num_samples=10000).run(num_trials) marginal_return_vals = EmpiricalMarginal(posterior_predictive) assert_equal(marginal_return_vals.mean, torch.ones(5) * 700, prec=30)
def _conditioned_model(self,model, sigma, ratings): data = dict() rating = ratings.take(2, axis=1) rating_len = len(rating) for i in range(rating_len): data["obs" + str(i)] = torch.tensor(rating[i], dtype = torch.float64) return poutine.condition(model, data=data)(sigma)
def score_latent(zs, ys): model = HarmonicModel() with poutine.trace() as trace: with poutine.condition( data={"z_{}".format(t): z for t, z in enumerate(zs)}): model.init() for y in ys[1:]: model.step(y) return trace.trace.log_prob_sum()
def conditioned_model(model, t, yt): # model must be a BaseModel assert isinstance(model, BaseModel) fcn = model.forward if model.output_type == "yt": obs = yt elif model.output_type == "logyt": obs = torch.log(torch.clamp(yt, min=1.0)) # fcn = poutine.mask(fcn, mask=(yt > 0)) return poutine.condition(fcn, data={model.output_type: obs})(t)
def WAIC(model, x, y, out_var_nm, num_samples=100): p = torch.zeros((num_samples, len(y))) # Get log probability samples for i in range(num_samples): tr = poutine.trace(poutine.condition(model, data=model.guide())).get_trace(x) dist = tr.nodes[out_var_nm]["fn"] p[i] = dist.log_prob(y).detach() pmax = p.max(axis=0).values lppd = pmax + (p - pmax).exp().mean(axis=0).log() # numerically stable version penalty = p.var(axis=0) return -2*(lppd - penalty)
def test_posterior_predictive_svi_auto_delta_guide(parallel): true_probs = torch.ones(5) * 0.7 num_trials = torch.ones(5) * 1000 num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) guide = AutoDelta(conditioned_model) svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=1.0)), Trace_ELBO()) for i in range(1000): svi.step(num_trials) posterior_predictive = Predictive(model, guide=guide, num_samples=10000, parallel=parallel) marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"] assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05)
def test_posterior_predictive_svi_auto_diag_normal_guide(return_trace): true_probs = torch.ones(5) * 0.7 num_trials = torch.ones(5) * 1000 num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) guide = AutoDiagonalNormal(conditioned_model) svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO()) for i in range(1000): svi.step(num_trials) posterior_predictive = Predictive(model, guide=guide, num_samples=10000, parallel=True) if return_trace: marginal_return_vals = posterior_predictive.get_vectorized_trace(num_trials).nodes["obs"]["value"] else: marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"] assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05)
def sample_and_calculate_log_impweight(x_data, y_data, model, guide, num_post_samples=int(1e3)): """ returns: samples and their log importance weights (adding to 0/exponentiated sum=1) """ log_impweights = torch.zeros(torch.Size([num_post_samples])) samples = torch.zeros(torch.Size([num_post_samples, guide.latent_dim])) #sigmas = torch.zeros(torch.Size([replications])) #sigma_logprobs = torch.zeros(torch.Size([replications])) for i in range(num_post_samples): trace_guide = poutine.trace(guide).get_trace() param_sample = trace_guide.nodes["_RETURN"]["value"] # need to evaluate the log probs so that it appears in ["_AutoDiagonalNormal_latent"]["log_prob_sum"] trace_guide.log_prob_sum() # check log prob sum!! ["_AutoDiagonalNormal_latent"]["log_prob_sum"] seems to be correct, # while the above <trace_guide.log_prob_sum()> equals # the log prob sum for sigma (which is a deterministic function of log sigma is wrong, i.e. not zero) # hack if isinstance(guide, AutoDiagonalNormal): samples[i, :] = trace_guide.nodes["_AutoDiagonalNormal_latent"][ "value"] param_sample_logprob = trace_guide.nodes[ "_AutoDiagonalNormal_latent"]["log_prob_sum"] else: samples[i, :] = trace_guide.nodes[ "_AutoMultivariateNormal_latent"]["value"] param_sample_logprob = trace_guide.nodes[ "_AutoMultivariateNormal_latent"]["log_prob_sum"] #param_sample_logprob = trace_guide.log_prob_sum() - trace_guide.nodes["sigma"]["log_prob_sum"] #trace_guide.nodes["_AutoDiagonalNormal_latent"]["log_prob_sum"] cond_model = poutine.condition(model, data={ "obs": y_data, **param_sample }) trace_cond_model = poutine.trace(cond_model).get_trace(x=x_data) joint_logprob = trace_cond_model.log_prob_sum() #<=>estimated log-posterior log_impweights[i] = joint_logprob - param_sample_logprob log_impweights = log_impweights - torch.logsumexp(log_impweights, dim=0) return samples, log_impweights
def transform_samples(self, aux_samples, save_params=None): """ Given latent samples from the warped posterior (with a possible batch dimension), return a `dict` of samples from the latent sites in the model. :param dict aux_samples: Dict site name to tensor value for each latent auxiliary site (or if ``save_params`` is specifiec, then for only those latent auxiliary sites needed to compute requested params). :param list save_params: An optional list of site names to save. This is useful in models with large nuisance variables. Defaults to None, saving all params. :return: a `dict` of samples keyed by latent sites in the model. :rtype: dict """ with poutine.condition(data=aux_samples), poutine.mask(mask=False): deltas = self.guide.get_deltas(save_params) return {name: delta.v for name, delta in deltas.items()}
def fit( self, df, max_iter=6000, patience=200, optimiser_settings={"lr": 1.0e-2}, elbo_kwargs={"num_particles": 5}, ): teams = sorted(list(set(df["home_team"]) | set(df["away_team"]))) home_team = df["home_team"].values away_team = df["away_team"].values home_goals = torch.tensor(df["home_goals"].values, dtype=torch.float32) away_goals = torch.tensor(df["away_goals"].values, dtype=torch.float32) gameweek = ((df["date"] - df["date"].min()).dt.days // 7).values self.team_to_index = {team: i for i, team in enumerate(teams)} self.index_to_team = { value: key for key, value in self.team_to_index.items() } self.n_teams = len(teams) self.min_date = df["date"].min() conditioned_model = condition(self.model, data={ "home_goals": home_goals, "away_goals": away_goals }) guide = AutoDiagonalNormal(conditioned_model) optimizer = Adam(optimiser_settings) elbo = Trace_ELBO(**elbo_kwargs) svi = SVI(conditioned_model, guide, optimizer, loss=elbo) pyro.clear_param_store() fitted_svi, losses = early_stopping(svi, home_team, away_team, gameweek, max_iter=max_iter, patience=patience) self.guide = guide return losses