Exemplo n.º 1
0
 def _default_event_space_bijector(self):
   # TODO(b/145620027) Finalize choice of bijector.
   return chain_bijector.Chain([
       exp_bijector.Log(validate_args=self.validate_args),
       softmax_centered_bijector.SoftmaxCentered(
           validate_args=self.validate_args),
   ], validate_args=self.validate_args)
Exemplo n.º 2
0
 def _default_event_space_bijector(self):
   # TODO(b/145620027) Finalize choice of bijector.
   return chain_bijector.Chain([
       invert_bijector.Invert(
           square_bijector.Square(validate_args=self.validate_args),
           validate_args=self.validate_args),
       softmax_centered_bijector.SoftmaxCentered(
           validate_args=self.validate_args)
   ], validate_args=self.validate_args)
Exemplo n.º 3
0
 def _compute_quantiles():
   """Helper to build quantiles."""
   # Omit {0, 1} since they might lead to Inf/NaN.
   zero = tf.zeros([], dtype=dist.dtype)
   edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1]
   # Expand edges so its broadcast across batch dims.
   edges = tf.reshape(
       edges,
       shape=tf.concat(
           [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0))
   quantiles = dist.quantile(edges)
   quantiles = softmax_centered_bijector.SoftmaxCentered().forward(quantiles)
   # Cyclically permute left by one.
   perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0)
   quantiles = tf.transpose(quantiles, perm)
   quantiles.set_shape(_get_final_shape(quadrature_size + 1))
   return quantiles
Exemplo n.º 4
0
 def _default_event_space_bijector(self):
     # TODO(b/145620027) Finalize choice of bijector.
     return softmax_centered_bijector.SoftmaxCentered(
         validate_args=self.validate_args)
 def _default_event_space_bijector(self):
     return softmax_centered_bijector.SoftmaxCentered(
         validate_args=self.validate_args)