def model(y=None): with pyro.util.optional(pyro.plate("plate", 3), with_plate): x = pyro.sample("x", dist.Normal(0, 1).expand(event_shape).to_event()) x2 = pyro.deterministic("x2", x**2, event_dim=len(event_shape)) pyro.deterministic("x3", x2) return pyro.sample("obs", dist.Normal(x2, 0.1).to_event(), obs=y)
def backward_model(data, *args, encoder=None, use_deltas=False, kl_beta=1.0, **kwargs): encoder = pyro.module("encoder", encoder) N = data.shape[0] with poutine.scale_messenger.ScaleMessenger(1 / N): with pyro.plate("batch", N): encoder_out = encoder(data) delta_sample_transformer_params(encoder.transformer.transformers, encoder_out["transform_params"]) pyro.deterministic("attention_input", encoder_out["view"]) def make_dist(): if use_deltas: return D.Delta(encoder_out["z_mu"]) else: return D.Normal(encoder_out["z_mu"], torch.exp(encoder_out["z_std"]) + 1e-6) with poutine.scale_messenger.ScaleMessenger(kl_beta): z = pyro.sample("z", make_dist().to_event(1))
def forward_model( data, transforms=None, cond=True, decoder=None, output_size=40, device=torch.device("cpu"), **kwargs ): decoder = pyro.module("view_decoder", decoder) with pyro.plate(data.shape[0]): z = pyro.sample( "z", D.Normal( torch.zeros(decoder.latent_dim, device=device), torch.ones(decoder.latent_dim, device=device), ).to_event(1), ) view = decoder(z) pyro.deterministic("canonical_view", view) grid = coordinates.identity_grid([output_size, output_size], device=device) grid = grid.expand(data.shape[0], *grid.shape) transform = random_pose_transform(transforms) transform_grid = transform(grid) transformed_view = T.broadcasting_grid_sample(view, transform_grid) obs = data if cond else None pyro.sample("pixels", D.Bernoulli(transformed_view).to_event(3), obs=obs)
def model(self, xs, ys=None): # register this pytorch module and all of its sub-modules with pyro pyro.module("generation_net", self) batch_size = xs.shape[0] with pyro.plate("data"): # Prior network uses the baseline predictions as initial guess. # This is the generative process with recurrent connection with torch.no_grad(): # this ensures the training process does not change the # baseline network y_hat = self.baseline_net(xs).view(xs.shape) # sample the handwriting style from the prior distribution, which is # modulated by the input xs. prior_loc, prior_scale = self.prior_net(xs, y_hat) zs = pyro.sample('z', dist.Normal(prior_loc, prior_scale).to_event(1)) # the output y is generated from the distribution pθ(y|x, z) loc = self.generation_net(zs) if ys is not None: # In training, we will only sample in the masked image mask_loc = loc[(xs == -1).view(-1, 784)].view(batch_size, -1) mask_ys = ys[xs == -1].view(batch_size, -1) pyro.sample('y', dist.Bernoulli(mask_loc).to_event(1), obs=mask_ys) else: # In testing, no need to sample: the output is already a # probability in [0, 1] range, which better represent pixel # values considering grayscale. If we sample, we will force # each pixel to be either 0 or 1, killing the grayscale pyro.deterministic('y', loc.detach()) # return the loc so we can visualize it later return loc
def model(self, data, transforms=None): output_size = self.encoder.insize decoder = pyro.module("decoder", self.decoder) # decoder takes z and std in the transformed coordinate frame # and the theta # and outputs an upright image with pyro.plate(data.shape[0]): # prior for z z = pyro.sample( "z", D.Normal( torch.zeros(decoder.z_dim, device=data.device), torch.ones(decoder.z_dim, device=data.device), ).to_event(1), ) # given a z, the decoder produces an "image" # this image must be transformed from the self consistent basis # to real world basis # first, z and std for the self consistent basis is outputted # then it is transfomed view = decoder(z) pyro.deterministic("canonical_view", view) # pyro.deterministic # is like pyro.sample but it is deterministic...? # all of this is completely independent of the input # maybe this is the "prior for the transformation" # and hence it looks completely independent of the input # but when the model is run again, these variables are replayed # with the theta generated by the guide # makes sense # so the model replays with theta and mu and sigma generated by # the guide, # taking theta and mu sigma and applying the inverse transform # to get the output image. grid = coordinates.identity_grid( [output_size, output_size], device=data.device) grid = grid.expand(data.shape[0], *grid.shape) transforms = T.TransformSequence(T.Translation(), T.Rotation()) transform = random_pose_transform(transforms) transform_grid = transform(grid) # output from decoder is transormed in do a different coordinate system transformed_view = T.broadcasting_grid_sample(view, transform_grid) # view from decoder outputs an image pyro.sample( "pixels", D.Bernoulli(transformed_view).to_event(3), obs=data)
def backward_model(data, *args, encoder=None, **kwargs): encoder = pyro.module("encoder", encoder) N = data.shape[0] with poutine.scale_messenger.ScaleMessenger(1 / N): with pyro.plate("batch", N): encoder_out = encoder(data) delta_sample_transformer_params(encoder.transformer.transformers, encoder_out["transform_params"]) pyro.deterministic("attention_input", encoder_out["view"]) z = pyro.sample( "z", D.Normal(encoder_out["z_mu"], torch.exp(encoder_out["z_std"]) + 1e-3).to_event(1), )
def state_space_model(data, N=1, T=2, prior_drift=0., verbose=False): # global rvs drift = pyro.sample('drift', dist.Normal(prior_drift, 1)) vol = pyro.sample('vol', dist.LogNormal(0, 1)) uncert = pyro.sample('uncert', dist.LogNormal(-5, 1)) if verbose: print(f"Using drift = {drift}, vol = {vol}, uncert = {uncert}") # the latent time series you want to infer # since you want to output this, we initialize a vector where you'll # save the inferred values latent = torch.empty((T, N)) # +1 comes from hidden initial condition # I think you want to plate out the same state space model for N different obs with pyro.plate('data_plate', N) as n: x0 = pyro.sample('x0', dist.Normal(drift, vol)) # or whatever your IC might be latent[0, n] = x0 # now comes the markov part, as you correctly noted for t in pyro.markov(range(1, T)): x_t = pyro.sample(f"x_{t}", dist.Normal(latent[t - 1, n] + drift, vol)) y_t = pyro.sample(f"y_{t}", dist.Normal(x_t, uncert), obs=data[t - 1, n] if data is not None else None) latent[t, n] = x_t return pyro.deterministic('latent', latent)
def model(self, home_team, away_team): sigma_a = pyro.sample("sigma_a", dist.HalfNormal(1.0)) sigma_b = pyro.sample("sigma_b", dist.HalfNormal(1.0)) mu_b = pyro.sample("mu_b", dist.Normal(0.0, 1.0)) rho_raw = pyro.sample("rho_raw", dist.Beta(2, 2)) rho = pyro.deterministic("rho", 2.0 * rho_raw - 1.0) log_gamma = pyro.sample("log_gamma", dist.Normal(0, 1)) with pyro.plate("teams", self.n_teams): abilities = pyro.sample( "abilities", dist.MultivariateNormal( torch.tensor([0.0, mu_b]), covariance_matrix=torch.tensor( [ [sigma_a ** 2.0, rho * sigma_a * sigma_b], [rho * sigma_a * sigma_b, sigma_b ** 2.0], ] ), ), ) log_a = abilities[:, 0] log_b = abilities[:, 1] home_inds = torch.tensor([self.team_to_index[team] for team in home_team]) away_inds = torch.tensor([self.team_to_index[team] for team in away_team]) home_rate = torch.exp(log_a[home_inds] + log_b[away_inds] + log_gamma) away_rate = torch.exp(log_a[away_inds] + log_b[home_inds]) pyro.sample("home_goals", dist.Poisson(home_rate)) pyro.sample("away_goals", dist.Poisson(away_rate))
def quantize(name, x_real, min, max): """ Randomly quantize in a way that preserves probability mass. We use a piecewise polynomial spline of order 3. """ assert min < max lb = x_real.detach().floor() # This cubic spline interpolates over the nearest four integers, ensuring # piecewise quadratic gradients. s = x_real - lb ss = s * s t = 1 - s tt = t * t probs = torch.stack( [ t * tt, 4 + ss * (3 * s - 6), 4 + tt * (3 * t - 6), s * ss, ], dim=-1, ) * (1 / 6) q = pyro.sample("Q_" + name, dist.Categorical(probs)).type_as(x_real) x = lb + q - 1 x = torch.max(x, 2 * min - 1 - x) x = torch.min(x, 2 * max + 1 - x) return pyro.deterministic(name, x)
def discrete_model(args, data): # Sample global parameters. rate_s, prob_i, rho = global_model(args.population) # Sequentially sample time-local variables. S = torch.tensor(args.population - 1.0) I = torch.tensor(1.0) for t, datum in enumerate(data): S2I = pyro.sample("S2I_{}".format(t), dist.Binomial(S, -(rate_s * I).expm1())) I2R = pyro.sample("I2R_{}".format(t), dist.Binomial(I, prob_i)) S = pyro.deterministic("S_{}".format(t), S - S2I) I = pyro.deterministic("I_{}".format(t), I + S2I - I2R) pyro.sample("obs_{}".format(t), dist.ExtendedBinomial(S2I, rho), obs=datum)
def forward(self, diameters: Tensor, resp_model=configs.respiratorA, debit=_DEBIT): """ Parameters ---------- diameters (Tensor): particle diameters resp_model (function): Tuple of surface area and list of `~MaskLayers`. """ surface_area_, layers_ = resp_model() face_vel = debit / surface_area_ with pyro.plate("diameters"): phi = pyro.deterministic( "phi", penetration.compute_penetration_profile(diameters, layers_, face_vel, configs.temperature, configs.viscosity, return_log=True)) obs = pyro.sample('obs_log', dist.Normal(loc=phi, scale=self.obs_scale)) return obs
def forward_model(data, label, N=-1, transforms=None, instantiate_label=False, cond_label=True, cond=True, decoder=None, latent_decoder=None, output_size=40, device=torch.device("cpu"), **kwargs): decoder = pyro.module("view_decoder", decoder) with pyro.plate("batch", N): z = pyro.sample( "z", D.Normal( torch.zeros(N, decoder.latent_dim, device=device), torch.ones(N, decoder.latent_dim, device=device), ).to_event(1), ) # use supervision if instantiate_label: latent_decoder = pyro.module("latent_decoder", latent_decoder) label_logits = latent_decoder(z) obs_label = label if cond_label else None pyro.sample("y", D.Categorical(logits=label_logits), obs=obs_label) view = decoder(z) pyro.deterministic("canonical_view", view) grid = coordinates.identity_grid([output_size, output_size], device=device) grid = grid.expand(N, *grid.shape) transform = random_pose_transform(transforms, device=device) transform_grid = transform(grid) transformed_view = T.broadcasting_grid_sample(view, transform_grid) obs = data if cond else None pyro.sample("pixels", D.Bernoulli(transformed_view).to_event(3), obs=obs)
def forward_model( data, transforms=None, instantiate_label=False, cond=True, decoder=None, output_size=128, device=torch.device("cpu"), kl_beta=1.0, **kwargs, ): decoder = pyro.module("view_decoder", decoder) N = data.shape[0] with poutine.scale_messenger.ScaleMessenger(1 / N): with pyro.plate("batch", N): with poutine.scale_messenger.ScaleMessenger(kl_beta): z = pyro.sample( "z", D.Normal( torch.zeros(N, decoder.latent_dim, device=device), torch.ones(N, decoder.latent_dim, device=device), ).to_event(1), ) # use supervision view = decoder(z) pyro.deterministic("canonical_view", view) grid = coordinates.identity_grid([output_size, output_size], device=device) grid = grid.expand(N, *grid.shape) scale = view.shape[-1] / output_size grid = grid * ( 1 / scale ) # rescales the image co-ordinates so one pixel of the recon corresponds to 1 pixel of the view. transform = random_pose_transform(transforms, device=device) transform_grid = transform(grid) transformed_view = T.broadcasting_grid_sample(view, transform_grid) obs = data if cond else None pyro.sample("pixels", D.Laplace(transformed_view, 0.5).to_event(3), obs=obs)
def pyro_sample_saas_lengthscales(dim: int, alpha: float = 0.1, pyro: Any = None, **tkwargs: Any) -> Tensor: tausq = pyro.sample( "kernel_tausq", pyro.distributions.HalfCauchy(torch.tensor(alpha, **tkwargs)), ) inv_length_sq = pyro.sample( "_kernel_inv_length_sq", pyro.distributions.HalfCauchy(torch.ones(dim, **tkwargs)), ) inv_length_sq = pyro.deterministic("kernel_inv_length_sq", tausq * inv_length_sq) lengthscale = pyro.deterministic( "lengthscale", (1.0 / inv_length_sq).sqrt(), # pyre-ignore [16] ) return lengthscale
def sample_lengthscale( self, dim: int, alpha: float = 0.1, **tkwargs: Any ) -> Tensor: r"""Sample the lengthscale.""" tausq = pyro.sample( "kernel_tausq", pyro.distributions.HalfCauchy(torch.tensor(alpha, **tkwargs)), ) inv_length_sq = pyro.sample( "_kernel_inv_length_sq", pyro.distributions.HalfCauchy(torch.ones(dim, **tkwargs)), ) inv_length_sq = pyro.deterministic( "kernel_inv_length_sq", tausq * inv_length_sq ) lengthscale = pyro.deterministic( "lengthscale", (1.0 / inv_length_sq).sqrt(), ) return lengthscale
def spectral_matrix_gp(n_points, plate, suffix=''): if suffix != '': suffix = f'_{suffix}' with plate: l = 150 pts = torch.arange(n_points, dtype=torch.float64).unsqueeze(-1) distance_squared = torch.pow(pts - pts.T, 2).unsqueeze(-1) cov = pyro.deterministic('cov', torch.exp(-0.5 * distance_squared / l).T, event_dim=2).contiguous() diag_idx = np.diag_indices(n_points, ndim=1) cov[:, diag_idx, diag_idx] += torch.rand(1, 1, n_points) / 1000 gp = pyro.sample( 'gp', dist.MultivariateNormal(loc=torch.tensor([0.] * n_points), covariance_matrix=cov).to_event(0)) S = pyro.deterministic('S', F.softplus(gp * 50) * 20, event_dim=1) return S
def quantize(name, x_real, min, max, num_quant_bins=4): """Randomly quantize in a way that preserves probability mass.""" assert _all(min < max) if num_quant_bins == 1: x = x_real.detach().round() return pyro.deterministic(name, x, event_dim=0) lb = x_real.detach().floor() probs = compute_bin_probs(x_real - lb, num_quant_bins=num_quant_bins) q = pyro.sample("Q_" + name, dist.Categorical(probs), infer={"enumerate": "parallel"}) q = q.type_as(x_real) - (num_quant_bins // 2 - 1) x = lb + q x = torch.max(x, 2 * min - 1 - x) x = torch.min(x, 2 * max + 1 - x) return pyro.deterministic(name, x, event_dim=0)
def time_matrix(time, irf, t0, plate, suffix='', scattering=False): if suffix != '': suffix = f'_{suffix}' if scattering: lol = pyro.deterministic(f'T_sc', irf / irf.max()) return lol with plate: if not scattering: tau = pyro.sample(f'tau{suffix}', dist.Gamma(5, 10))[:, np.newaxis] else: tau = pyro.sample(f'tau{suffix}', dist.Uniform(0.00001, 0.005))[:, np.newaxis] T_unbound = pyro.deterministic(f'T_unbound{suffix}', torch.exp(-time / tau)) T = pyro.deterministic(f'T{suffix}', T_unbound) T_convolved = conv1d(T, irf, mode='fft_circular', cut=False) if not scattering: circular_multiplier = 1 / (1 - torch.exp(-time[-1] / tau)) circular_multiplier = pyro.deterministic('circular_multiplier', circular_multiplier) T_convolved = T_convolved * circular_multiplier T_convolved = pyro.deterministic(f'T_convolved{suffix}', T_convolved) T_scaled = T_convolved / (T_convolved.max(dim=1)[0][:, np.newaxis]) T_scaled = pyro.deterministic(f'T_scaled{suffix}', T_scaled) return T_scaled
def model_full(self, data=None, time=None, fix_time=False): components_plate = pyro.plate('components', self.n_components) if fix_time: t0 = pyro.deterministic('t0', self.t0) else: t0 = pyro.sample('t0', dist.Normal(loc=self.t0, scale=0.01)) xi = pyro.sample( 'xi', dist.Exponential(torch.rand(self.n_wavelenghs) / 10).to_event(1)) irf = Interp1d()(self.time_padded, self.irf_padded, time - t0, None) T_scaled = time_matrix(time, irf, t0, components_plate, suffix='') S = spectral_matrix_gp(self.n_wavelenghs, components_plate, suffix='') ST = pyro.deterministic('ST', S.T @ T_scaled, event_dim=2) if self.scattering: scattering_plate = pyro.plate('scattering', 1) T_sc = time_matrix(time, irf, t0, scattering_plate, suffix='sc', scattering=True) S_sc = spectral_matrix_unscaled( self.n_wavelenghs, scattering_plate, suffix='sc') * 20 ST_sc = pyro.deterministic('ST_sc', S_sc.T @ T_sc, event_dim=2) ST = ST + ST_sc if data is not None: data = data[self.time_slice][self.wavelength_slice] pyro.sample( 'I_obs', dist.Poisson( (ST.T + xi)[self.time_slice][self.wavelength_slice]).to_event(2), obs=data) pyro.deterministic('I', ST.T + xi, event_dim=2)
def spectral_matrix(n_points, plate, suffix=''): if suffix != '': suffix = f'_{suffix}' with plate: S_unscaled = pyro.sample( f'S_unscaled{suffix}', dist.Exponential(torch.rand(n_points)).to_event(1)) S = pyro.deterministic(f'S{suffix}', S_unscaled / S_unscaled.max(dim=1)[0][:, np.newaxis], event_dim=1) return S
def forward(self, data): N_pop = 300 p1 = self.ode_params1.view((-1, )) p2 = self.ode_params2.view((-1, )) p3 = self.ode_params3.view((-1, )) R0 = pyro.deterministic('R0', torch.zeros_like(p1)) ode_params = torch.stack([p1, p2, p3, 1 - p3, R0], dim=1) SIR_sim = self._ode_op.apply(ode_params, (self._ode_model, )) for i in range(len(data)): pyro.sample("obs_{}".format(i), dist.Poisson(SIR_sim[..., i, 1] * N_pop), obs=data[i]) return SIR_sim
def model(self, raw_expr, encoded_expr, read_depth): pyro.module("decoder", self.decoder) with pyro.plate("genes", self.num_genes): dispersion = pyro.sample( "dispersion", dist.Gamma( torch.tensor(2.).to(self.device), torch.tensor(0.5).to(self.device))) psi = pyro.sample( "dropout", dist.Beta( torch.tensor(1.).to(self.device), torch.tensor(10.).to(self.device))) #pyro.module("decoder", self.decoder) with pyro.plate("cells", encoded_expr.shape[0]): # Dirichlet prior 𝑝(𝜃|𝛼) is replaced by a log-normal distribution theta_loc = self.prior_mu * encoded_expr.new_ones( (encoded_expr.shape[0], self.num_topics)) theta_scale = self.prior_std * encoded_expr.new_ones( (encoded_expr.shape[0], self.num_topics)) theta = pyro.sample( "theta", dist.LogNormal(theta_loc, theta_scale).to_event(1)) theta = theta / theta.sum(-1, keepdim=True) # conditional distribution of 𝑤𝑛 is defined as # 𝑤𝑛|𝛽,𝜃 ~ Categorical(𝜎(𝛽𝜃)) expr_rate = pyro.deterministic("expr_rate", self.decoder(theta)) mu = torch.multiply(read_depth, expr_rate) p = torch.minimum(mu / (mu + dispersion), self.max_prob) pyro.sample( 'obs', dist.ZeroInflatedNegativeBinomial(total_count=dispersion, probs=p, gate=psi).to_event(1), obs=raw_expr)
def pyrocov_model(dataset): # Tensor shapes are commented at the end of some lines. features = dataset["features"] local_time = dataset["local_time"][..., None] # [T, P, 1] T, P, _ = local_time.shape S, F = features.shape weekly_strains = dataset["weekly_strains"] assert weekly_strains.shape == (T, P, S) # Sample global random variables. coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))[..., None] rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))[..., None] init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))[..., None] init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))[..., None] # Assume relative growth rate depends strongly on mutations and weakly on place. coef_loc = torch.zeros(F) coef = pyro.sample("coef", dist.Logistic(coef_loc, coef_scale).to_event(1)) # [F] rate_loc = pyro.deterministic( "rate_loc", 0.01 * coef @ features.T, event_dim=1 ) # [S] # Assume initial infections depend strongly on strain and place. init_loc = pyro.sample( "init_loc", dist.Normal(torch.zeros(S), init_loc_scale).to_event(1) ) # [S] with pyro.plate("place", P, dim=-1): rate = pyro.sample( "rate", dist.Normal(rate_loc, rate_scale).to_event(1) ) # [P, S] init = pyro.sample( "init", dist.Normal(init_loc, init_scale).to_event(1) ) # [P, S] # Finally observe counts. with pyro.plate("time", T, dim=-2): logits = init + rate * local_time # [T, P, S] pyro.sample( "obs", dist.Multinomial(logits=logits, validate_args=False), obs=weekly_strains, )
def model(data): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", NonreparameterizedNormal(a, 0)) c = pyro.sample("c", dist.Normal(b, 1)) d = pyro.sample("d", dist.Normal(a, c.exp())) e = pyro.sample("e", dist.Normal(0, 1)) f = pyro.sample("f", dist.Normal(0, 1)) g = pyro.sample("g", dist.Bernoulli(logits=e + f), obs=torch.tensor(0.0)) with pyro.plate("p", len(data)): d_ = d.detach() # this results in a known failure h = pyro.sample("h", dist.Normal(c, d_.exp())) i = pyro.deterministic("i", h + 1) j = pyro.sample("j", dist.Delta(h + 1), obs=h + 1) k = pyro.sample("k", dist.Normal(a, j.exp()), obs=data) return [a, b, c, d, e, f, g, h, i, j, k]
def forward(self): with pyro.plate("gene_weights", self.G): b = pyro.sample("b", dist.Normal(-10.,3.)) theta = pyro.sample("theta", dist.Gamma(2., 0.5)) psi = pyro.sample("dropout", dist.Beta(1., 10.)) with pyro.plate("topic-gene_weights", self.K): beta = pyro.sample("beta", dist.Gamma(1., 5.)) with pyro.plate("gene", self.G) as gene: with pyro.plate("data", self.N, subsample_size=64) as ind: expr_rate = pyro.deterministic("rate", torch.matmul(self.cell_topics.index_select(0, ind), beta) + b) mu = torch.reshape(self.read_depth, (-1,1)).index_select(0, ind) * torch.exp(expr_rate) p = torch.minimum(mu / (mu + theta), torch.tensor([0.99999])) pyro.sample("obs", dist.ZeroInflatedNegativeBinomial(total_count=theta, probs=p, gate = psi), obs= self.gene_expr.index_select(0, ind))
def model(self, x_data, idx=None): # =====================Gene expression level scaling======================= # # Explains difference in expression between genes and # how it differs in single cell and spatial technology # compute hyperparameters from mean and sd gl_alpha_shape = self.gl_shape**2 / self.gl_shape_var gl_alpha_rate = self.gl_shape / self.gl_shape_var gl_beta_shape = self.gl_rate**2 / self.gl_rate_var gl_beta_rate = self.gl_rate / self.gl_rate_var self.gene_level_alpha_hyp = pyro.sample( 'gene_level_alpha_hyp', dist.Gamma( torch.ones([1, 1]) * torch.tensor(gl_alpha_shape), torch.ones([1, 1]) * torch.tensor(gl_alpha_rate))) self.gene_level_beta_hyp = pyro.sample( 'gene_level_beta_hyp', dist.Gamma( torch.ones([1, 1]) * torch.tensor(gl_beta_shape), torch.ones([1, 1]) * torch.tensor(gl_beta_rate))) self.gene_level = pyro.sample( 'gene_level', dist.Gamma( torch.ones([self.n_var, 1]) * self.gene_level_alpha_hyp, torch.ones([self.n_var, 1]) * self.gene_level_beta_hyp)) # scale cell state factors by gene_level self.gene_factors = pyro.deterministic('gene_factors', self.cell_state) # =====================Spot factors======================= # # prior on spot factors reflects the number of cells, fraction of their cytoplasm captured, # times heterogeniety in the total number of mRNA between individual cells with each cell type cps_shape = self.cell_number_prior['cells_per_spot'] ** 2 \ / (self.cell_number_prior['cells_per_spot'] / self.cell_number_prior['cells_mean_var_ratio']) cps_rate = self.cell_number_prior['cells_per_spot'] \ / (self.cell_number_prior['cells_per_spot'] / self.cell_number_prior['cells_mean_var_ratio']) self.cells_per_spot = pyro.sample( 'cells_per_spot', dist.Gamma( torch.ones([self.n_obs, 1]) * torch.tensor(cps_shape), torch.ones([self.n_obs, 1]) * torch.tensor(cps_rate))) fps_shape = self.cell_number_prior['factors_per_spot'] ** 2 \ / (self.cell_number_prior['factors_per_spot'] / self.cell_number_prior['factors_mean_var_ratio']) fps_rate = self.cell_number_prior['factors_per_spot'] \ / (self.cell_number_prior['factors_per_spot'] / self.cell_number_prior['factors_mean_var_ratio']) self.factors_per_spot = pyro.sample( 'factors_per_spot', dist.Gamma( torch.ones([self.n_obs, 1]) * torch.tensor(fps_shape), torch.ones([self.n_obs, 1]) * torch.tensor(fps_rate))) shape = self.factors_per_spot / torch.tensor( np.array(self.n_fact).reshape((1, 1))) rate = torch.ones([1, 1]) / self.cells_per_spot * self.factors_per_spot self.spot_factors = pyro.sample( 'spot_factors', dist.Gamma(torch.matmul(shape, torch.ones([1, self.n_fact])), torch.matmul(rate, torch.ones([1, self.n_fact])))) # =====================Spot-specific additive component======================= # # molecule contribution that cannot be explained by cell state signatures # these counts are distributed between all genes not just expressed genes self.spot_add_hyp = pyro.sample( 'spot_add_hyp', dist.Gamma( torch.ones([2, 1]) * torch.tensor(1.), torch.ones([2, 1]) * torch.tensor(0.1))) self.spot_add = pyro.sample( 'spot_add', dist.Gamma( torch.ones([self.n_obs, 1]) * self.spot_add_hyp[0, 0], torch.ones([self.n_obs, 1]) * self.spot_add_hyp[1, 0])) # =====================Gene-specific additive component ======================= # # per gene molecule contribution that cannot be explained by cell state signatures # these counts are distributed equally between all spots (e.g. background, free-floating RNA) self.gene_add_hyp = pyro.sample( 'gene_add_hyp', dist.Gamma( torch.ones([2, 1]) * torch.tensor(1.), torch.ones([2, 1]) * torch.tensor(1.))) self.gene_add = pyro.sample( 'gene_add', dist.Gamma( torch.ones([self.n_var, 1]) * self.gene_add_hyp[0, 0], torch.ones([self.n_var, 1]) * self.gene_add_hyp[1, 0])) # =====================Gene-specific overdispersion ======================= # self.phi_hyp = pyro.sample( 'phi_hyp', dist.Gamma( torch.ones([1, 1]) * torch.tensor(self.phi_hyp_prior['mean']), torch.ones([1, 1]) * torch.tensor(self.phi_hyp_prior['sd']))) self.gene_E = pyro.sample( 'gene_E', dist.Exponential(torch.ones([self.n_var, 1]) * self.phi_hyp[0, 0])) # =====================Expected expression ======================= # # expected expression self.mu_biol = torch.matmul(self.spot_factors[idx], self.gene_factors.T) * self.gene_level.T \ + self.gene_add.T + self.spot_add[idx] # =====================DATA likelihood ======================= # # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial self.data_target = pyro.sample( 'data_target', NegativeBinomial(mu=self.mu_biol, theta=torch.ones([1, 1]) / (self.gene_E.T * self.gene_E.T)), obs=x_data) # =====================Compute nUMI from each factor in spots ======================= # nUMI = (self.spot_factors * (self.gene_factors * self.gene_level).sum(0)) self.nUMI_factors = pyro.deterministic('nUMI_factors', nUMI)
def model( s, m, y=None, gamma_hyper=1.0, pi0=1.0, rho0=1.0, epsilon0=0.01, alpha0=1000.0, dtype=torch.float32, device="cpu", ): # Cast inputs and set device m, gamma_hyper, pi0, rho0, epsilon0, alpha0 = [ torch.tensor(v, dtype=dtype, device=device) for v in [m, gamma_hyper, pi0, rho0, epsilon0, alpha0] ] if y is not None: y = torch.tensor(y) n, g = m.shape with pyro.plate("position", g, dim=-1): with pyro.plate("strain", s, dim=-2): gamma = pyro.sample( "gamma", dist.RelaxedBernoulli(temperature=gamma_hyper, logits=0.0), ) # gamma.shape == (s, g) rho_hyper = pyro.sample("rho_hyper", dist.Gamma(rho0, 1.0)) rho = pyro.sample( "rho", dist.RelaxedOneHotCategorical( temperature=rho_hyper, logits=torch.zeros(s, dtype=dtype, device=device), ), ) epsilon_hyper = pyro.sample("epsilon_hyper", dist.Beta(1.0, 1 / epsilon0)) alpha_hyper = pyro.sample("alpha_hyper", dist.Gamma(alpha0, 1.0)) pi_hyper = pyro.sample("pi_hyper", dist.Gamma(pi0, 1.0)) with pyro.plate("sample", n, dim=-1): pi = pyro.sample( "pi", dist.RelaxedOneHotCategorical(temperature=pi_hyper, probs=rho), ) alpha = pyro.sample("alpha", dist.Gamma(alpha_hyper, 1.0)).unsqueeze( -1 ) epsilon = pyro.sample( "epsilon", dist.Beta(1.0, 1 / epsilon_hyper) ).unsqueeze(-1) # pi.shape == (n, s) # alpha.shape == epsilon.shape == (n,) p_noerr = pyro.deterministic("p_noerr", pi @ gamma) p = pyro.deterministic( "p", (1 - epsilon / 2) * (p_noerr) + (epsilon / 2) * (1 - p_noerr) ) # p.shape == (n, g) y = pyro.sample( "y", dist.BetaBinomial( concentration1=alpha * p, concentration0=alpha * (1 - p), total_count=m, ), obs=y, ) # y.shape == (n, g) return y
def __call__(self): """ Notes ----- Labeling system: 1. for kernel level of parameters such as rho, span, nkots, kerenel etc., use suffix _lev and _coef for levels and regression to partition 2. for knots level of parameters such as coef, loc and scale priors, use prefix _lev and _rr _pr for levels, regular and positive regressors to partition 3. reduce ambigious by replacing all greeks by labels more intuitive use _coef, _weight etc. instead of _beta, use _scale instead of _sigma """ response = self.response which_valid = self.which_valid_res n_obs = self.n_obs # n_valid = self.n_valid_res sdy = self.sdy meany = self.mean_y dof = self.dof lev_knot_loc = self.lev_knot_loc seas_term = self.seas_term pr = self.pr rr = self.rr n_pr = self.n_pr n_rr = self.n_rr k_lev = self.k_lev k_coef = self.k_coef n_knots_lev = self.n_knots_lev n_knots_coef = self.n_knots_coef lev_knot_scale = self.lev_knot_scale # mult var norm stuff mvn = self.mvn geometric_walk = self.geometric_walk min_residuals_sd = self.min_residuals_sd if min_residuals_sd > 1.0: min_residuals_sd = torch.tensor(1.0) if min_residuals_sd < 0: min_residuals_sd = torch.tensor(0.0) # expand dim to n_rr x n_knots_coef rr_init_knot_loc = self.rr_init_knot_loc rr_init_knot_scale = self.rr_init_knot_scale rr_knot_scale = self.rr_knot_scale # this does not need to expand dim since it is used as latent grand mean pr_init_knot_loc = self.pr_init_knot_loc pr_init_knot_scale = self.pr_init_knot_scale pr_knot_scale = self.pr_knot_scale # transformation of data regressors = torch.zeros(n_obs) if n_pr > 0 and n_rr > 0: regressors = torch.cat([rr, pr], dim=-1) elif n_pr > 0: regressors = pr elif n_rr > 0: regressors = rr response_tran = response - meany - seas_term # sampling begins here extra_out = {} # levels sampling lev_knot_tran = pyro.sample( "lev_knot_tran", dist.Normal(lev_knot_loc - meany, lev_knot_scale).expand([n_knots_lev]).to_event(1)) lev = (lev_knot_tran @ k_lev.transpose(-2, -1)) # using hierarchical priors vs. multivariate priors if mvn == 0: # regular regressor sampling if n_rr > 0: # pooling latent variables rr_init_knot = pyro.sample( "rr_init_knot", dist.Normal(rr_init_knot_loc, rr_init_knot_scale).to_event(1)) rr_knot = pyro.sample( "rr_knot", dist.Normal( rr_init_knot.unsqueeze(-1) * torch.ones(n_rr, n_knots_coef), rr_knot_scale).to_event(2)) rr_coef = (rr_knot @ k_coef.transpose(-2, -1)).transpose( -2, -1) # positive regressor sampling if n_pr > 0: if geometric_walk: # TODO: development method pr_init_knot = pyro.sample( "pr_init_knot", dist.FoldedDistribution( dist.Normal(pr_init_knot_loc, pr_init_knot_scale)).to_event(1)) pr_knot_step = pyro.sample( "pr_knot_step", # note that unlike rr_knot, the first one is ignored as we use the initial scale # to sample the first knot dist.Normal(torch.zeros(n_pr, n_knots_coef), pr_knot_scale).to_event(2)) pr_knot = pr_init_knot.unsqueeze(-1) * pr_knot_step.cumsum( -1).exp() pr_coef = (pr_knot @ k_coef.transpose(-2, -1)).transpose( -2, -1) else: # TODO: original method # pooling latent variables pr_init_knot = pyro.sample( "pr_knot_loc", dist.FoldedDistribution( dist.Normal(pr_init_knot_loc, pr_init_knot_scale)).to_event(1)) pr_knot = pyro.sample( "pr_knot", dist.FoldedDistribution( dist.Normal( pr_init_knot.unsqueeze(-1) * torch.ones(n_pr, n_knots_coef), pr_knot_scale)).to_event(2)) pr_coef = (pr_knot @ k_coef.transpose(-2, -1)).transpose( -2, -1) else: # regular regressor sampling if n_rr > 0: rr_init_knot = pyro.deterministic( "rr_init_knot", torch.zeros(rr_init_knot_loc.shape)) # updated mod loc_temp = rr_init_knot_loc.unsqueeze(-1) * torch.ones( n_rr, n_knots_coef) scale_temp = torch.diag_embed( rr_init_knot_scale.unsqueeze(-1) * torch.ones(n_rr, n_knots_coef)) # the sampling rr_knot = pyro.sample( "rr_knot", dist.MultivariateNormal( loc=loc_temp, covariance_matrix=scale_temp).to_event(1)) rr_coef = (rr_knot @ k_coef.transpose(-2, -1)).transpose( -2, -1) # positive regressor sampling if n_pr > 0: # this part is junk just so that the pr_init_knot has a prior; but it does not connect to anything else # pooling latent variables pr_init_knot = pyro.sample( "pr_init_knot", dist.FoldedDistribution( dist.Normal(pr_init_knot_loc, pr_init_knot_scale)).to_event(1)) # updated mod loc_temp = pr_init_knot_loc.unsqueeze(-1) * torch.ones( n_pr, n_knots_coef) scale_temp = torch.diag_embed( pr_init_knot_scale.unsqueeze(-1) * torch.ones(n_pr, n_knots_coef)) pr_knot = pyro.sample( "pr_knot", dist.MultivariateNormal( loc=loc_temp, covariance_matrix=scale_temp).to_event(1)) pr_knot = torch.exp(pr_knot) pr_coef = (pr_knot @ k_coef.transpose(-2, -1)).transpose( -2, -1) # concatenating all latent variables coef_init_knot = torch.zeros(n_rr + n_pr) coef_knot = torch.zeros((n_rr + n_pr, n_knots_coef)) coef = torch.zeros(n_obs) if n_pr > 0 and n_rr > 0: coef_knot = torch.cat([rr_knot, pr_knot], dim=-2) coef_init_knot = torch.cat([rr_init_knot, pr_init_knot], dim=-1) coef = torch.cat([rr_coef, pr_coef], dim=-1) elif n_pr > 0: coef_knot = pr_knot coef_init_knot = pr_init_knot coef = pr_coef elif n_rr > 0: coef_knot = rr_knot coef_init_knot = rr_init_knot coef = rr_coef # coefficients likelihood/priors coef_prior_list = self.coef_prior_list if coef_prior_list: for x in coef_prior_list: name = x['name'] # TODO: we can move torch conversion to init to enhance speed m = torch.tensor(x['prior_mean']) sd = torch.tensor(x['prior_sd']) # tp = torch.tensor(x['prior_tp_idx']) # idx = torch.tensor(x['prior_regressor_col_idx']) start_tp_idx = x['prior_start_tp_idx'] end_tp_idx = x['prior_end_tp_idx'] idx = x['prior_regressor_col_idx'] pyro.sample("prior_{}".format(name), dist.Normal(m, sd).to_event(2), obs=coef[..., start_tp_idx:end_tp_idx, idx]) # observation likelihood yhat = lev + (regressors * coef).sum(-1) obs_scale_base = pyro.sample("obs_scale_base", dist.Beta(2, 2)).unsqueeze(-1) # from 0.5 * sdy to sdy obs_scale = ((obs_scale_base * (1.0 - min_residuals_sd)) + min_residuals_sd) * sdy # with pyro.plate("response_plate", n_valid): # pyro.sample("response", # dist.StudentT(dof, yhat[..., which_valid], obs_scale), # obs=response_tran[which_valid]) pyro.sample("response", dist.StudentT(dof, yhat[..., which_valid], obs_scale).to_event(1), obs=response_tran[which_valid]) lev_knot = lev_knot_tran + meany extra_out.update({ 'yhat': yhat + seas_term + meany, 'lev': lev + meany, 'lev_knot': lev_knot, 'coef': coef, 'coef_knot': coef_knot, 'coef_init_knot': coef_init_knot, 'obs_scale': obs_scale, }) return extra_out
def forward(self, x_data, idx, batch_index): obs2sample = one_hot(batch_index, self.n_batch) obs_plate = self.create_plates(x_data, idx, batch_index) # =====================Cell abundances w_sf======================= # # factorisation prior on w_sf models similarity in locations # across cell types f and reflects the absolute scale of w_sf with obs_plate: n_s_cells_per_location = pyro.sample( "n_s_cells_per_location", dist.Gamma( self.N_cells_per_location * self.N_cells_mean_var_ratio, self.N_cells_mean_var_ratio, ), ) y_s_groups_per_location = pyro.sample( "y_s_groups_per_location", dist.Gamma(self.Y_groups_per_location, self.ones), ) # cell group loadings shape = self.ones_1_n_groups * y_s_groups_per_location / self.n_groups_tensor rate = self.ones_1_n_groups / (n_s_cells_per_location / y_s_groups_per_location) with obs_plate: z_sr_groups_factors = pyro.sample( "z_sr_groups_factors", dist.Gamma( shape, rate), # .to_event(1)#.expand([self.n_groups]).to_event(1) ) # (n_obs, n_groups) k_r_factors_per_groups = pyro.sample( "k_r_factors_per_groups", dist.Gamma(self.factors_per_groups, self.ones).expand([self.n_groups, 1]).to_event(2), ) # (self.n_groups, 1) c2f_shape = k_r_factors_per_groups / self.n_factors_tensor x_fr_group2fact = pyro.sample( "x_fr_group2fact", dist.Gamma(c2f_shape, k_r_factors_per_groups).expand( [self.n_groups, self.n_factors]).to_event(2), ) # (self.n_groups, self.n_factors) with obs_plate: w_sf_mu = z_sr_groups_factors @ x_fr_group2fact w_sf = pyro.sample( "w_sf", dist.Gamma( w_sf_mu * self.w_sf_mean_var_ratio_tensor, self.w_sf_mean_var_ratio_tensor, ), ) # (self.n_obs, self.n_factors) # =====================Location-specific detection efficiency ======================= # # y_s with hierarchical mean prior detection_mean_y_e = pyro.sample( "detection_mean_y_e", dist.Gamma( self.ones * self.detection_mean_hyp_prior_alpha, self.ones * self.detection_mean_hyp_prior_beta, ).expand([self.n_batch, 1]).to_event(2), ) detection_hyp_prior_alpha = pyro.deterministic( "detection_hyp_prior_alpha", self.ones_n_batch_1 * self.detection_hyp_prior_alpha, ) beta = (obs2sample @ detection_hyp_prior_alpha) / ( obs2sample @ detection_mean_y_e) with obs_plate: detection_y_s = pyro.sample( "detection_y_s", dist.Gamma(obs2sample @ detection_hyp_prior_alpha, beta), ) # (self.n_obs, 1) # =====================Gene-specific additive component ======================= # # per gene molecule contribution that cannot be explained by # cell state signatures (e.g. background, free-floating RNA) s_g_gene_add_alpha_hyp = pyro.sample( "s_g_gene_add_alpha_hyp", dist.Gamma(self.gene_add_alpha_hyp_prior_alpha, self.gene_add_alpha_hyp_prior_beta), ) s_g_gene_add_mean = pyro.sample( "s_g_gene_add_mean", dist.Gamma( self.gene_add_mean_hyp_prior_alpha, self.gene_add_mean_hyp_prior_beta, ).expand([self.n_batch, 1]).to_event(2), ) # (self.n_batch) s_g_gene_add_alpha_e_inv = pyro.sample( "s_g_gene_add_alpha_e_inv", dist.Exponential(s_g_gene_add_alpha_hyp).expand([self.n_batch, 1]).to_event(2), ) # (self.n_batch) s_g_gene_add_alpha_e = self.ones / s_g_gene_add_alpha_e_inv.pow(2) s_g_gene_add = pyro.sample( "s_g_gene_add", dist.Gamma(s_g_gene_add_alpha_e, s_g_gene_add_alpha_e / s_g_gene_add_mean).expand([self.n_batch, self.n_vars]).to_event(2), ) # (self.n_batch, n_vars) # =====================Gene-specific overdispersion ======================= # alpha_g_phi_hyp = pyro.sample( "alpha_g_phi_hyp", dist.Gamma(self.alpha_g_phi_hyp_prior_alpha, self.alpha_g_phi_hyp_prior_beta), ) alpha_g_inverse = pyro.sample( "alpha_g_inverse", dist.Exponential(alpha_g_phi_hyp).expand( [self.n_batch, self.n_vars]).to_event(2), ) # (self.n_batch, self.n_vars) # =====================Expected expression ======================= # # expected expression mu = ((w_sf @ self.cell_state) + (obs2sample @ s_g_gene_add)) * detection_y_s alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2)) # convert mean and overdispersion to total count and logits # total_count, logits = _convert_mean_disp_to_counts_logits( # mu, alpha, eps=self.eps # ) # =====================DATA likelihood ======================= # # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial with obs_plate: pyro.sample( "data_target", dist.GammaPoisson(concentration=alpha, rate=alpha / mu), # dist.NegativeBinomial(total_count=total_count, logits=logits), obs=x_data, ) # =====================Compute mRNA count from each factor in locations ======================= # with obs_plate: mRNA = w_sf * (self.cell_state).sum(-1) pyro.deterministic("u_sf_mRNA_factors", mRNA)
def __call__(self): response = self.response num_of_obs = self.num_of_obs extra_out = {} # smoothing params if self.lev_sm_input < 0: lev_sm = pyro.sample("lev_sm", dist.Uniform(0, 1)) else: lev_sm = torch.tensor(self.lev_sm_input, dtype=torch.double) extra_out['lev_sm'] = lev_sm if self.slp_sm_input < 0: slp_sm = pyro.sample("slp_sm", dist.Uniform(0, 1)) else: slp_sm = torch.tensor(self.slp_sm_input, dtype=torch.double) extra_out['slp_sm'] = slp_sm # residual tuning parameters nu = pyro.sample("nu", dist.Uniform(self.min_nu, self.max_nu)) # prior for residuals obs_sigma = pyro.sample("obs_sigma", dist.HalfCauchy(self.cauchy_sd)) # regression parameters if self.num_of_pr == 0: pr = torch.zeros(num_of_obs) pr_beta = pyro.deterministic("pr_beta", torch.zeros(0)) else: with pyro.plate("pr", self.num_of_pr): # fixed scale ridge if self.reg_penalty_type == 0: pr_sigma = self.pr_sigma_prior # auto scale ridge elif self.reg_penalty_type == 2: # weak prior for sigma pr_sigma = pyro.sample( "pr_sigma", dist.HalfCauchy(self.auto_ridge_scale)) # case when it is not lasso if self.reg_penalty_type != 1: # weak prior for betas pr_beta = pyro.sample( "pr_beta", dist.FoldedDistribution( dist.Normal(self.pr_beta_prior, pr_sigma))) else: pr_beta = pyro.sample( "pr_beta", dist.FoldedDistribution( dist.Laplace(self.pr_beta_prior, self.lasso_scale))) pr = pr_beta @ self.pr_mat.transpose(-1, -2) if self.num_of_nr == 0: nr = torch.zeros(num_of_obs) nr_beta = pyro.deterministic("nr_beta", torch.zeros(0)) else: with pyro.plate("nr", self.num_of_nr): # fixed scale ridge if self.reg_penalty_type == 0: nr_sigma = self.nr_sigma_prior # auto scale ridge elif self.reg_penalty_type == 2: # weak prior for sigma nr_sigma = pyro.sample( "nr_sigma", dist.HalfCauchy(self.auto_ridge_scale)) # case when it is not lasso if self.reg_penalty_type != 1: # weak prior for betas nr_beta = pyro.sample( "nr_beta", dist.FoldedDistribution( dist.Normal(self.nr_beta_prior, nr_sigma))) else: nr_beta = pyro.sample( "nr_beta", dist.FoldedDistribution( dist.Laplace(self.nr_beta_prior, self.lasso_scale))) nr = nr_beta @ self.nr_mat.transpose(-1, -2) if self.num_of_rr == 0: rr = torch.zeros(num_of_obs) rr_beta = pyro.deterministic("rr_beta", torch.zeros(0)) else: with pyro.plate("rr", self.num_of_rr): # fixed scale ridge if self.reg_penalty_type == 0: rr_sigma = self.rr_sigma_prior # auto scale ridge elif self.reg_penalty_type == 2: # weak prior for sigma rr_sigma = pyro.sample( "rr_sigma", dist.HalfCauchy(self.auto_ridge_scale)) # case when it is not lasso if self.reg_penalty_type != 1: # weak prior for betas rr_beta = pyro.sample( "rr_beta", dist.Normal(self.rr_beta_prior, rr_sigma)) else: rr_beta = pyro.sample( "rr_beta", dist.Laplace(self.rr_beta_prior, self.lasso_scale)) rr = rr_beta @ self.rr_mat.transpose(-1, -2) # a hack to make sure we don't use a dimension "1" due to rr_beta and pr_beta sampling r = pr + nr + rr if r.dim() > 1: r = r.unsqueeze(-2) # trend parameters # local trend proportion lt_coef = pyro.sample("lt_coef", dist.Uniform(0, 1)) # global trend proportion gt_coef = pyro.sample("gt_coef", dist.Uniform(-0.5, 0.5)) # global trend parameter gt_pow = pyro.sample("gt_pow", dist.Uniform(0, 1)) # seasonal parameters if self.is_seasonal: # seasonality smoothing parameter if self.sea_sm_input < 0: sea_sm = pyro.sample("sea_sm", dist.Uniform(0, 1)) else: sea_sm = torch.tensor(self.sea_sm_input, dtype=torch.double) extra_out['sea_sm'] = sea_sm # initial seasonality # 33% lift is with 1 sd prob. init_sea = pyro.sample( "init_sea", dist.Normal(0, 0.33).expand([self.seasonality]).to_event(1)) init_sea = init_sea - init_sea.mean(-1, keepdim=True) b = [None] * num_of_obs # slope l = [None] * num_of_obs # level if self.is_seasonal: s = [None] * (self.num_of_obs + self.seasonality) for t in range(self.seasonality): s[t] = init_sea[..., t] s[self.seasonality] = init_sea[..., 0] else: s = [torch.tensor(0.)] * num_of_obs # states initial condition b[0] = torch.zeros_like(slp_sm) if self.is_seasonal: l[0] = response[0] - r[..., 0] - s[0] else: l[0] = response[0] - r[..., 0] # update process for t in range(1, num_of_obs): # this update equation with l[t-1] ONLY. # intentionally different from the Holt-Winter form # this change is suggested from Slawek's original SLGT model l[t] = lev_sm * (response[t] - s[t] - r[..., t]) + (1 - lev_sm) * l[t - 1] b[t] = slp_sm * (l[t] - l[t - 1]) + (1 - slp_sm) * b[t - 1] if self.is_seasonal: s[t + self.seasonality] = \ sea_sm * (response[t] - l[t] - r[..., t]) + (1 - sea_sm) * s[t] # evaluation process # vectorize as much math as possible for lst in [b, l, s]: # torch.stack requires all items to have the same shape, but the # initial items of our lists may not have batch_shape, so we expand. lst[0] = lst[0].expand_as(lst[-1]) b = torch.stack(b, dim=-1).reshape(b[0].shape[:-1] + (-1, )) l = torch.stack(l, dim=-1).reshape(l[0].shape[:-1] + (-1, )) s = torch.stack(s, dim=-1).reshape(s[0].shape[:-1] + (-1, )) lgt_sum = l + gt_coef * l.abs()**gt_pow + lt_coef * b lgt_sum = torch.cat([l[..., :1], lgt_sum[..., :-1]], dim=-1) # shift by 1 # a hack here as well to get rid of the extra "1" in r.shape if r.dim() >= 2: r = r.squeeze(-2) yhat = lgt_sum + s[..., :num_of_obs] + r with pyro.plate("response_plate", num_of_obs - 1): pyro.sample("response", dist.StudentT(nu, yhat[..., 1:], obs_sigma), obs=response[1:]) # we care beta not the pr_beta, nr_beta, ... extra_out['beta'] = torch.cat([pr_beta, nr_beta, rr_beta], dim=-1) extra_out.update({'b': b, 'l': l, 's': s, 'lgt_sum': lgt_sum}) return extra_out