コード例 #1
0
    def trans_dist(self, value=None):
        assert self.word_len > 0

        if value is None:
            value = jnp.eye(self.word_len)  # transition-probability matrix
            value = jnp.roll(value, 1, axis=1)
        self._trans_dist = distrax.Categorical(probs=value)
コード例 #2
0
 def model(self, value):
     mixing_coeffs, means, covariances = value
     components_distribution = distrax.as_distribution(
         tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(loc=means,
                                                                           covariance_matrix=covariances,
                                                                           validate_args=True))
     self._model = MixtureSameFamily(mixture_distribution=distrax.Categorical(probs=mixing_coeffs),
                                     components_distribution=components_distribution)
コード例 #3
0
ファイル: distributions.py プロジェクト: deepmind/rlax
def categorical_sample(key, probs):
    """Sample from a set of discrete probabilities."""
    warnings.warn(
        "Rlax categorical_sample will be deprecated. "
        "Please use distrax.Categorical.sample instead.",
        PendingDeprecationWarning,
        stacklevel=2)
    return distrax.Categorical(probs=probs).sample(seed=key)
コード例 #4
0
 def model(self, value):
     mixing_coeffs_logits, probs_logits = value
     self._model = MixtureSameFamily(
         mixture_distribution=distrax.Categorical(
             logits=mixing_coeffs_logits),
         components_distribution=distrax.Independent(
             distrax.Bernoulli(logits=probs_logits),
             reinterpreted_batch_ndims=1))
コード例 #5
0
    def test_posterior_marginal(self):
        mix_dist_probs = jnp.array([0.1, 0.9])
        component_dist_probs = jnp.array([[.2, .3, .5], [.7, .2, .1]])
        bm = MixtureSameFamily(
            mixture_distribution=distrax.Categorical(probs=mix_dist_probs),
            components_distribution=distrax.Categorical(
                probs=component_dist_probs))

        marginal_dist = bm.posterior_marginal(jnp.array([0., 1., 2.]))
        marginals = marginal_dist.probs

        self.assertEqual((3, 2), marginals.shape)

        expected_marginals = jnp.array([[(.1 * .2) / (.1 * .2 + .9 * .7),
                                         (.9 * .7) / (.1 * .2 + .9 * .7)],
                                        [(.1 * .3) / (.1 * .3 + .9 * .2),
                                         (.9 * .2) / (.1 * .3 + .9 * .2)],
                                        [(.1 * .5) / (.1 * .5 + .9 * .1),
                                         (.9 * .1) / (.1 * .5 + .9 * .1)]])

        self.assertAllClose(marginals, expected_marginals)
コード例 #6
0
ファイル: distributions.py プロジェクト: deepmind/rlax
def categorical_importance_sampling_ratios(pi_logits_t: Array,
                                           mu_logits_t: Array,
                                           a_t: Array) -> Array:
    """Compute importance sampling ratios from logits.

  Args:
    pi_logits_t: unnormalized logits at time t for the target policy.
    mu_logits_t: unnormalized logits at time t for the behavior policy.
    a_t: actions at time t.

  Returns:
    importance sampling ratios.
  """
    warnings.warn(
        "Rlax categorical_importance_sampling_ratios will be deprecated. "
        "Please use distrax.importance_sampling_ratios instead.",
        PendingDeprecationWarning,
        stacklevel=2)
    return distrax.importance_sampling_ratios(distrax.Categorical(pi_logits_t),
                                              distrax.Categorical(mu_logits_t),
                                              a_t)
コード例 #7
0
    def test_posterior_mode(self):
        mix_dist_probs = jnp.array([[0.5, 0.5], [0.01, 0.99]])
        locs = jnp.array([[-1., 1.], [-1., 1.]])
        scale = jnp.array([1.])

        gm = MixtureSameFamily(
            mixture_distribution=distrax.Categorical(probs=mix_dist_probs),
            components_distribution=distrax.Normal(loc=locs, scale=scale))

        mode = gm.posterior_mode(jnp.array([[1.], [-1.], [-6.]]))

        self.assertEqual((3, 2), mode.shape)
        self.assertAllClose(jnp.array([[1, 1], [0, 1], [0, 0]]), mode)
