def transformed_distributions(draw, batch_shape=None, event_dim=None, enable_vars=False): bijector = draw(bijector_hps.unconstrained_bijectors()) logging.info('TD bijector: %s', bijector) if batch_shape is None: batch_shape = draw(tfp_hps.batch_shapes()) underlying_batch_shape = batch_shape batch_shape_arg = None if draw(hps.booleans()): # Use batch_shape overrides. underlying_batch_shape = tf.TensorShape([]) # scalar underlying batch batch_shape_arg = batch_shape underlyings = distributions( batch_shape=underlying_batch_shape, event_dim=event_dim, enable_vars=enable_vars).map( lambda dist_and_batch_shape: dist_and_batch_shape[0]).filter( bijector_hps.distribution_filter_for(bijector)) to_transform = draw(underlyings) logging.info( 'TD underlying distribution: %s; parameters used: %s', to_transform, [k for k, v in six.iteritems(to_transform.parameters) if v is not None]) return (tfd.TransformedDistribution( bijector=bijector, distribution=to_transform, batch_shape=batch_shape_arg, validate_args=True), batch_shape)
def transformed_distributions(draw, batch_shape=None, event_dim=None): bijector = bijector_hps.unconstrained_bijectors(draw) logging.info('TD bijector: %s', bijector) if batch_shape is None: batch_shape = draw(batch_shapes()) underlying_batch_shape = batch_shape batch_shape_arg = None if draw(hps.booleans()): # Use batch_shape overrides. underlying_batch_shape = tf.TensorShape([]) # scalar underlying batch batch_shape_arg = batch_shape # TODO(b/128974935): Use the composite distributions(..).map(..).filter(..) # underlyings = distributions( # batch_shape=underlying_batch_shape, event_dim=event_dim).map( # lambda dist_and_batch_shape: dist_and_batch_shape[0]).filter( # bijector_hps.distribution_filter_for(bijector)) # to_transform = draw(underlyings) to_transform, _ = distributions( draw, batch_shape=underlying_batch_shape, event_dim=event_dim, eligibility_filter=lambda name: name != 'TransformedDistribution') while not bijector_hps.distribution_filter_for(bijector)(to_transform): to_transform, _ = distributions(draw, batch_shape=underlying_batch_shape, event_dim=event_dim) logging.info( 'TD underlying distribution: %s; parameters used: %s', to_transform, [ k for k, v in six.iteritems(to_transform.parameters) if v is not None ]) return (tfd.TransformedDistribution(bijector=bijector, distribution=to_transform, batch_shape=batch_shape_arg, validate_args=True), batch_shape)
def transformed_distributions(draw, batch_shape=None, event_dim=None, enable_vars=False, depth=None): """Strategy for drawing `TransformedDistribution`s. The transforming bijector is drawn from the `bijectors.hypothesis_testlib.unconstrained_bijectors` strategy. The underlying distribution is drawn from the `distributions` strategy, except that it must be compatible with the bijector according to `bijectors.hypothesis_testlib.distribution_filter_for` (these generally check that vector bijectors are not combined with scalar distributions, etc). Args: draw: Hypothesis strategy sampler supplied by `@hps.composite`. batch_shape: An optional `TensorShape`. The batch shape of the resulting `TransformedDistribution`. The underlying distribution will sometimes have the same `batch_shape`, and sometimes have scalar batch shape. Hypothesis will pick a `batch_shape` if omitted. event_dim: Optional Python int giving the size of each of the underlying distribution's parameters' event dimensions. This is shared across all parameters, permitting square event matrices, compatible location and scale Tensors, etc. If omitted, Hypothesis will choose one. enable_vars: TODO(bjp): Make this `True` all the time and put variable initialization in slicing_test. If `False`, the returned parameters are all Tensors, never Variables or DeferredTensor. depth: Python `int` giving maximum nesting depth of compound Distributions. Returns: dists: A strategy for drawing `TransformedDistribution`s with the specified `batch_shape` (or an arbitrary one if omitted). """ if depth is None: depth = draw(depths()) bijector = draw(bijector_hps.unconstrained_bijectors()) hp.note( 'Drawing TransformedDistribution with bijector {}'.format(bijector)) if batch_shape is None: batch_shape = draw(tfp_hps.shapes()) underlying_batch_shape = batch_shape batch_shape_arg = None if draw(hps.booleans()): # Use batch_shape overrides. underlying_batch_shape = tf.TensorShape([]) # scalar underlying batch batch_shape_arg = batch_shape underlyings = distributions( batch_shape=underlying_batch_shape, event_dim=event_dim, enable_vars=enable_vars, depth=depth - 1).filter(bijector_hps.distribution_filter_for(bijector)) to_transform = draw(underlyings) hp.note('Forming TransformedDistribution with ' 'underlying distribution {}; parameters {}'.format( to_transform, params_used(to_transform))) # TODO(bjp): Add test coverage for `event_shape` argument of # `TransformedDistribution`. result_dist = tfd.TransformedDistribution(bijector=bijector, distribution=to_transform, batch_shape=batch_shape_arg, validate_args=True) if batch_shape != result_dist.batch_shape: msg = ('TransformedDistribution strategy generated a bad batch shape ' 'for {}, should have been {}.').format(result_dist, batch_shape) raise AssertionError(msg) return result_dist
def transformed_distributions(draw, batch_shape=None, event_dim=None, enable_vars=False, depth=None, eligibility_filter=lambda name: True, validate_args=True): """Strategy for drawing `TransformedDistribution`s. The transforming bijector is drawn from the `bijectors.hypothesis_testlib.unconstrained_bijectors` strategy. The underlying distribution is drawn from the `distributions` strategy, except that it must be compatible with the bijector according to `bijectors.hypothesis_testlib.distribution_filter_for` (these generally check that vector bijectors are not combined with scalar distributions, etc). Args: draw: Hypothesis strategy sampler supplied by `@hps.composite`. batch_shape: An optional `TensorShape`. The batch shape of the resulting `TransformedDistribution`. The underlying distribution will sometimes have the same `batch_shape`, and sometimes have scalar batch shape. Hypothesis will pick a `batch_shape` if omitted. event_dim: Optional Python int giving the size of each of the underlying distribution's parameters' event dimensions. This is shared across all parameters, permitting square event matrices, compatible location and scale Tensors, etc. If omitted, Hypothesis will choose one. enable_vars: TODO(bjp): Make this `True` all the time and put variable initialization in slicing_test. If `False`, the returned parameters are all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor` `tfp.util.TransformedVariable`} depth: Python `int` giving maximum nesting depth of compound Distributions. eligibility_filter: Optional Python callable. Blocks some Distribution class names so they will not be drawn. validate_args: Python `bool`; whether to enable runtime assertions. Returns: dists: A strategy for drawing `TransformedDistribution`s with the specified `batch_shape` (or an arbitrary one if omitted). """ if depth is None: depth = draw(depths()) bijector = draw(bijector_hps.unconstrained_bijectors()) hp.note('Drawing TransformedDistribution with bijector {}'.format(bijector)) if batch_shape is None: batch_shape = draw(tfp_hps.shapes()) def eligibility_fn(name): if not eligibility_filter(name): return False return bijector_hps.distribution_eligilibility_filter_for(bijector)(name) underlyings = distributions( batch_shape=batch_shape, event_dim=event_dim, enable_vars=enable_vars, depth=depth - 1, eligibility_filter=eligibility_fn, validate_args=validate_args).filter( bijector_hps.distribution_filter_for(bijector)) to_transform = draw(underlyings) hp.note('Forming TransformedDistribution with ' 'underlying distribution {}; parameters {}'.format( to_transform, params_used(to_transform))) result_dist = tfd.TransformedDistribution( bijector=bijector, distribution=to_transform, validate_args=validate_args) if batch_shape != result_dist.batch_shape: msg = ('TransformedDistribution strategy generated a bad batch shape ' 'for {}, should have been {}.').format(result_dist, batch_shape) raise AssertionError(msg) return result_dist