コード例 #1
0
def gumbel_max_sampler(logits, temperature, rng):
    """Sample fom categorical distribution using Gumbel-Max trick.

    Gumbel-Max trick:
    https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
    https://arxiv.org/abs/1411.0030

  Args:
    logits: Unnormalized logits for categorical distribution.
      [batch x n_mutations_to_sample x n_mutation_types]
    temperature: temperature parameter for Gumbel-Max. The lower the
      temperature, the closer the sample is to one-hot-encoding.
    rng: Jax random number generator

  Returns:
    class_assignments: Sampled class assignments [batch]
    log_likelihoods: Log-likelihoods of the sampled mutations [batch]
  """

    # Normalize the logits
    logits = logsoftmax(logits)

    gumbel_noise = jrand.gumbel(rng, logits.shape)
    softmax_logits = (logits + gumbel_noise) / temperature
    soft_assignments = softmax(softmax_logits, -1)
    class_assignments = jnp.argmax(soft_assignments, -1)
    assert len(class_assignments.shape) == 2
    # Output shape: [batch x num_mutations]

    return class_assignments
コード例 #2
0
 def log_sample_jax(self, x, u, n, alpha=1, xg=None, ug=None):
     x = np.concatenate((x, u), 1)
     c = vmap(self.c.cost_jax)(x)
     costs = c.reshape(-1, 1)
     samples = gumbel(self.key, (costs.shape[0], n))
     choices = np.argmax(samples + alpha * costs, 0)
     return choices, alpha * costs
コード例 #3
0
ファイル: random_test.py プロジェクト: haokeqiudu/jax
  def testGumbel(self, dtype):
    key = random.PRNGKey(0)
    rand = lambda key: random.gumbel(key, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key)
    compiled_samples = crand(key)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gumbel_r().cdf)
コード例 #4
0
def _categorical_jax(logits, num_samples, dtype=None, seed=None, name=None):  # pylint: disable=unused-argument
  dtype = utils.numpy_dtype(dtype or np.int64)
  if not hasattr(logits, 'shape') or not hasattr(logits, 'dtype'):
    logits = np.array(logits, np.float32)
  import jax.random as jaxrand  # pylint: disable=g-import-not-at-top
  if seed is None:
    raise ValueError('Must provide PRNGKey to sample in JAX.')
  z = jaxrand.gumbel(
      key=seed, shape=logits.shape + (num_samples,), dtype=logits.dtype)
  return np.argmax(np.expand_dims(logits, -1) + z, axis=-2).astype(dtype)
コード例 #5
0
 def update_policy(self, particles, weights):
     """
     """
     resampled_particles = []
     for p, w in zip(particles, weights):
         self.key, sk = random.split(self.key)
         samples = random.gumbel(sk, (len(p), len(p)))
         choices = np.argmax(samples + w.reshape(-1,1), 0)
         resampled_particles.append(np.take(p, choices, 0))
     resampled_particles = np.concatenate(resampled_particles, 0)
     self.policy.update_parameters(resampled_particles, np.zeros_like(resampled_particles[:, 1]))
コード例 #6
0
ファイル: pixelcnn.py プロジェクト: yueyedeai/jaxnet
def _gumbel_max(rng, logit_probs):
    return np.argmax(random.gumbel(rng, logit_probs.shape, logit_probs.dtype) +
                     logit_probs,
                     axis=0)
コード例 #7
0
def categorical_logits(key, logits, shape=()):
    shape = shape or logits.shape[:-1]
    return np.argmax(
        random.gumbel(key, shape + logits.shape[-1:], logits.dtype) + logits,
        axis=-1)
コード例 #8
0
def sample_gumbel(rng, shape, n=0):
    rng, rng_input = random.split(rng)
    shape = shape if n == 0 else (n,) + shape
    return rng, random.gumbel(rng_input, shape=shape)
コード例 #9
0
ファイル: random.py プロジェクト: PKU-NIP-Lab/BrainPy
def gumbel(loc=0.0, scale=1.0, size=None):
    assert loc == 0.
    assert scale == 1.
    return JaxArray(jr.gumbel(DEFAULT.split_key(), shape=_size2shape(size)))
コード例 #10
0
ファイル: continuous.py プロジェクト: kumsh/numpyro
 def sample(self, key, sample_shape=()):
     standard_gumbel_sample = random.gumbel(
         key, shape=sample_shape + self.batch_shape + self.event_shape)
     return self.loc + self.scale * standard_gumbel_sample