def assert_values_compatible_with_distribution_shape( scoped_name: str, values: Any, dist_shape: tf.TensorShape) -> None: """Assert if a supplied values are compatible with a distribution's TensorShape. A value is considered to have a consistent shape with the distribution if two conditions are met. 1) It has a greater or equal number of dimensions when compared to the distribution: len(values.shape) >= len(dist_shape) 2) The supplied values' shape is compatible with the distribution's shape: dist_shape.is_compatible_with(values.shape[(len(values.shape) - len(dist_shape)):]) Parameters ---------- scoped_name: str The variable's scoped name values: Any The supplied values dist_shape: tf.TensorShape The ``tf.TensorShape`` instance. Returns ------- None Raises ------ EvaluationError When the ``values`` shape is not compatible with the ``dist_shape``. """ value_shape = get_observed_tensor_shape(values) if value_shape.rank < dist_shape.rank or not dist_shape.is_compatible_with( value_shape[(len(value_shape) - len(dist_shape)):]): raise EvaluationError( EvaluationError.INCOMPATIBLE_VALUE_AND_DISTRIBUTION_SHAPE.format( scoped_name, dist_shape, value_shape))
def _assert_compatible_shape(shape: tf.TensorShape, example_shape): if not shape.is_compatible_with(example_shape): raise ValueError(f"example shape {example_shape} is incompatible with " f"feature shape {shape}")