def _parameter_properties(cls, dtype, num_classes=None):
     return dict(
         index_points=parameter_properties.ParameterProperties(
             event_ndims=lambda self: self.kernel.feature_ndims + 1,
             shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED,
         ),
         observations=parameter_properties.ParameterProperties(
             event_ndims=2,
             shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED),
         observation_index_points=parameter_properties.ParameterProperties(
             event_ndims=lambda self: self.kernel.feature_ndims + 1,
             shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED,
         ),
         observations_is_missing=parameter_properties.ParameterProperties(
             event_ndims=2,
             shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED,
         ),
         kernel=parameter_properties.BatchedComponentProperties(),
         observation_noise_variance=parameter_properties.
         ParameterProperties(
             event_ndims=0,
             shape_fn=lambda sample_shape: sample_shape[:-1],
             default_constraining_bijector_fn=(
                 lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype
                                                                       )))),
         predictive_noise_variance=parameter_properties.ParameterProperties(
             event_ndims=0,
             shape_fn=lambda sample_shape: sample_shape[:-1],
             default_constraining_bijector_fn=(
                 lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype
                                                                       )))),
         _observation_scale=parameter_properties.BatchedComponentProperties(
         ))
예제 #2
0
 def _parameter_properties(cls, dtype, num_classes=None):
     return dict(
         loc=parameter_properties.ParameterProperties(event_ndims=1),
         precision_factor=parameter_properties.BatchedComponentProperties(),
         precision=parameter_properties.BatchedComponentProperties(),
         nonzeros=parameter_properties.BatchedComponentProperties(
             event_ndims=1))
예제 #3
0
 def _parameter_properties(cls, dtype, num_classes=None):
   return dict(
       mixture_distribution=(
           parameter_properties.BatchedComponentProperties()),
       components_distribution=(
           parameter_properties.BatchedComponentProperties(
               event_ndims=1)))
예제 #4
0
 def _parameter_properties(cls, dtype):
   return dict(
       chain=parameter_properties.BatchedComponentProperties(),
       transition_bijector=parameter_properties.BatchedComponentProperties(
           # The transition bijector contributes no batch shape
           # beyond that from the chain itself.
           event_ndims=None))
 def _parameter_properties(cls, dtype, num_classes=None):
   return dict(
       distribution=parameter_properties.BatchedComponentProperties(),
       bijector=parameter_properties.BatchedComponentProperties(
           event_ndims=lambda td: tf.nest.map_structure(  # pylint: disable=g-long-lambda
               tensorshape_util.rank, td.distribution.event_shape),
           event_ndims_tensor=lambda td: tf.nest.map_structure(  # pylint: disable=g-long-lambda
               ps.rank_from_shape, td.distribution.event_shape_tensor())))
 def _parameter_properties(cls, dtype, num_classes=None):
   return dict(
       df=parameter_properties.ParameterProperties(
           default_constraining_bijector_fn=(
               lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype)))),
       loc=parameter_properties.ParameterProperties(event_ndims=2),
       scale_row=parameter_properties.BatchedComponentProperties(),
       scale_column=parameter_properties.BatchedComponentProperties())
예제 #7
0
 def _parameter_properties(cls, dtype, num_classes=None):
     # pylint: disable=g-long-lambda
     return dict(df=parameter_properties.ParameterProperties(
         shape_fn=lambda sample_shape: sample_shape[:-2],
         default_constraining_bijector_fn=parameter_properties.
         BIJECTOR_NOT_IMPLEMENTED),
                 scale=parameter_properties.BatchedComponentProperties())
예제 #8
0
 def _parameter_properties(cls, dtype, num_classes=None):
   return dict(
       distribution=parameter_properties.BatchedComponentProperties(
           # TODO(davmre): replace with `self.reinterpreted_batch_ndims` once
           # support for `reinterpreted_batch_ndims=None` has been removed.
           event_ndims=lambda self: self._get_reinterpreted_batch_ndims()),  # pylint: disable=protected-access
       reinterpreted_batch_ndims=(
           parameter_properties.ShapeParameterProperties()))
예제 #9
0
 def _parameter_properties(cls, dtype):
     from tensorflow_probability.python.bijectors import softplus  # pylint:disable=g-import-not-at-top
     return dict(
         kernel=parameter_properties.BatchedComponentProperties(),
         scale_diag=parameter_properties.ParameterProperties(
             event_ndims=lambda self: self.kernel.feature_ndims,
             default_constraining_bijector_fn=(
                 lambda: softplus.Softplus(low=dtype_util.eps(dtype)))))
