def bijectors(draw, bijector_name=None, batch_shape=None, event_dim=None, enable_vars=False, allowed_bijectors=None, validate_args=True, return_duplicate=False): """Strategy for drawing Bijectors. The emitted bijector may be a basic bijector or an `Invert` of a basic bijector, but not a compound like `Chain`. Args: draw: Hypothesis strategy sampler supplied by `@hps.composite`. bijector_name: Optional Python `str`. If given, the produced bijectors will all have this type. If omitted, Hypothesis chooses one from the allowlist `INSTANTIABLE_BIJECTORS`. batch_shape: An optional `TensorShape`. The batch shape of the resulting bijector. Hypothesis will pick one if omitted. event_dim: Optional Python int giving the size of each of the underlying distribution's parameters' event dimensions. This is shared across all parameters, permitting square event matrices, compatible location and scale Tensors, etc. If omitted, Hypothesis will choose one. enable_vars: TODO(bjp): Make this `True` all the time and put variable initialization in slicing_test. If `False`, the returned parameters are all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor` `tfp.util.TransformedVariable`} allowed_bijectors: Optional list of `str` Bijector names to sample from. Bijectors not in this list will not be returned or instantiated as part of a meta-bijector (Chain, Invert, etc.). Defaults to `INSTANTIABLE_BIJECTORS`. validate_args: Python `bool`; whether to enable runtime checks. return_duplicate: Python `bool`: If `False` return a single bijector. If `True` return a tuple of two bijectors of the same type, instantiated with the same parameters. Returns: bijectors: A strategy for drawing bijectors with the specified `batch_shape` (or an arbitrary one if omitted). """ if allowed_bijectors is None: allowed_bijectors = bhps.INSTANTIABLE_BIJECTORS if bijector_name is None: bijector_name = draw(hps.sampled_from(allowed_bijectors)) if batch_shape is None: batch_shape = draw(tfp_hps.shapes()) if event_dim is None: event_dim = draw(hps.integers(min_value=2, max_value=6)) if bijector_name == 'Invert': underlying_name = draw( hps.sampled_from(sorted(set(allowed_bijectors) - {'Invert'}))) underlying = draw( bijectors(bijector_name=underlying_name, batch_shape=batch_shape, event_dim=event_dim, enable_vars=enable_vars, allowed_bijectors=allowed_bijectors, validate_args=validate_args)) bijector_params = {'bijector': underlying} msg = 'Forming Invert bijector with underlying bijector {}.' hp.note(msg.format(underlying)) elif bijector_name == 'TransformDiagonal': underlying_name = draw( hps.sampled_from( sorted( set(allowed_bijectors) & set(bhps.TRANSFORM_DIAGONAL_ALLOWLIST)))) underlying = draw( bijectors(bijector_name=underlying_name, batch_shape=(), event_dim=event_dim, enable_vars=enable_vars, allowed_bijectors=allowed_bijectors, validate_args=validate_args)) bijector_params = {'diag_bijector': underlying} msg = 'Forming TransformDiagonal bijector with underlying bijector {}.' hp.note(msg.format(underlying)) elif bijector_name == 'Inline': scale = draw( tfp_hps.maybe_variable( hps.sampled_from(np.float32([1., -1., 2, -2.])), enable_vars)) b = tfb.Scale(scale=scale) bijector_params = dict( forward_fn=CallableModule(b.forward, b), inverse_fn=b.inverse, forward_log_det_jacobian_fn=lambda x: b.forward_log_det_jacobian( # pylint: disable=g-long-lambda x, event_ndims=b.forward_min_event_ndims), forward_min_event_ndims=b.forward_min_event_ndims, is_constant_jacobian=b.is_constant_jacobian, is_increasing=b._internal_is_increasing, # pylint: disable=protected-access ) elif bijector_name == 'DiscreteCosineTransform': dct_type = hps.integers(min_value=2, max_value=3) bijector_params = {'dct_type': draw(dct_type)} elif bijector_name == 'GeneralizedPareto': concentration = hps.floats(min_value=-200., max_value=200) scale = hps.floats(min_value=1e-2, max_value=200) loc = hps.floats(min_value=-200, max_value=200) bijector_params = { 'concentration': draw(concentration), 'scale': draw(scale), 'loc': draw(loc) } elif bijector_name == 'PowerTransform': power = hps.floats(min_value=1e-6, max_value=10.) bijector_params = {'power': draw(power)} elif bijector_name == 'Permute': event_ndims = draw(hps.integers(min_value=1, max_value=2)) axis = hps.integers(min_value=-event_ndims, max_value=-1) # This is a permutation of dimensions within an axis. # (Contrast with `Transpose` below.) bijector_params = { 'axis': draw(axis), 'permutation': draw( tfp_hps.maybe_variable(hps.permutations(np.arange(event_dim)), enable_vars, dtype=tf.int32)) } elif bijector_name == 'Reshape': event_shape_out = draw(tfp_hps.shapes(min_ndims=1)) # TODO(b/142135119): Wanted to draw general input and output shapes like the # following, but Hypothesis complained about filtering out too many things. # event_shape_in = draw(tfp_hps.shapes(min_ndims=1)) # hp.assume(event_shape_out.num_elements() == event_shape_in.num_elements()) event_shape_in = [event_shape_out.num_elements()] bijector_params = { 'event_shape_out': event_shape_out, 'event_shape_in': event_shape_in } elif bijector_name == 'Transpose': event_ndims = draw(hps.integers(min_value=0, max_value=2)) # This is a permutation of axes. # (Contrast with `Permute` above.) bijector_params = { 'perm': draw(hps.permutations(np.arange(event_ndims))) } else: params_event_ndims = bhps.INSTANTIABLE_BIJECTORS[ bijector_name].params_event_ndims bijector_params = draw( tfp_hps.broadcasting_params( batch_shape, params_event_ndims, event_dim=event_dim, enable_vars=enable_vars, constraint_fn_for=lambda param: constraint_for( bijector_name, param), # pylint:disable=line-too-long mutex_params=MUTEX_PARAMS)) bijector_params = constrain_params(bijector_params, bijector_name) ctor = getattr(tfb, bijector_name) hp.note('Forming {} bijector with params {}.'.format( bijector_name, bijector_params)) bijector = ctor(validate_args=validate_args, **bijector_params) if not return_duplicate: return bijector return (bijector, ctor(validate_args=validate_args, **bijector_params))
def bijectors(draw, bijector_name=None, batch_shape=None, event_dim=None, enable_vars=False): """Strategy for drawing Bijectors. The emitted bijector may be a basic bijector or an `Invert` of a basic bijector, but not a compound like `Chain`. Args: draw: Hypothesis strategy sampler supplied by `@hps.composite`. bijector_name: Optional Python `str`. If given, the produced bijectors will all have this type. If omitted, Hypothesis chooses one from the whitelist `TF2_FRIENDLY_BIJECTORS`. batch_shape: An optional `TensorShape`. The batch shape of the resulting bijector. Hypothesis will pick one if omitted. event_dim: Optional Python int giving the size of each of the underlying distribution's parameters' event dimensions. This is shared across all parameters, permitting square event matrices, compatible location and scale Tensors, etc. If omitted, Hypothesis will choose one. enable_vars: TODO(bjp): Make this `True` all the time and put variable initialization in slicing_test. If `False`, the returned parameters are all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor` `tfp.util.TransformedVariable`} Returns: bijectors: A strategy for drawing bijectors with the specified `batch_shape` (or an arbitrary one if omitted). """ if bijector_name is None: bijector_name = draw(hps.sampled_from(TF2_FRIENDLY_BIJECTORS)) if batch_shape is None: batch_shape = draw(tfp_hps.shapes()) if event_dim is None: event_dim = draw(hps.integers(min_value=2, max_value=6)) if bijector_name == 'Invert': underlying_name = draw( hps.sampled_from(sorted(set(TF2_FRIENDLY_BIJECTORS) - {'Invert'}))) underlying = draw( bijectors(bijector_name=underlying_name, batch_shape=batch_shape, event_dim=event_dim, enable_vars=enable_vars)) bijector_params = {'bijector': underlying} elif bijector_name == 'TransformDiagonal': underlying_name = draw( hps.sampled_from(sorted(TRANSFORM_DIAGONAL_WHITELIST))) underlying = draw( bijectors(bijector_name=underlying_name, batch_shape=(), event_dim=event_dim, enable_vars=enable_vars)) bijector_params = {'diag_bijector': underlying} elif bijector_name == 'Inline': scale = draw( tfp_hps.maybe_variable( hps.sampled_from(np.float32([1., -1., 2, -2.])), enable_vars)) b = tfb.Scale(scale=scale) bijector_params = dict( forward_fn=CallableModule(b.forward, b), inverse_fn=b.inverse, forward_log_det_jacobian_fn=lambda x: b.forward_log_det_jacobian( # pylint: disable=g-long-lambda x, event_ndims=b.forward_min_event_ndims), forward_min_event_ndims=b.forward_min_event_ndims, is_constant_jacobian=b.is_constant_jacobian, is_increasing=b._internal_is_increasing, # pylint: disable=protected-access ) elif bijector_name == 'DiscreteCosineTransform': dct_type = hps.integers(min_value=2, max_value=3) bijector_params = {'dct_type': draw(dct_type)} elif bijector_name == 'PowerTransform': power = hps.floats(min_value=1e-6, max_value=10.) bijector_params = {'power': draw(power)} elif bijector_name == 'Permute': event_ndims = draw(hps.integers(min_value=1, max_value=2)) axis = hps.integers(min_value=-event_ndims, max_value=-1) # This is a permutation of dimensions within an axis. # (Contrast with `Transpose` below.) bijector_params = { 'axis': draw(axis), 'permutation': draw( tfp_hps.maybe_variable(hps.permutations(np.arange(event_dim)), enable_vars, dtype=tf.int32)) } elif bijector_name == 'Reshape': event_shape_out = draw(tfp_hps.shapes(min_ndims=1)) # TODO(b/142135119): Wanted to draw general input and output shapes like the # following, but Hypothesis complained about filtering out too many things. # event_shape_in = draw(tfp_hps.shapes(min_ndims=1)) # hp.assume(event_shape_out.num_elements() == event_shape_in.num_elements()) event_shape_in = [event_shape_out.num_elements()] bijector_params = { 'event_shape_out': event_shape_out, 'event_shape_in': event_shape_in } elif bijector_name == 'Transpose': event_ndims = draw(hps.integers(min_value=0, max_value=2)) # This is a permutation of axes. # (Contrast with `Permute` above.) bijector_params = { 'perm': draw(hps.permutations(np.arange(event_ndims))) } else: bijector_params = draw( broadcasting_params(bijector_name, batch_shape, event_dim=event_dim, enable_vars=enable_vars)) ctor = getattr(tfb, bijector_name) return ctor(validate_args=True, **bijector_params)