def _parameter_properties(cls, dtype, num_classes=None): # pylint: disable=g-long-lambda return dict(df=parameter_properties.ParameterProperties( shape_fn=lambda sample_shape: sample_shape[:-2], default_constraining_bijector_fn=parameter_properties. BIJECTOR_NOT_IMPLEMENTED), scale_tril=parameter_properties.ParameterProperties( event_ndims=2, default_constraining_bijector_fn=lambda: fill_scale_tril_bijector.FillScaleTriL( diag_shift=dtype_util.eps(dtype))))
def _parameter_properties(cls, dtype, num_classes=None): # pylint: disable=g-long-lambda return dict( loc=parameter_properties.ParameterProperties(event_ndims=1), scale_tril=parameter_properties.ParameterProperties( event_ndims=2, shape_fn=lambda sample_shape: ps.concat( [sample_shape, sample_shape[-1:]], axis=0), default_constraining_bijector_fn=lambda: fill_scale_tril_bijector.FillScaleTriL(diag_shift=dtype_util. eps(dtype))))
def _parameter_properties(cls, dtype, num_classes=None): # pylint: disable=g-long-lambda return dict( loc=parameter_properties.ParameterProperties(event_ndims=1), covariance_matrix=parameter_properties.ParameterProperties( event_ndims=2, shape_fn=lambda sample_shape: ps.concat( [sample_shape, sample_shape[-1:]], axis=0), default_constraining_bijector_fn=( lambda: chain_bijector.Chain([ cholesky_outer_product_bijector.CholeskyOuterProduct(), fill_scale_tril_bijector.FillScaleTriL( diag_shift=dtype_util.eps(dtype)) ]))))
def _trainable_linear_operator_tril(shape, scale_initializer=1e-2, diag_bijector=None, dtype=None, name=None): """Build a trainable `LinearOperatorLowerTriangular` instance. Args: shape: Shape of the `LinearOperator`, equal to `[b0, ..., bn, d]`, where `b0...bn` are batch dimensions and `d` is the length of the diagonal. scale_initializer: Variables are initialized with samples from `Normal(0, scale_initializer)`. diag_bijector: Bijector to apply to the diagonal of the operator. dtype: `tf.dtype` of the `LinearOperator`. name: str, name for `tf.name_scope`. Yields: *parameters: sequence of `trainable_state_util.Parameter` namedtuples. These are intended to be consumed by `trainable_state_util.as_stateful_builder` and `trainable_state_util.as_stateless_builder` to define stateful and stateless variants respectively. """ with tf.name_scope(name or 'trainable_linear_operator_tril'): if dtype is None: dtype = dtype_util.common_dtype([scale_initializer], dtype_hint=tf.float32) scale_initializer = tf.convert_to_tensor(scale_initializer, dtype=dtype) diag_bijector = diag_bijector or _DefaultScaleDiagonal() batch_shape, dim = ps.split(shape, num_or_size_splits=[-1, 1]) scale_tril_bijector = fill_scale_tril.FillScaleTriL( diag_bijector, diag_shift=tf.zeros([], dtype=dtype)) scale_tril = yield trainable_state_util.Parameter( init_fn=lambda seed: scale_tril_bijector( # pylint: disable=g-long-lambda samplers.normal(mean=0., stddev=scale_initializer, shape=ps.concat( [batch_shape, dim * (dim + 1) // 2], axis=0), seed=seed, dtype=dtype)), name='scale_tril', constraining_bijector=scale_tril_bijector) return tf.linalg.LinearOperatorLowerTriangular(tril=scale_tril, is_non_singular=True)
def _default_event_space_bijector(self): # TODO(b/145620027) Finalize choice of bijector. tril_bijector = chain_bijector.Chain([ transform_diagonal_bijector.TransformDiagonal( diag_bijector=softplus_bijector.Softplus( validate_args=self.validate_args), validate_args=self.validate_args), fill_scale_tril_bijector.FillScaleTriL( validate_args=self.validate_args) ], validate_args=self.validate_args) if self.input_output_cholesky: return tril_bijector return chain_bijector.Chain([ cholesky_outer_product_bijector.CholeskyOuterProduct( validate_args=self.validate_args), tril_bijector ], validate_args=self.validate_args)
def build_trainable_linear_operator_tril(shape, scale_initializer=1e-2, diag_bijector=None, dtype=None, seed=None, name=None): """Build a trainable `LinearOperatorLowerTriangular` instance. Args: shape: Shape of the `LinearOperator`, equal to `[b0, ..., bn, d]`, where `b0...bn` are batch dimensions and `d` is the length of the diagonal. scale_initializer: Variables are initialized with samples from `Normal(0, scale_initializer)`. diag_bijector: Bijector to apply to the diagonal of the operator. dtype: `tf.dtype` of the `LinearOperator`. seed: Python integer to seed the random number generator. name: str, name for `tf.name_scope`. Returns: operator: Trainable instance of `tf.linalg.LinearOperatorLowerTriangular`. """ with tf.name_scope(name or 'build_trainable_linear_operator_tril'): if dtype is None: dtype = dtype_util.common_dtype([scale_initializer], dtype_hint=tf.float32) scale_initializer = tf.convert_to_tensor(scale_initializer, dtype=dtype) diag_bijector = diag_bijector or _DefaultScaleDiagonal() batch_shape, dim = ps.split(shape, num_or_size_splits=[-1, 1]) scale_tril_bijector = fill_scale_tril.FillScaleTriL( diag_bijector, diag_shift=tf.zeros([], dtype=dtype)) flat_initial_scale = samplers.normal( mean=0., stddev=scale_initializer, shape=ps.concat([batch_shape, dim * (dim + 1) // 2], axis=0), seed=seed, dtype=dtype) return tf.linalg.LinearOperatorLowerTriangular( tril=tfp_util.TransformedVariable( scale_tril_bijector.forward(flat_initial_scale), bijector=scale_tril_bijector, name='tril'), is_non_singular=True)