예제 #10
0
 def _parameter_properties(cls, dtype, num_classes=None):
   from tensorflow_probability.python.bijectors import softplus as softplus_bijector  # pylint:disable=g-import-not-at-top
   return dict(
       amplitudes=parameter_properties.ParameterProperties(
           event_ndims=1,
           default_constraining_bijector_fn=(
               softplus_bijector.Softplus(low=dtype_util.eps(dtype)))),
       kernel=parameter_properties.BatchedComponentProperties(event_ndims=1))
 def _parameter_properties(cls, dtype, num_classes=None):
     return dict(
         distribution=parameter_properties.BatchedComponentProperties(),
         low=parameter_properties.ParameterProperties(),
         # TODO(b/169874884): Support decoupled parameterization.
         high=parameter_properties.ParameterProperties(
             default_constraining_bijector_fn=parameter_properties.
             BIJECTOR_NOT_IMPLEMENTED, ))
예제 #12
0
 def _parameter_properties(cls, dtype, num_classes=None):
   return dict(
       df=parameter_properties.ParameterProperties(
           default_constraining_bijector_fn=(
               lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype)))),
       index_points=parameter_properties.ParameterProperties(
           event_ndims=lambda self: self.kernel.feature_ndims + 1,
           shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED),
       kernel=parameter_properties.BatchedComponentProperties())
예제 #13
0
 def _parameter_properties(cls, dtype, num_classes=None):
     return dict(
         df=parameter_properties.ParameterProperties(
             default_constraining_bijector_fn=(
                 lambda: softplus_bijector.Softplus(  # pylint:disable=g-long-lambda
                     low=dtype_util.as_numpy_dtype(dtype)(2.)))),
         schur_complement=parameter_properties.BatchedComponentProperties(),
         fixed_inputs_observations=parameter_properties.ParameterProperties(
             event_ndims=1,
             shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED))
예제 #14
0
 def _parameter_properties(cls, dtype):
     from tensorflow_probability.python.bijectors import softplus  # pylint:disable=g-import-not-at-top
     return dict(
         base_kernel=parameter_properties.BatchedComponentProperties(),
         fixed_inputs=parameter_properties.ParameterProperties(
             event_ndims=lambda self: self.base_kernel.feature_ndims + 1),
         diag_shift=parameter_properties.ParameterProperties(
             default_constraining_bijector_fn=(
                 lambda: softplus.Softplus(low=dtype_util.eps(dtype)))),
         _precomputed_divisor_matrix_cholesky=(
             parameter_properties.ParameterProperties(event_ndims=2)))
예제 #15
0
 def _parameter_properties(cls, dtype, num_classes=None):
     # Subclasses must implement their own `_parameter_properties`. If they
     # don't, call the base Distribution version to raise a NotImplementedError.
     if cls is MultivariateStudentTLinearOperator:
         return dict(
             df=parameter_properties.ParameterProperties(
                 default_constraining_bijector_fn=(
                     lambda: softplus_bijector.Softplus(low=dtype_util.eps(
                         dtype)))),
             loc=parameter_properties.ParameterProperties(event_ndims=1),
             scale=parameter_properties.BatchedComponentProperties())
     return distribution.Distribution._parameter_properties(dtype=dtype)  # pylint: disable=protected-access
예제 #16
0
    def _parameter_properties(cls, dtype):
        def get_parameter_event_ndims(self, x_event_ndims):
            return nest.map_structure_up_to(
                self.
                bijectors,  # Recurse up to the BijectorWithMetadata tuples.
                lambda bm: bm.x_event_ndims,
                self._get_bijectors_with_metadata(  # pylint: disable=protected-access
                    event_ndims=x_event_ndims,
                    pack_as_original_structure=True))

        return dict(bijectors=parameter_properties.BatchedComponentProperties(
            event_ndims=get_parameter_event_ndims))
 def _parameter_properties(cls, dtype, num_classes=None):
     from tensorflow_probability.python.bijectors import softplus as softplus_bijector  # pylint:disable=g-import-not-at-top
     return dict(index_points=parameter_properties.ParameterProperties(
         event_ndims=lambda self: self.kernel.feature_ndims + 1,
         shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED),
                 kernel=parameter_properties.BatchedComponentProperties(),
                 observation_noise_variance=parameter_properties.
                 ParameterProperties(
                     event_ndims=0,
                     shape_fn=lambda sample_shape: sample_shape[:-1],
                     default_constraining_bijector_fn=(
                         lambda: softplus_bijector.Softplus(low=dtype_util.
                                                            eps(dtype)))))
