예제 #1
0
 def rsample(self, sample_shape=torch.Size()):
     sample_shape = torch.Size(sample_shape)
     uniforms = clamp_probs(
         self.logits.new(self._extended_shape(sample_shape)).uniform_())
     gumbels = -((-(uniforms.log())).log())
     scores = (self.logits + gumbels) / self.temperature
     return scores - log_sum_exp(scores)
예제 #2
0
    def apply(self, samples, sampler):
        samples = to_pm1(samples)
        log_psis = sampler.rbm_module.effective_energy(to_01(samples)).div(2.)

        shape = log_psis.shape + (samples.shape[-1], )
        log_flipped_psis = torch.zeros(*shape,
                                       dtype=torch.double,
                                       device=sampler.rbm_module.device)

        for i in range(samples.shape[-1]):  # sum over spin sites
            self._flip_spin(i, samples)  # flip the spin at site i
            log_flipped_psis[:, i] = sampler.rbm_module.effective_energy(
                to_01(samples)).div(2.)
            self._flip_spin(i, samples)  # flip it back

        log_flipped_psis = log_sum_exp(log_flipped_psis,
                                       keepdim=True).squeeze()

        interaction_terms = ((samples[:, :-1] * samples[:, 1:]).sum(1) +
                             samples[:, 0] * samples[:, samples.shape[-1] - 1])
        # sum over spin sites

        transverse_field_terms = (log_flipped_psis.sub(log_psis).exp()
                                  )  # convert to ratio of probabilities

        energy = (transverse_field_terms.mul(
            self.h).add(interaction_terms).mul(-1.))

        if self.density:
            return energy.div(samples.shape[-1])
        else:
            return energy
예제 #3
0
 def log_prob(self, value):
     K = self._categorical._num_events
     self._validate_log_prob_arg(value)
     logits, value = broadcast_all(self.logits, value)
     log_scale = (
         self.temperature.new(self.temperature.shape).fill_(K).lgamma() -
         self.temperature.log().mul(-(K - 1)))
     score = logits - value.mul(self.temperature)
     score = (score - log_sum_exp(score)).sum(-1)
     return score + log_scale
예제 #4
0
 def log_prob(self, value):
     K = self._categorical._num_events
     if self._validate_args:
         self._validate_sample(value)
     logits, value = broadcast_all(self.logits, value)
     log_scale = (self.temperature.new(self.temperature.shape).fill_(K).lgamma() -
                  self.temperature.log().mul(-(K - 1)))
     score = logits - value.mul(self.temperature)
     score = (score - log_sum_exp(score)).sum(-1)
     return score + log_scale
예제 #5
0
 def __init__(self, probs=None, logits=None):
     if (probs is None) == (logits is None):
         raise ValueError("Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         self.probs = probs / probs.sum(-1, keepdim=True)
     else:
         self.logits = logits - log_sum_exp(logits)
     self._param = self.probs if probs is not None else self.logits
     self._num_events = self._param.size()[-1]
     batch_shape = self._param.size()[:-1]
     super(Categorical, self).__init__(batch_shape)
예제 #6
0
 def __init__(self, probs=None, logits=None):
     if (probs is None) == (logits is None):
         raise ValueError("Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         self.probs = probs / probs.sum(-1, keepdim=True)
     else:
         self.logits = logits - log_sum_exp(logits)
     self._param = self.probs if probs is not None else self.logits
     self._num_events = self._param.size()[-1]
     batch_shape = self._param.size()[:-1]
     super(Categorical, self).__init__(batch_shape)
예제 #7
0
    def apply(self, samples, sampler):
        """Computes the energy of each sample given a batch of
        samples.

        :param samples: A batch of samples to calculate the observable on.
                        Must be using the :math:`\sigma_i = 0, 1` convention.
        :type samples: torch.Tensor
        :param sampler: The sampler that drew the samples. Must implement
                        the function :func:`effective_energy`, giving the
                        log probability of its inputs (up to an additive
                        constant).
        :type sampler: qucumber.samplers.Sampler
        """
        samples = to_pm1(samples)
        log_psis = sampler.effective_energy(to_01(samples)).div(2.)

        shape = log_psis.shape + (samples.shape[-1],)
        log_flipped_psis = torch.zeros(*shape,
                                       dtype=torch.double,
                                       device=sampler.device)

        for i in range(samples.shape[-1]):  # sum over spin sites
            self._flip_spin(i, samples)  # flip the spin at site i
            log_flipped_psis[:, i] = sampler.effective_energy(
                to_01(samples)
            ).div(2.)
            self._flip_spin(i, samples)  # flip it back

        log_flipped_psis = log_sum_exp(
            log_flipped_psis, keepdim=True).squeeze()

        if self.periodic_bcs:
            perm_indices = list(range(sampler.shape[-1]))
            perm_indices = perm_indices[1:] + [0]
            interaction_terms = ((samples * samples[:, perm_indices])
                                 .sum(1))
        else:
            interaction_terms = ((samples[:, :-1] * samples[:, 1:])
                                 .sum(1))      # sum over spin sites

        transverse_field_terms = (log_flipped_psis
                                  .sub(log_psis)
                                  .exp())  # convert to ratio of probabilities

        energy = (transverse_field_terms
                  .mul(self.h)
                  .add(interaction_terms)
                  .mul(-1.))

        if self.density:
            return energy.div(samples.shape[-1])
        else:
            return energy
예제 #8
0
 def rsample(self, sample_shape=torch.Size()):
     sample_shape = torch.Size(sample_shape)
     uniforms = clamp_probs(self.logits.new(self._extended_shape(sample_shape)).uniform_())
     gumbels = -((-(uniforms.log())).log())
     scores = (self.logits + gumbels) / self.temperature
     return scores - log_sum_exp(scores)