def __init__(self, distribution, bijector=None, batch_shape=None, event_shape=None, validate_args=False, name=None): """Construct a Transformed Distribution. Args: distribution: The base distribution instance to transform. Typically an instance of `Distribution`. bijector: The object responsible for calculating the transformation. Typically an instance of `Bijector`. `None` means `Identity()`. batch_shape: `integer` vector `Tensor` which overrides `distribution` `batch_shape`; valid only if `distribution.is_scalar_batch` and `distribution.is_scalar_event`. event_shape: `integer` vector `Tensor` which overrides `distribution` `event_shape`; valid only if `distribution.is_scalar_batch` and `distribution.is_scalar_event` validate_args: Python `Boolean`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. name: The name for the distribution. Default: `bijector.name + distribution.name`. """ parameters = locals() parameters.pop("self") if bijector is None: bijector = bijectors.Identity(validate_args=validate_args) name = name or bijector.name + distribution.name with ops.name_scope(name, values=[event_shape, batch_shape]): if batch_shape is not None or event_shape is not None: is_scalar_batch_and_scalar_event = _logical_and( distribution.is_scalar_batch, distribution.is_scalar_event) if batch_shape is not None: batch_shape = self._maybe_validate_shape_override( ops.convert_to_tensor(batch_shape, name="batch_shape"), is_scalar_batch_and_scalar_event, validate_args) self._override_batch_shape = batch_shape if event_shape is not None: event_shape = self._maybe_validate_shape_override( ops.convert_to_tensor(event_shape, name="event_shape"), is_scalar_batch_and_scalar_event, validate_args) event_ndims = (event_shape.get_shape().ndims if event_shape.get_shape().ndims is not None else array_ops.rank(event_shape, "event_ndims")) self._reduce_event_indices = math_ops.range(-event_ndims, 0) self._override_event_shape = event_shape self._distribution = distribution self._bijector = bijector super(TransformedDistribution, self).__init__( dtype=self._distribution.dtype, is_continuous=self._distribution.is_continuous, is_reparameterized=self._distribution.is_reparameterized, validate_args=validate_args, allow_nan_stats=self._distribution.allow_nan_stats, parameters=parameters, # We let TransformedDistribution access _graph_parents since this class # is more like a baseclass than derived. graph_parents=( distribution._graph_parents + # pylint: disable=protected-access bijector.graph_parents), name=name)
def __init__(self, distribution, bijector=None, batch_shape=None, event_shape=None, validate_args=False, name=None): """Construct a Transformed Distribution. Args: distribution: The base distribution instance to transform. Typically an instance of `Distribution`. bijector: The object responsible for calculating the transformation. Typically an instance of `Bijector`. `None` means `Identity()`. batch_shape: `integer` vector `Tensor` which overrides `distribution` `batch_shape`; valid only if `distribution.is_scalar_batch()`. event_shape: `integer` vector `Tensor` which overrides `distribution` `event_shape`; valid only if `distribution.is_scalar_event()`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Default: `bijector.name + distribution.name`. """ parameters = locals() name = name or (("" if bijector is None else bijector.name) + distribution.name) with ops.name_scope(name, values=[event_shape, batch_shape]): # For convenience we define some handy constants. self._zero = constant_op.constant(0, dtype=dtypes.int32, name="zero") self._empty = constant_op.constant([], dtype=dtypes.int32, name="empty") if bijector is None: bijector = bijectors.Identity(validate_args=validate_args) # We will keep track of a static and dynamic version of # self._is_{batch,event}_override. This way we can do more prior to graph # execution, including possibly raising Python exceptions. self._override_batch_shape = self._maybe_validate_shape_override( batch_shape, distribution.is_scalar_batch(), validate_args, "batch_shape") self._is_batch_override = _logical_not( _logical_equal(_ndims_from_shape(self._override_batch_shape), self._zero)) self._is_maybe_batch_override = bool( tensor_util.constant_value(self._override_batch_shape) is None or tensor_util.constant_value( self._override_batch_shape).size != 0) self._override_event_shape = self._maybe_validate_shape_override( event_shape, distribution.is_scalar_event(), validate_args, "event_shape") self._is_event_override = _logical_not( _logical_equal(_ndims_from_shape(self._override_event_shape), self._zero)) self._is_maybe_event_override = bool( tensor_util.constant_value(self._override_event_shape) is None or tensor_util.constant_value( self._override_event_shape).size != 0) # To convert a scalar distribution into a multivariate distribution we # will draw dims from the sample dims, which are otherwise iid. This is # easy to do except in the case that the base distribution has batch dims # and we're overriding event shape. When that case happens the event dims # will incorrectly be to the left of the batch dims. In this case we'll # cyclically permute left the new dims. self._needs_rotation = _logical_and( self._is_event_override, _logical_not(self._is_batch_override), _logical_not(distribution.is_scalar_batch())) override_event_ndims = _ndims_from_shape( self._override_event_shape) self._rotate_ndims = _pick_scalar_condition( self._needs_rotation, override_event_ndims, 0) # We'll be reducing the head dims (if at all), i.e., this will be [] # if we don't need to reduce. self._reduce_event_indices = math_ops.range( self._rotate_ndims - override_event_ndims, self._rotate_ndims) self._distribution = distribution self._bijector = bijector super(TransformedDistribution, self).__init__( dtype=self._distribution.dtype, is_continuous=self._distribution.is_continuous, reparameterization_type=self._distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=self._distribution.allow_nan_stats, parameters=parameters, # We let TransformedDistribution access _graph_parents since this class # is more like a baseclass than derived. graph_parents=( distribution._graph_parents + # pylint: disable=protected-access bijector.graph_parents), name=name)
def __init__(self, distribution, bijector=None, batch_shape=None, event_shape=None, validate_args=False, name=None): """Construct a Transformed Distribution. Args: distribution: The base distribution instance to transform. Typically an instance of `Distribution`. bijector: The object responsible for calculating the transformation. Typically an instance of `Bijector`. `None` means `Identity()`. batch_shape: `integer` vector `Tensor` which overrides `distribution` `batch_shape`; valid only if `distribution.is_scalar_batch`. event_shape: `integer` vector `Tensor` which overrides `distribution` `event_shape`; valid only if `distribution.is_scalar_event`. validate_args: Python Boolean. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. name: The name for the distribution. Default: `bijector.name + distribution.name`. """ parameters = locals() parameters.pop("self") if bijector is None: bijector = bijectors.Identity(validate_args=validate_args) name = name or bijector.name + distribution.name with ops.name_scope(name, values=[event_shape, batch_shape]): if batch_shape is not None: batch_shape = self._maybe_validate_shape_override( ops.convert_to_tensor(batch_shape, name="batch_shape"), distribution.is_scalar_batch, validate_args) self._override_batch_shape = batch_shape if event_shape is not None: event_shape = self._maybe_validate_shape_override( ops.convert_to_tensor(event_shape, name="event_shape"), distribution.is_scalar_event, validate_args) self._override_event_ndims = ( event_shape.get_shape().ndims if event_shape.get_shape().ndims is not None else array_ops.rank(event_shape, name="event_ndims")) else: self._override_event_ndims = 0 self._override_event_shape = event_shape # To convert a scalar distribution into a multivariate distribution we # will draw dims from the sample dims, which are otherwise iid. This is # easy to do except in the case that: # batch_shape is None and # event_shape is not None and # not distribution.is_scalar_batch. # When that case happens the event dims will incorrectly be to the left of # the batch dims. In this case we'll cyclically permute left the new dims. if batch_shape is None and event_shape is not None: self._needs_rotation = ops.convert_to_tensor( _logical_not(distribution.is_scalar_batch), name="needs_rotation") n = _pick_scalar_condition(self._needs_rotation, self._override_event_ndims, 0) # We'll be reducing the head dims (if at all), i.e., this will be [] # if we don't need to reduce. self._reduce_event_indices = math_ops.range( n - self._override_event_ndims, n) else: self._needs_rotation = ops.convert_to_tensor(False, name="needs_rotation") # We'll be reducing the tail dims (if at all), i.e., this will be [] # if we don't need to reduce. self._reduce_event_indices = ( math_ops.range(-self._override_event_ndims, 0) if event_shape is not None else []) self._distribution = distribution self._bijector = bijector super(TransformedDistribution, self).__init__( dtype=self._distribution.dtype, is_continuous=self._distribution.is_continuous, is_reparameterized=self._distribution.is_reparameterized, validate_args=validate_args, allow_nan_stats=self._distribution.allow_nan_stats, parameters=parameters, # We let TransformedDistribution access _graph_parents since this class # is more like a baseclass than derived. graph_parents=(distribution._graph_parents + # pylint: disable=protected-access bijector.graph_parents), name=name)