예제 #1
0
 def rsample(self, sample_shape=torch.Size()):
     shape = self._extended_shape(sample_shape)
     probs = clamp_probs(self.probs.expand(shape))
     uniforms = clamp_probs(
         torch.rand(shape, dtype=probs.dtype, device=probs.device))
     return (uniforms.log() - (-uniforms).log1p() + probs.log() -
             (-probs).log1p()) / self.temperature
예제 #2
0
    def sample(self, probs=None, logits=None, temperature=.67):

        if (probs is not None) and (logits is None):
            params = clamp_probs(probs).log()
        elif (probs is None) and (logits is not None):
            params = logits
        else:
            raise ValueError("either probs or logits should be given")

        unif = clamp_probs(torch.rand(params.size()).type_as(params))
        gumbel = -( ( -(unif.log()) ).log() )
        scores = (params + gumbel) / temperature
        scores = scores - scores.logsumexp(dim=-1, keepdim=True)
        return scores.exp()
예제 #3
0
    def rsample(self, temperature=None, gumbel_noise=None):
        if gumbel_noise is None:
            with torch.no_grad():
                uniforms = torch.empty_like(self.probs).uniform_()
                uniforms = distr_utils.clamp_probs(uniforms)
                gumbel_noise = -(-uniforms.log()).log()
            # TODO(serhii): This is used for debugging (to get the same samples) and is not differentiable.
            # gumbel_noise = None
            # _sample = self.cat_distr.sample()
            # sample = torch.zeros_like(self.probs)
            # sample.scatter_(-1, _sample[:, None], 1.0)
            # return sample, gumbel_noise

        elif gumbel_noise.shape != self.probs.shape:
            raise ValueError
        # TODO(siyu) what does temperature mean
        if temperature is None:
            with torch.no_grad():
                scores = (self.logits + gumbel_noise)
                scores = Categorical.masked_softmax(scores, self.mask)
                sample = torch.zeros_like(scores)
                sample.scatter_(-1, scores.argmax(dim=-1, keepdim=True), 1.0)
                return sample, gumbel_noise
        else:
            scores = (self.logits + gumbel_noise) / temperature
            sample = Categorical.masked_softmax(scores, self.mask)
            return sample, gumbel_noise
