def _default_event_space_bijector(self): # TODO(b/145620027) Finalize choice of bijector. return chain_bijector.Chain([ exp_bijector.Log(validate_args=self.validate_args), softmax_centered_bijector.SoftmaxCentered( validate_args=self.validate_args), ], validate_args=self.validate_args)
def _default_event_space_bijector(self): # TODO(b/145620027) Finalize choice of bijector. return chain_bijector.Chain([ invert_bijector.Invert( square_bijector.Square(validate_args=self.validate_args), validate_args=self.validate_args), softmax_centered_bijector.SoftmaxCentered( validate_args=self.validate_args) ], validate_args=self.validate_args)
def _compute_quantiles(): """Helper to build quantiles.""" # Omit {0, 1} since they might lead to Inf/NaN. zero = tf.zeros([], dtype=dist.dtype) edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1] # Expand edges so its broadcast across batch dims. edges = tf.reshape( edges, shape=tf.concat( [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0)) quantiles = dist.quantile(edges) quantiles = softmax_centered_bijector.SoftmaxCentered().forward(quantiles) # Cyclically permute left by one. perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0) quantiles = tf.transpose(quantiles, perm) quantiles.set_shape(_get_final_shape(quadrature_size + 1)) return quantiles
def _default_event_space_bijector(self): # TODO(b/145620027) Finalize choice of bijector. return softmax_centered_bijector.SoftmaxCentered( validate_args=self.validate_args)
def _default_event_space_bijector(self): return softmax_centered_bijector.SoftmaxCentered( validate_args=self.validate_args)