コード例 #8
0
ファイル: distributions.py プロジェクト: deepmind/rlax
def categorical_cross_entropy(labels: Array, logits: Array) -> Array:
    """Computes the softmax cross entropy between sets of logits and labels.

  See "Deep Learning" by Goodfellow et al.
  (http://www.deeplearningbook.org/contents/prob.html). The computation is
  equivalent to:

                  sum_i (labels_i * log_softmax(logits_i))

  Args:
    labels: a valid probability distribution (non-negative, sum to 1).
    logits: unnormalized log probabilities.

  Returns:
    a scalar loss.
  """
    warnings.warn(
        "Rlax categorical_cross_entropy will be deprecated. "
        "Please use distrax.Categorical.cross_entropy instead.",
        PendingDeprecationWarning,
        stacklevel=2)
    return distrax.Categorical(probs=labels).cross_entropy(
        distrax.Categorical(logits=logits))
コード例 #9
0
    def posterior_marginal(self, observations):
        '''Compute the marginal posterior distribution for a batch of observations.
        https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/MixtureSameFamily?version=nightly#posterior_marginal

        Parameters
        ----------
          observations:
            An array representing observations from the mixture. Must
            be broadcastable with the mixture's batch shape.

        Returns
        -------
          * array
            Posterior marginals that is a `Categorical` distribution object representing
            the marginal probability of the components of the mixture. The batch
            shape of the `Categorical` will be the broadcast shape of `observations`
            and the mixture batch shape; the number of classes will equal the
            number of mixture components.
        '''
        return distrax.Categorical(logits=self._per_mixture_component_log_prob(observations))
コード例 #10
0
 def model(self, value):
     mixing_coeffs, probs = value
     self._model = MixtureSameFamily(mixture_distribution=distrax.Categorical(probs=mixing_coeffs),
                                     components_distribution=distrax.Independent(distrax.Bernoulli(probs=probs)))
コード例 #11
0
A = uniform(key_A, (n_hidden, n_hidden))
A = A / jnp.sum(A, axis=1)

# observation matrix
B = uniform(key_B, (n_hidden, n_obs))
B = B / jnp.sum(B, axis=1).reshape((-1, 1))

n_samples = 1000
init_state_dist = jnp.ones(n_hidden) / n_hidden

seed = 0
rng_key = PRNGKey(seed)

params_numpy = HMMNumpy(A, B, init_state_dist)
params_jax = HMMJax(A, B, init_state_dist)
hmm_distrax = HMM(trans_dist=distrax.Categorical(probs=A),
                  obs_dist=distrax.Categorical(probs=B),
                  init_dist=distrax.Categorical(probs=init_state_dist))

z_hist, x_hist = hmm_sample_jax(params_jax, n_samples, rng_key)

start = time.time()
alphas_np, _, gammas_np, loglikelihood_np = hmm_forwards_backwards_numpy(
    params_numpy, x_hist, len(x_hist))
print(
    f'Time taken by numpy version of forwards backwards : {time.time()-start}s'
)

start = time.time()
alphas_jax, _, gammas_jax, loglikelihood_jax = hmm_forwards_backwards_jax(
    params_jax, jnp.array(x_hist), len(x_hist))
コード例 #12
0
# state transition matrix
n_hidden, n_obs = 100, 10
A = uniform(key_A, (n_hidden, n_hidden))
A = A / jnp.sum(A, axis=1)

# observation matrix
B = uniform(key_B, (n_hidden, n_obs))
B = B / jnp.sum(B, axis=1).reshape((-1, 1))

n_samples = 1000
init_state_dist = jnp.ones(n_hidden) / n_hidden

seed = 0
rng_key = PRNGKey(seed)

hmm = HMM(trans_dist=distrax.Categorical(probs=A),
          obs_dist=distrax.Categorical(probs=B),
          init_dist=distrax.Categorical(probs=init_state_dist))

hmm_distrax = distrax.HMM(trans_dist=distrax.Categorical(probs=A),
                          obs_dist=distrax.Categorical(probs=B),
                          init_dist=distrax.Categorical(probs=init_state_dist))

z_hist, x_hist = hmm_sample_log(hmm, n_samples, rng_key)

start = time.time()
alphas, _, gammas, loglikelihood = hmm_distrax.forward_backward(
    x_hist, len(x_hist))
