예제 #1
0
 def _parameter_properties(cls, dtype, num_classes=None):
     # pylint: disable=g-long-lambda
     return dict(df=parameter_properties.ParameterProperties(
         shape_fn=lambda sample_shape: sample_shape[:-2],
         default_constraining_bijector_fn=parameter_properties.
         BIJECTOR_NOT_IMPLEMENTED),
                 scale_tril=parameter_properties.ParameterProperties(
                     event_ndims=2,
                     default_constraining_bijector_fn=lambda:
                     fill_scale_tril_bijector.FillScaleTriL(
                         diag_shift=dtype_util.eps(dtype))))
예제 #2
0
 def _parameter_properties(cls, dtype, num_classes=None):
     # pylint: disable=g-long-lambda
     return dict(
         loc=parameter_properties.ParameterProperties(event_ndims=1),
         scale_tril=parameter_properties.ParameterProperties(
             event_ndims=2,
             shape_fn=lambda sample_shape: ps.concat(
                 [sample_shape, sample_shape[-1:]], axis=0),
             default_constraining_bijector_fn=lambda:
             fill_scale_tril_bijector.FillScaleTriL(diag_shift=dtype_util.
                                                    eps(dtype))))
 def _parameter_properties(cls, dtype, num_classes=None):
     # pylint: disable=g-long-lambda
     return dict(
         loc=parameter_properties.ParameterProperties(event_ndims=1),
         covariance_matrix=parameter_properties.ParameterProperties(
             event_ndims=2,
             shape_fn=lambda sample_shape: ps.concat(
                 [sample_shape, sample_shape[-1:]], axis=0),
             default_constraining_bijector_fn=(
                 lambda: chain_bijector.Chain([
                     cholesky_outer_product_bijector.CholeskyOuterProduct(),
                     fill_scale_tril_bijector.FillScaleTriL(
                         diag_shift=dtype_util.eps(dtype))
                 ]))))
def _trainable_linear_operator_tril(shape,
                                    scale_initializer=1e-2,
                                    diag_bijector=None,
                                    dtype=None,
                                    name=None):
    """Build a trainable `LinearOperatorLowerTriangular` instance.

  Args:
    shape: Shape of the `LinearOperator`, equal to `[b0, ..., bn, d]`, where
      `b0...bn` are batch dimensions and `d` is the length of the diagonal.
    scale_initializer: Variables are initialized with samples from
      `Normal(0, scale_initializer)`.
    diag_bijector: Bijector to apply to the diagonal of the operator.
    dtype: `tf.dtype` of the `LinearOperator`.
    name: str, name for `tf.name_scope`.
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.
  """
    with tf.name_scope(name or 'trainable_linear_operator_tril'):
        if dtype is None:
            dtype = dtype_util.common_dtype([scale_initializer],
                                            dtype_hint=tf.float32)

        scale_initializer = tf.convert_to_tensor(scale_initializer,
                                                 dtype=dtype)
        diag_bijector = diag_bijector or _DefaultScaleDiagonal()
        batch_shape, dim = ps.split(shape, num_or_size_splits=[-1, 1])

        scale_tril_bijector = fill_scale_tril.FillScaleTriL(
            diag_bijector, diag_shift=tf.zeros([], dtype=dtype))
        scale_tril = yield trainable_state_util.Parameter(
            init_fn=lambda seed: scale_tril_bijector(  # pylint: disable=g-long-lambda
                samplers.normal(mean=0.,
                                stddev=scale_initializer,
                                shape=ps.concat(
                                    [batch_shape, dim * (dim + 1) // 2],
                                    axis=0),
                                seed=seed,
                                dtype=dtype)),
            name='scale_tril',
            constraining_bijector=scale_tril_bijector)
        return tf.linalg.LinearOperatorLowerTriangular(tril=scale_tril,
                                                       is_non_singular=True)
예제 #5
0
 def _default_event_space_bijector(self):
   # TODO(b/145620027) Finalize choice of bijector.
   tril_bijector = chain_bijector.Chain([
       transform_diagonal_bijector.TransformDiagonal(
           diag_bijector=softplus_bijector.Softplus(
               validate_args=self.validate_args),
           validate_args=self.validate_args),
       fill_scale_tril_bijector.FillScaleTriL(
           validate_args=self.validate_args)
   ], validate_args=self.validate_args)
   if self.input_output_cholesky:
     return tril_bijector
   return chain_bijector.Chain([
       cholesky_outer_product_bijector.CholeskyOuterProduct(
           validate_args=self.validate_args),
       tril_bijector
   ], validate_args=self.validate_args)
def build_trainable_linear_operator_tril(shape,
                                         scale_initializer=1e-2,
                                         diag_bijector=None,
                                         dtype=None,
                                         seed=None,
                                         name=None):
    """Build a trainable `LinearOperatorLowerTriangular` instance.

  Args:
    shape: Shape of the `LinearOperator`, equal to `[b0, ..., bn, d]`, where
      `b0...bn` are batch dimensions and `d` is the length of the diagonal.
    scale_initializer: Variables are initialized with samples from
      `Normal(0, scale_initializer)`.
    diag_bijector: Bijector to apply to the diagonal of the operator.
    dtype: `tf.dtype` of the `LinearOperator`.
    seed: Python integer to seed the random number generator.
    name: str, name for `tf.name_scope`.

  Returns:
    operator: Trainable instance of `tf.linalg.LinearOperatorLowerTriangular`.
  """
    with tf.name_scope(name or 'build_trainable_linear_operator_tril'):
        if dtype is None:
            dtype = dtype_util.common_dtype([scale_initializer],
                                            dtype_hint=tf.float32)

        scale_initializer = tf.convert_to_tensor(scale_initializer,
                                                 dtype=dtype)
        diag_bijector = diag_bijector or _DefaultScaleDiagonal()
        batch_shape, dim = ps.split(shape, num_or_size_splits=[-1, 1])

        scale_tril_bijector = fill_scale_tril.FillScaleTriL(
            diag_bijector, diag_shift=tf.zeros([], dtype=dtype))
        flat_initial_scale = samplers.normal(
            mean=0.,
            stddev=scale_initializer,
            shape=ps.concat([batch_shape, dim * (dim + 1) // 2], axis=0),
            seed=seed,
            dtype=dtype)
        return tf.linalg.LinearOperatorLowerTriangular(
            tril=tfp_util.TransformedVariable(
                scale_tril_bijector.forward(flat_initial_scale),
                bijector=scale_tril_bijector,
                name='tril'),
            is_non_singular=True)