예제 #18
0
 def _parameter_properties(cls, dtype, num_classes=None):
     # pylint: disable=g-long-lambda
     return dict(
         distribution=parameter_properties.BatchedComponentProperties(),
         shift=parameter_properties.ParameterProperties(),
         scale=parameter_properties.ParameterProperties(
             default_constraining_bijector_fn=(
                 lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype
                                                                       )))),
         tailweight=parameter_properties.ParameterProperties(
             default_constraining_bijector_fn=(
                 lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype
                                                                       )))))
예제 #19
0
 def _parameter_properties(cls, dtype, num_classes=None):
     from tensorflow_probability.python.bijectors import ascending  # pylint:disable=g-import-not-at-top
     from tensorflow_probability.python.bijectors import softplus  # pylint:disable=g-import-not-at-top
     return dict(
         kernels=parameter_properties.BatchedComponentProperties(
             event_ndims=lambda self: [0 for _ in self.kernels]),
         locs=parameter_properties.ParameterProperties(
             event_ndims=1,
             default_constraining_bijector_fn=lambda: ascending.Ascending(
             )),  # pylint:disable=unnecessary-lambda
         slopes=parameter_properties.ParameterProperties(
             event_ndims=1,
             default_constraining_bijector_fn=(
                 lambda: softplus.Softplus(low=dtype_util.eps(dtype)))))
예제 #20
0
 def _parameter_properties(cls, dtype, num_classes=None):
   return dict(
       df=parameter_properties.ParameterProperties(
           default_constraining_bijector_fn=(
               lambda: softplus_bijector.Softplus(  # pylint: disable=g-long-lambda
                   low=dtype_util.as_numpy_dtype(dtype)(2.)))),
       index_points=parameter_properties.ParameterProperties(
           event_ndims=lambda self: self.kernel.feature_ndims + 1,
           shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED),
       kernel=parameter_properties.BatchedComponentProperties(),
       observation_noise_variance=parameter_properties.ParameterProperties(
           default_constraining_bijector_fn=(
               lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype))),
           shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED))
예제 #21
0
 def _parameter_properties(cls, dtype, num_classes=None):
     return dict(
         distributions=parameter_properties.BatchedComponentProperties(
             event_ndims=lambda self: [0 for _ in self.distributions]),
         axis=parameter_properties.ShapeParameterProperties())
예제 #22
0
 def _parameter_properties(cls, dtype, num_classes=None):
   return dict(
       distribution=parameter_properties.BatchedComponentProperties(),
       sample_shape=parameter_properties.ShapeParameterProperties())
예제 #23
0
 def _parameter_properties(cls, dtype, num_classes=None):
     return dict(
         distribution=parameter_properties.BatchedComponentProperties(
             event_ndims=lambda self: self.reinterpreted_batch_ndims))
예제 #24
0
 def _parameter_properties(cls, dtype, num_classes=None):
   return dict(
       loc=parameter_properties.ParameterProperties(event_ndims=1),
       scale=parameter_properties.BatchedComponentProperties())
예제 #25
0
 def _parameter_properties(cls, dtype, num_classes=None):
     return dict(
         distribution=parameter_properties.BatchedComponentProperties(),
         bijector=parameter_properties.BatchedComponentProperties(
             event_ndims=_count_bijector_batch_ndims_used_by_event))
예제 #26
0
 def _parameter_properties(cls, dtype):
     return dict(bijector=parameter_properties.BatchedComponentProperties(
         event_ndims=(
             lambda self, x_event_ndims: self.bijector.inverse_event_ndims(  # pylint: disable=g-long-lambda
                 x_event_ndims))))
예제 #27
0
 def _parameter_properties(cls, dtype, num_classes=None):
   return dict(
       base_kernel=parameter_properties.BatchedComponentProperties(),
       task_kernel_matrix_linop=(
           parameter_properties.BatchedComponentProperties()))
예제 #28
0
 def _parameter_properties(cls, dtype):
   return dict(
       bijectors=parameter_properties.BatchedComponentProperties(
           event_ndims=lambda self: [None for _ in self.bijectors]))
예제 #29
0
 def _parameter_properties(cls, dtype, num_classes=None):
     return dict(kernel=parameter_properties.BatchedComponentProperties())
예제 #30
0
파일: masked.py 프로젝트: axch/probability
 def _parameter_properties(cls, dtype, num_classes=None):
     return dict(
         distribution=parameter_properties.BatchedComponentProperties(),
         validity_mask=parameter_properties.ParameterProperties(
             shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED))