def test_onehot_shapes(probs): temperature = torch.tensor(0.5) probs = torch.tensor(probs, requires_grad=True) d = RelaxedOneHotCategoricalStraightThrough(temperature, probs=probs) sample = d.rsample() log_prob = d.log_prob(sample) grad_probs = grad(log_prob.sum(), [probs])[0] assert grad_probs.shape == probs.shape
def rsample(self, n_samples, ret_z, temperature=0.1): oh = RelaxedOneHotCategoricalStraightThrough( logits=self.logits, temperature=temperature).rsample((n_samples, )) mus = (oh.unsqueeze(2) * self.mus).sum(dim=1) sigmas = (oh.unsqueeze(2) * self.sigmas).sum(dim=1)**2 if ret_z: return Normal(mus, sigmas).rsample( (1, )).squeeze(0), oh.argmax(dim=1) else: return Normal(mus, sigmas).rsample((1, )).squeeze(0)
def guide(): q = pyro.param("q", torch.tensor([0.1, 0.2, 0.3, 0.4]), constraint=constraints.simplex) temp = torch.tensor(0.10) pyro.sample( "z", RelaxedOneHotCategoricalStraightThrough(temperature=temp, probs=q))
def test_onehot_entropy_grad(temp): num_samples = 2000000 q = torch.tensor([0.1, 0.2, 0.3, 0.4], requires_grad=True) temp = torch.tensor(temp) dist_q = RelaxedOneHotCategorical(temperature=temp, probs=q) z = dist_q.rsample(sample_shape=(num_samples, )) expected = grad(dist_q.log_prob(z).sum(), [q])[0] / num_samples dist_q = RelaxedOneHotCategoricalStraightThrough(temperature=temp, probs=q) z = dist_q.rsample(sample_shape=(num_samples, )) actual = grad(dist_q.log_prob(z).sum(), [q])[0] / num_samples assert_equal( expected, actual, prec=0.08, msg= 'bad grad for RelaxedOneHotCategoricalStraightThrough (expected {}, got {})' .format(expected, actual))
def guide(self, x): with p.poutine.scale(scale=len(x["data"]) / self.n): self._sample_globals() for covariate, _ in require("covariates"): is_observed = x["effects"][covariate].values.any(1) effect_distr = RelaxedOneHotCategoricalStraightThrough( temperature=to_device(torch.as_tensor(0.1)), logits=torch.stack([ get_param( f"effect-{covariate}-{sample}-logits", torch.zeros(len(vals)), ) for sample, vals in x["effects"][covariate].iterrows() ]), ) with p.poutine.mask(mask=~to_device(torch.as_tensor(is_observed))): effect = p.sample(f"effect-{covariate}-all", effect_distr) effect[is_observed] = torch.as_tensor( x["effects"][covariate].values[is_observed]).to(effect) p.sample(f"effect-{covariate}", Delta(effect)) return super().guide(x)
def rsample_gumbel_softmax( distr: Distribution, n: int, temperature: torch.Tensor, straight_through: bool = False, ) -> torch.Tensor: if isinstance(distr, (Categorical, OneHotCategorical)): if straight_through: gumbel_distr = RelaxedOneHotCategoricalStraightThrough( temperature, probs=distr.probs) else: gumbel_distr = RelaxedOneHotCategorical(temperature, probs=distr.probs) elif isinstance(distr, Bernoulli): if straight_through: gumbel_distr = RelaxedBernoulliStraightThrough(temperature, probs=distr.probs) else: gumbel_distr = RelaxedBernoulli(temperature, probs=distr.probs) else: raise ValueError("Using Gumbel Softmax with non-discrete distribution") return gumbel_distr.rsample((n, ))
def _sample_condition(batch_idx, slide, covariate, condition): try: conditions = get("covariates")[covariate] except KeyError: return if pd.isna(condition): condition_distr = RelaxedOneHotCategoricalStraightThrough( temperature=to_device(torch.as_tensor(0.1)), logits=get_param( f"logits-{slide}-{covariate}", lambda: torch.zeros(len(conditions)), lr_multiplier=2.0, ), ) with pyro.poutine.scale(scale=1.0 / dataset.size(slide=slide)): condition_onehot = pyro.sample( f"condition-{covariate}-{batch_idx}", condition_distr, infer={"is_global": True}, ) condition_scale = 1e-99 # ^ HACK: Pyro requires scale > 0 else: condition_onehot = to_device(torch.eye(len(conditions)))[ np.isin(conditions, condition) ][0] condition_scale = 1.0 / dataset.size( covariate=covariate, condition=condition ) mu_rate_g_condition = condition_onehot @ get_param( f"mu_rate_g_condition-{covariate}", lambda: torch.zeros( len(conditions), len(self._allocated_genes) ), lr_multiplier=2.0, ) sd_rate_g_condition = condition_onehot @ get_param( f"sd_rate_g_condition-{covariate}", lambda: 1e-2 * torch.ones(len(conditions), len(self._allocated_genes)), constraint=constraints.positive, lr_multiplier=2.0, ) with pyro.poutine.scale(scale=condition_scale): pyro.sample( f"rate_g_condition-{covariate}-{batch_idx}", Normal(mu_rate_g_condition, 1e-8 + sd_rate_g_condition), infer={"is_global": True}, ) mu_logits_g_condition = condition_onehot @ get_param( f"mu_logits_g_condition-{covariate}", lambda: torch.zeros( len(conditions), len(self._allocated_genes) ), lr_multiplier=2.0, ) sd_logits_g_condition = condition_onehot @ get_param( f"sd_logits_g_condition-{covariate}", lambda: 1e-2 * torch.ones(len(conditions), len(self._allocated_genes)), constraint=constraints.positive, lr_multiplier=2.0, ) with pyro.poutine.scale(scale=condition_scale): pyro.sample( f"logits_g_condition-{covariate}-{batch_idx}", Normal( mu_logits_g_condition, 1e-8 + sd_logits_g_condition ), infer={"is_global": True}, )
def forward(self, s_t): hidden = self.s_to_hidden(s_t) probs = F.softmax(self.hidden_to_scores(hidden), dim=-1) a_t = RelaxedOneHotCategoricalStraightThrough(0.5, probs=probs).rsample() return a_t