예제 #4
0
 def __init__(self,
              probs=None,
              logits=None,
              lims=(0.499, 0.501),
              validate_args=None):
     if (probs is None) == (logits is None):
         raise ValueError(
             "Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         is_scalar = isinstance(probs, Number)
         self.probs, = broadcast_all(probs)
         # validate 'probs' here if necessary as it is later clamped for numerical stability
         # close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
         if validate_args is not None:
             if not self.arg_constraints['probs'].check(
                     getattr(self, 'probs')).all():
                 raise ValueError(
                     "The parameter {} has invalid values".format('probs'))
         self.probs = clamp_probs(self.probs)
     else:
         is_scalar = isinstance(logits, Number)
         self.logits, = broadcast_all(logits)
     self._param = self.probs if probs is not None else self.logits
     if is_scalar:
         batch_shape = torch.Size()
     else:
         batch_shape = self._param.size()
     self._lims = lims
     super(ContinuousBernoulli, self).__init__(batch_shape,
                                               validate_args=validate_args)
예제 #5
0
 def rsample(self, sample_shape=torch.Size()):
     sample_shape = torch.Size(sample_shape)
     uniforms = clamp_probs(
         self.logits.new(self._extended_shape(sample_shape)).uniform_())
     gumbels = -((-(uniforms.log())).log())
     scores = (self.logits + gumbels) / self.temperature
     return scores - scores.logsumexp(dim=-1, keepdim=True)
예제 #6
0
    def log_density(self, sample, probs=None, logits=None, temperature=.67):
        temperature = torch.tensor([temperature]).type_as(sample)
        log_scale = torch.full_like(temperature, float(self.K)).lgamma() + temperature.log().mul(self.K - 1)

        if (probs is not None) and (logits is None):
            params = clamp_probs(probs).log()
        elif (probs is None) and (logits is not None):
            params = logits
        else:
            raise ValueError("either probs or logits should be given")

        sample = clamp_probs(sample)
        score = params - sample.log().mul(temperature)
        score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
        score = score + log_scale #

        return score - sample.log().sum(-1)
예제 #7
0
 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     probs = clamp_probs(self.probs)
     log_factorial_n = math.lgamma(self.total_count + 1)
     log_factorial_k = torch.lgamma(value + 1)
     log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
     return (log_factorial_n - log_factorial_k - log_factorial_nmk +
             value * self.logits + self.total_count * torch.log1p(-probs))
예제 #8
0
    def sample_continous(self, logits):
        l_shape = (logits.shape[0], self.k, logits.shape[2])
        u = clamp_probs(torch.rand(l_shape, device=logits.device))
        gumbel = -torch.log(-torch.log(u))
        noisy_logits = (gumbel + logits) / self.T
        samples = F.softmax(noisy_logits, dim=-1)
        samples = torch.max(samples, dim=1)[0]

        return samples
예제 #9
0
 def rsample(self, sample_shape=torch.Size()):
     shape = self._extended_shape(sample_shape)
     uniforms = clamp_probs(
         torch.rand(shape,
                    dtype=self.logits.dtype,
                    device=self.logits.device))
     gumbels = -((-(uniforms.log())).log())
     scores = (self.logits + gumbels) / self.temperature
     return scores - scores.logsumexp(dim=1, keepdim=True)
예제 #10
0
 def rsample(self):
     with torch.no_grad():
         uniforms = torch.empty_like(self.cat_distr.probs).uniform_()
         uniforms = distr_utils.clamp_probs(uniforms)
         # uniforms = distr_utils.clamp_probs(uniforms)
         gumbel_noise = -(-uniforms.log()).log()
         scores = (self.cat_distr.logits + gumbel_noise)
         scores = Categorical.masked_softmax(scores, self.mask)
         sample = torch.zeros_like(scores)
         sample.scatter_(-1, scores.argmax(dim=-1, keepdim=True),
                         1.0)  # --> action index with 1 others being 0
         return sample, gumbel_noise
예제 #11
0
 def continuous_topk(self, w, separate=False):
     khot_list = []
     onehot_approx = torch.zeros_like(w, dtype=torch.float32)
     for _ in range(self.k):
         khot_mask = clamp_probs(1.0 - onehot_approx)
         w += torch.log(khot_mask)
         onehot_approx = F.softmax(w / self.T, dim=-1)
         khot_list.append(onehot_approx)
     if separate:
         return khot_list
     else:
         return torch.stack(khot_list, dim=-1).sum(-1).squeeze(1)
예제 #12
0
def conditional_gumbel_rsample(
    hard_sample: torch.Tensor, distr: Distribution, temperature,
) -> torch.Tensor:
    """
    Conditionally re-samples from the distribution given the hard sample.
    This samples z \sim p(z|b), where b is the hard sample and p(z) is a gumbel distribution.
    """
    # Adapted from torch.distributions.relaxed_bernoulli and torch.distributions.relaxed_categorical
    shape = hard_sample.shape
    probs = (
        distr.probs
        if not isinstance(hard_sample, storch.Tensor)
        else distr.probs._tensor
    )
    probs = clamp_probs(probs.expand_as(hard_sample))
    v = clamp_probs(torch.rand(shape, dtype=probs.dtype, device=probs.device))
    if isinstance(distr, Bernoulli):
        pos_probs = probs[hard_sample == 1]
        v_prime = torch.zeros_like(hard_sample)
        # See https://arxiv.org/abs/1711.00123
        v_prime[hard_sample == 1] = v[hard_sample == 1] * pos_probs + (1 - pos_probs)
        v_prime[hard_sample == 0] = v[hard_sample == 0] * (1 - probs[hard_sample == 0])
        log_sample = (
            probs.log() + probs.log1p() + v_prime.log() + v_prime.log1p()
        ) / temperature
        return log_sample.sigmoid()
    # b=argmax(hard_sample)
    b = hard_sample.max(-1).indices
    # b = F.one_hot(b, hard_sample.shape[-1])

    # See https://arxiv.org/abs/1711.00123
    log_v = v.log()
    # i != b (indexing could maybe be improved here, but i doubt it'd be more efficient)
    log_v_b = torch.gather(log_v, -1, b.unsqueeze(-1))
    cond_gumbels = -(-(log_v / probs) - log_v_b).log()
    # i = b
    index_sample = hard_sample.bool()
    cond_gumbels[index_sample] = -(-log_v[index_sample]).log()
    scores = cond_gumbels / temperature
    return (scores - scores.logsumexp(dim=-1, keepdim=True)).exp()
예제 #13
0
 def rsample(self, sample_shape=torch.Size()):
     soft_sample = super(RelaxedOneHotCategoricalStraightThrough, self).rsample(sample_shape)
     soft_sample = clamp_probs(soft_sample)
     hard_sample = QuantizeCategorical.apply(soft_sample)
     return hard_sample
예제 #14
0
 def rsample(self, sample_shape=torch.Size()):
     sample_shape = torch.Size(sample_shape)
     uniforms = clamp_probs(self.logits.new(self._extended_shape(sample_shape)).uniform_())
     gumbels = -((-(uniforms.log())).log())
     scores = (self.logits + gumbels) / self.temperature
     return scores - _log_sum_exp(scores)
예제 #15
0
 def probs(self):
     return clamp_probs(logits_to_probs(self.logits, is_binary=True))
예제 #16
0
파일: util.py 프로젝트: xlchan/BNN
 def rsample(self, sample_shape = torch.Size()):
     return clamp_probs(super(StableRelaxedBernoulli, self).rsample(sample_shape))
예제 #17
0
 def rsample(self, sample_shape=torch.Size()):
     shape = self._extended_shape(sample_shape)
     probs = clamp_probs(self.probs.expand(shape))
     uniforms = clamp_probs(self.probs.new(shape).uniform_())
     return (uniforms.log() - (-uniforms).log1p() + probs.log() - (-probs).log1p()) / self.temperature
예제 #18
0
def cross_entropy_multiple_class(input: torch.FloatTensor,
                                 target: torch.FloatTensor) -> torch.Tensor:
    return torch.mean(torch.sum(-target * torch.log(clamp_probs(input)),
                                dim=1))
예제 #19
0
def _dirichlet_sample_nograd(concentration):
    probs = torch._standard_gamma(concentration)
    probs /= probs.sum(-1, True)
    return clamp_probs(probs)
예제 #20
0
 def inject_noise(self, logits):
     u = clamp_probs(torch.rand_like(logits))
     z = -torch.log(-torch.log(u))
     noisy_logits = logits + z
     return noisy_logits
예제 #21
0
파일: BNN_CDropout.py 프로젝트: xlchan/BNN
 def dropout_rate(self):
     return clamp_probs(torch.sigmoid(self.p_logit))
예제 #22
0
def _dirichlet_sample_nograd(concentration):
    probs = torch._standard_gamma(concentration)
    probs /= probs.sum(-1, True)
    return clamp_probs(probs)
예제 #23
0
 def rsample(self, sample_shape=torch.Size()):
     soft_sample = super().rsample(sample_shape)
     soft_sample = clamp_probs(soft_sample)
     hard_sample = QuantizeCategorical2D.apply(soft_sample)
     return hard_sample
예제 #24
0
 def rsample(self, sample_shape=torch.Size()):
     soft_sample = super(RelaxedBernoulliStraightThrough, self).rsample(sample_shape)
     soft_sample = clamp_probs(soft_sample)
     hard_sample = QuantizeBernoulli.apply(soft_sample)
     return hard_sample
예제 #25
0
 def variational_posterior(self, logits: torch.Tensor):
     return Bernoulli(probs=clamp_probs(logits.sigmoid()))
 def rsample(self, sample_shape=torch.Size()):
     soft_sample = super().rsample(sample_shape)
     soft_sample = clamp_probs(soft_sample)
     hard_sample = QuantizeBernoulli.apply(soft_sample)
     return hard_sample