print(
    f'Time taken by Forwards Backwards function of HMM general: {time.time()-start}s'
)
コード例 #13
0
# state transition matrix
n_hidden, n_obs = 100, 10
A = uniform(key_A, (n_hidden, n_hidden))
A = A / jnp.sum(A, axis=1)

# observation matrix
B = uniform(key_B, (n_hidden, n_obs))
B = B / jnp.sum(B, axis=1).reshape((-1, 1))

n_samples = 1000
init_state_dist = jnp.ones(n_hidden) / n_hidden

seed = 0
rng_key = PRNGKey(seed)

params = HMM(distrax.Categorical(probs=A), distrax.Categorical(probs=B),
             distrax.Categorical(probs=init_state_dist))
params_jax = HMMJax(A, B, init_state_dist)

z_hist, x_hist = hmm_sample_jax(params_jax, n_samples, rng_key)

start = time.time()
alphas, _, gammas, loglikelihood = hmm_forwards_backwards(
    params, x_hist, len(x_hist))
print(f'Time taken by HMM general forwards : {time.time()-start}s')
print(f'Loglikelihood of HMM general : {loglikelihood}')

start = time.time()
alphas_jax, _, gammas_jax, loglikelihood_jax = hmm_forwards_backwards_jax(
    params_jax, jnp.array(x_hist), len(x_hist))
print(f'Time taken by HMM discrete forwards backwards: {time.time()-start}s')
コード例 #14
0
 def init_dist(self, value=None):
     if value is None:
         value = jnp.append(jnp.ones((1, )), jnp.zeros(
             (self.word_len - 1, )))
     self._init_dist = distrax.Categorical(probs=value)
コード例 #15
0
ファイル: distributions.py プロジェクト: deepmind/rlax
 def sample_fn(key: Array, logits: Array):
     probs = distrax.Softmax(logits=logits, temperature=temperature).probs
     return distrax.Categorical(
         probs=_mix_with_uniform(probs, epsilon)).sample(seed=key)
コード例 #16
0
    initial_probs = jnp.array([0.3, 0.2, 0.5])

    # transition matrix
    A = jnp.array([[0.3, 0.4, 0.3], [0.1, 0.6, 0.3], [0.2, 0.3, 0.5]])

    S1 = jnp.array([[1.1, 0], [0, 0.3]])

    S2 = jnp.array([[0.3, -0.5], [-0.5, 1.3]])

    S3 = jnp.array([[0.8, 0.4], [0.4, 0.5]])

    cov_collection = jnp.array([S1, S2, S3]) / 60
    mu_collection = jnp.array([[0.3, 0.3], [0.8, 0.5], [0.3, 0.8]])

    hmm = HMM(
        trans_dist=distrax.Categorical(probs=A),
        init_dist=distrax.Categorical(probs=initial_probs),
        obs_dist=distrax.as_distribution(
            tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(
                loc=mu_collection, covariance_matrix=cov_collection)))
    n_samples, seed = 50, 100
    samples_state, samples_obs = hmm_sample(hmm, n_samples, PRNGKey(seed))

    xmin, xmax = 0, 1
    ymin, ymax = 0, 1.2
    colors = ["tab:green", "tab:blue", "tab:red"]

    fig, ax = plt.subplots()
    _, color_sample = plot_2dhmm(hmm, samples_obs, samples_state, colors, ax,
                                 xmin, xmax, ymin, ymax)
    pml.savefig("hmm_lillypad_2d.pdf")
コード例 #17
0
 def class_priors(self, value):
     self._class_priors = distrax.Categorical(probs=value)
コード例 #18
0
ファイル: distributions.py プロジェクト: deepmind/rlax
 def log_prob_fn(sample: Array, logits: Array):
     probs = distrax.Softmax(logits=logits, temperature=temperature).probs
     return distrax.Categorical(
         probs=_mix_with_uniform(probs, epsilon)).log_prob(sample)
コード例 #19
0
ファイル: distributions.py プロジェクト: deepmind/rlax
 def entropy_fn(logits: Array):
     probs = distrax.Softmax(logits=logits, temperature=temperature).probs
     return distrax.Categorical(
         probs=_mix_with_uniform(probs, epsilon)).entropy()