def test_categorical_gradient_with_logits(init_tensor_type): p = Variable(init_tensor_type([-float('inf'), 0]), requires_grad=True) categorical = OneHotCategorical(logits=p) log_pdf = categorical.batch_log_pdf(Variable(init_tensor_type([0, 1]))) log_pdf.sum().backward() assert_equal(log_pdf.data[0], 0) assert_equal(p.grad.data[0], 0)
def main(xs, ys=None): """ The model corresponds to the following generative process: p(z) = normal(0,I) # handwriting style (latent) p(y|x) = categorical(I/10.) # which digit (semi-supervised) p(x|y,z) = bernoulli(loc(y,z)) # an image loc is given by a neural network `decoder` :param xs: a batch of scaled vectors of pixels from an image :param ys: (optional) a batch of the class labels i.e. the digit corresponding to the image(s) :return: None """ xs = torch.reshape(xs, [200,784]) # WL: ok for analyser? ===== # ys = ... # ========================== # register this pytorch module and all of its sub-modules with pyro pyro.module("decoder_fst", decoder_fst) pyro.module("decoder_snd", decoder_snd) # batch_size = xs.size(0) # batch_size = 200 # z_dim = 50 # output_size = 10 with pyro.plate("data"): # sample the handwriting style from the constant prior distribution prior_loc = torch.zeros([200, 50]) prior_scale = torch.ones([200, 50]) zs = pyro.sample("z", Normal(prior_loc, prior_scale).to_event(1)) # if the label y (which digit to write) is supervised, sample from the # constant prior, otherwise, observe the value (i.e. score it against the constant prior) alpha_prior = torch.ones([200, 10]) / (1.0 * 10) # WL: editd. ===== # ys = pyro.sample("y", OneHotCategorical(alpha_prior), obs=ys) if ys is None: ys = pyro.sample("y", OneHotCategorical(alpha_prior)) else: ys = pyro.sample("y", OneHotCategorical(alpha_prior), obs=ys) # ================ # finally, score the image (x) using the handwriting style (z) and # the class label y (which digit to write) against the # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z)) # where `decoder` is a neural network hidden = softplus(decoder_fst(torch.cat([zs, ys], -1))) loc = sigmoid(decoder_snd(hidden)) pyro.sample("x", Bernoulli(loc).to_event(1), obs=xs)
def label_variable(self, label): new_label = [] options = {'device': label.device, 'dtype': label.dtype} for i, length in enumerate(self.latents_sizes): prior = torch.ones(label.shape[0], length, ** options) / (1.0 * length) new_label.append( pyro.sample("label_" + str(self.latents_names[i]), OneHotCategorical(prior), obs=one_hot(tensor(label[:, i], dtype=torch.int64), int(length)))) new_label = torch.cat(new_label, -1) return new_label.to(torch.float32).to(label.device)
def main(xs, ys=None): """ The guide corresponds to the following: q(y|x) = categorical(alpha(x)) # infer digit from an image q(z|x,y) = normal(loc(x,y),scale(x,y)) # infer handwriting style from an image and the digit loc, scale are given by a neural network `encoder_z` alpha is given by a neural network `encoder_y` :param xs: a batch of scaled vectors of pixels from an image :param ys: (optional) a batch of the class labels i.e. the digit corresponding to the image(s) :return: None """ xs = torch.reshape(xs, [200, 784]) # WL: ok for analyser? ===== # ys = ... # ========================== pyro.module("encoder_y_fst", encoder_y_fst) pyro.module("encoder_y_snd", encoder_y_snd) pyro.module("encoder_z_fst", encoder_z_fst) pyro.module("encoder_z_out1", encoder_z_out1) pyro.module("encoder_z_out2", encoder_z_out2) # inform Pyro that the variables in the batch of xs, ys are conditionally independent with pyro.plate("data"): # if the class label (the digit) is not supervised, sample # (and score) the digit with the variational distribution # q(y|x) = categorical(alpha(x)) if ys is None: hidden = softplus(encoder_y_fst(xs)) alpha = softmax(encoder_y_snd(hidden)) ys = pyro.sample("y", OneHotCategorical(alpha)) # sample (and score) the latent handwriting-style with the variational # distribution q(z|x,y) = normal(loc(x,y),scale(x,y)) # shape = broadcast_shape(torch.Size([200]), ys.shape[:-1]) + (-1,) # WL: ok for analyser? ===== shape = ys.shape[:-1] + (-1, ) hidden_z = softplus( encoder_z_fst( torch.cat([ torch.Tensor.expand(xs, shape), torch.Tensor.expand(ys, shape) ], -1))) # ========================== loc = encoder_z_out1(hidden_z) scale = torch.exp(encoder_z_out2(hidden_z)) pyro.sample("z", Normal(loc, scale).to_event(1))
def model(): p = torch.tensor([0.25] * 4) pyro.sample("z", OneHotCategorical(probs=p))
def model(self, x, zs): # pylint: disable=too-many-locals, too-many-statements dataset = require("dataloader").dataset def _compute_rim(decoded): shared_representation = get_module( "metagene_shared", lambda: torch.nn.Sequential( torch.nn.Conv2d( decoded.shape[1], decoded.shape[1], kernel_size=1 ), torch.nn.BatchNorm2d(decoded.shape[1], momentum=0.05), torch.nn.LeakyReLU(0.2, inplace=True), ), )(decoded) rim = torch.cat( [ get_module( f"decoder_{_encode_metagene_name(n)}", partial( self._create_metagene_decoder, decoded.shape[1], n ), )(shared_representation) for n in self.metagenes ], dim=1, ) rim = torch.nn.functional.softmax(rim, dim=1) return rim decoded = self._decode(zs) label = center_crop(x["label"], [None, *decoded.shape[-2:]]) rim = checkpoint(_compute_rim, decoded) rim = center_crop(rim, [None, None, *label.shape[-2:]]) rim = pyro.sample("rim", Delta(rim)) scale = pyro.sample( "scale", Delta( center_crop( self._get_scale_decoder(decoded.shape[1])(decoded), [None, None, *label.shape[-2:]], ) ), ) rim = scale * rim rate_mg_prior = Normal( 0.0, 1e-8 + get_param( "rate_mg_prior_sd", lambda: torch.ones(len(self._allocated_genes)), constraint=constraints.positive, ), ) with pyro.poutine.scale(scale=len(x["data"]) / dataset.size()): rate_mg = torch.stack( [ pyro.sample( _encode_metagene_name(n), rate_mg_prior, infer={"is_global": True}, ) for n in self.metagenes ] ) rate_mg = pyro.sample("rate_mg", Delta(rate_mg)) rate_g_conditions_prior = Normal( 0.0, 1e-8 + get_param( "rate_g_conditions_prior_sd", lambda: torch.ones(len(self._allocated_genes)), constraint=constraints.positive, ), ) logits_g_conditions_prior = Normal( 0.0, 1e-8 + get_param( "logits_g_conditions_prior_sd", lambda: torch.ones(len(self._allocated_genes)), constraint=constraints.positive, ), ) rate_g, logits_g = [], [] for batch_idx, (slide, covariates) in enumerate( zip(x["slide"], x["covariates"]) ): rate_g_slide = get_param( "rate_g_condition_baseline", lambda: self.__init_rate_baseline().log(), lr_multiplier=5.0, ) logits_g_slide = get_param( "logits_g_condition_baseline", self.__init_logits_baseline, lr_multiplier=5.0, ) for covariate, condition in covariates.items(): try: conditions = get("covariates")[covariate] except KeyError: continue if pd.isna(condition): with pyro.poutine.scale( scale=1.0 / dataset.size(slide=slide) ): pyro.sample( f"condition-{covariate}-{batch_idx}", OneHotCategorical( to_device(torch.ones(len(conditions))) / len(conditions) ), infer={"is_global": True}, ) # ^ NOTE 1: This statement affects the ELBO but not its # gradient. The pmf is non-differentiable but # it doesn't matter---our prior over the # conditions is uniform; even if a gradient # existed, it would always be zero. # ^ NOTE 2: The result is used to index the effect of # the condition. However, this takes place in # the guide to avoid sampling effets that are # not used in the current minibatch, # potentially (?) reducing noise in the # learning signal. Therefore, the result here # is discarded. condition_scale = 1e-99 # ^ HACK: Pyro requires scale > 0 else: condition_scale = 1.0 / dataset.size( covariate=covariate, condition=condition ) with pyro.poutine.scale(scale=condition_scale): rate_g_slide = rate_g_slide + pyro.sample( f"rate_g_condition-{covariate}-{batch_idx}", rate_g_conditions_prior, infer={"is_global": True}, ) logits_g_slide = logits_g_slide + pyro.sample( f"logits_g_condition-{covariate}-{batch_idx}", logits_g_conditions_prior, infer={"is_global": True}, ) rate_g.append(rate_g_slide) logits_g.append(logits_g_slide) logits_g = torch.stack(logits_g)[:, self._gene_indices] rate_g = torch.stack(rate_g)[:, self._gene_indices] rate_mg = rate_g.unsqueeze(1) + rate_mg[:, self._gene_indices] with scope(prefix=self.tag): self._sample_image(x, decoded) for i, (data, label, rim, rate_mg, logits_g) in enumerate( zip(x["data"], label, rim, rate_mg, logits_g) ): zero_count_idxs = 1 + torch.where(data.sum(1) == 0)[0] partial_idxs = np.unique( torch.cat([label[0], label[-1], label[:, 0], label[:, -1]]) .cpu() .numpy() ) partial_idxs = np.setdiff1d( partial_idxs, zero_count_idxs.cpu().numpy() ) mask = np.invert( np.isin(label.cpu().numpy(), [0, *partial_idxs]) ) mask = torch.as_tensor(mask, device=label.device) if not mask.any(): continue label = label[mask] idxs, label = torch.unique(label, return_inverse=True) data = data[idxs - 1] pyro.sample(f"idx-{i}", Delta(idxs.float())) rim = rim[:, mask] labelonehot = sparseonehot(label) rim = torch.sparse.mm(labelonehot.t().float(), rim.t()) rsg = rim @ rate_mg.exp() expression_distr = NegativeBinomial( total_count=1e-8 + rsg, logits=logits_g ) pyro.sample(f"xsg-{i}", expression_distr, obs=data)
def model(self, x, zs): # pylint: disable=too-many-locals def _compute_rim(decoded): shared_representation = get_module( "metagene_shared", lambda: torch.nn.Sequential( torch.nn.Conv2d( decoded.shape[1], decoded.shape[1], kernel_size=1), torch.nn.BatchNorm2d(decoded.shape[1], momentum=0.05), torch.nn.LeakyReLU(0.2, inplace=True), ), )(decoded) rim = torch.cat( [ get_module( f"decoder_{_encode_metagene_name(n)}", partial(self._create_metagene_decoder, decoded.shape[1], n), )(shared_representation) for n in self.metagenes ], dim=1, ) rim = torch.nn.functional.softmax(rim, dim=1) return rim num_genes = x["data"][0].shape[1] decoded = self._decode(zs) label = center_crop(x["label"], [None, *decoded.shape[-2:]]) rim = checkpoint(_compute_rim, decoded) rim = center_crop(rim, [None, None, *label.shape[-2:]]) rim = p.sample("rim", Delta(rim)) scale = p.sample( "scale", Delta( center_crop( self._get_scale_decoder(decoded.shape[1])(decoded), [None, None, *label.shape[-2:]], )), ) rim = scale * rim with p.poutine.scale(scale=len(x["data"]) / self.n): rate_mg_prior = Normal( 0.0, 1e-8 + get_param( "rate_mg_prior_sd", lambda: torch.ones(num_genes), constraint=constraints.positive, ), ) rate_mg = torch.stack([ p.sample(_encode_metagene_name(n), rate_mg_prior) for n in self.metagenes ]) rate_mg = p.sample("rate_mg", Delta(rate_mg)) rate_g_effects_baseline = get_param( "rate_g_effects_baseline", lambda: self.__init_rate_baseline().log(), lr_multiplier=5.0, ) logits_g_effects_baseline = get_param( "logits_g_effects_baseline", # pylint: disable=unnecessary-lambda self.__init_logits_baseline, lr_multiplier=5.0, ) rate_g_effects_prior = Normal( 0.0, 1e-8 + get_param( "rate_g_effects_prior_sd", lambda: torch.ones(num_genes), constraint=constraints.positive, ), ) rate_g_effects = p.sample("rate_g_effects", rate_g_effects_prior) rate_g_effects = torch.cat( [rate_g_effects_baseline.unsqueeze(0), rate_g_effects]) logits_g_effects_prior = Normal( 0.0, 1e-8 + get_param( "logits_g_effects_prior_sd", lambda: torch.ones(num_genes), constraint=constraints.positive, ), ) logits_g_effects = p.sample( "logits_g_effects", logits_g_effects_prior, ) logits_g_effects = torch.cat( [logits_g_effects_baseline.unsqueeze(0), logits_g_effects]) effects = [] for covariate, vals in require("covariates"): effect = p.sample( f"effect-{covariate}", OneHotCategorical( to_device(torch.ones(len(vals))) / len(vals)), ) effects.append(effect) effects = torch.cat( [ to_device(torch.ones(x["effects"].shape[0], 1)), *effects, ], 1, ).float() logits_g = effects @ logits_g_effects rate_g = effects @ rate_g_effects rate_mg = rate_g[:, None] + rate_mg with scope(prefix=self.tag): image_distr = self._sample_image(x, decoded) def _compute_sample_params(data, label, rim, rate_mg, logits_g): zero_count_idxs = 1 + torch.where(data.sum(1) == 0)[0] partial_idxs = np.unique( torch.cat([label[0], label[-1], label[:, 0], label[:, -1]]).cpu().numpy()) partial_idxs = np.setdiff1d(partial_idxs, zero_count_idxs.cpu().numpy()) mask = np.invert( np.isin(label.cpu().numpy(), [0, *partial_idxs])) mask = torch.as_tensor(mask, device=label.device) if not mask.any(): return ( data[[]], torch.zeros(0, num_genes).to(rim), logits_g.expand(0, -1), ) label = label[mask] - 1 idxs, label = torch.unique(label, return_inverse=True) data = data[idxs] rim = rim[:, mask] labelonehot = sparseonehot(label) rim = torch.sparse.mm(labelonehot.t().float(), rim.t()) rgs = rim @ rate_mg.exp() return data, rgs, logits_g.expand(len(rgs), -1) data, rgs, logits_g = zip(*it.starmap( _compute_sample_params, zip(x["data"], label, rim, rate_mg, logits_g), )) expression_distr = NegativeBinomial( total_count=1e-8 + torch.cat(rgs), logits=torch.cat(logits_g), ) p.sample("xsg", expression_distr, obs=torch.cat(data)) return image_distr, expression_distr