def codomain_tensors(draw, bijector, shape=None): """Strategy for drawing Tensors in the codomain of a bijector. If the bijector's codomain is constrained, this proceeds by drawing an unconstrained Tensor and then transforming it to fit. The constraints are declared in `bijectors.hypothesis_testlib.bijector_supports`. The transformations are defined by `tfp_hps.constrainer`. Args: draw: Hypothesis strategy sampler supplied by `@hps.composite`. bijector: A `Bijector` in whose codomain the Tensors will be. shape: An optional `TensorShape`. The shape of the resulting Tensors. Hypothesis will pick one if omitted. Returns: tensors: A strategy for drawing codomain Tensors for the desired bijector. """ if is_invert(bijector): return draw(domain_tensors(bijector.bijector, shape)) elif is_transform_diagonal(bijector): return draw(codomain_tensors(bijector.diag_bijector, shape)) if shape is None: shape = draw(tfp_hps.shapes()) bijector_name = type(bijector).__name__ support = bhps.bijector_supports()[bijector_name].inverse if is_generalized_pareto(bijector): constraint_fn = bhps.generalized_pareto_constraint( bijector.loc, bijector.scale, bijector.concentration) elif isinstance(bijector, tfb.SoftClip): constraint_fn = bhps.softclip_constraint(bijector.low, bijector.high) else: constraint_fn = tfp_hps.constrainer(support) return draw(tfp_hps.constrained_tensors(constraint_fn, shape))
def domain_tensors(draw, bijector, shape=None): """Strategy for drawing Tensors in the domain of a bijector. If the bijector's domain is constrained, this proceeds by drawing an unconstrained Tensor and then transforming it to fit. The constraints are declared in `bijectors.hypothesis_testlib.bijector_supports`. The transformations are defined by `tfp_hps.constrainer`. Args: draw: Hypothesis strategy sampler supplied by `@hps.composite`. bijector: A `Bijector` in whose domain the Tensors will be. shape: An optional `TensorShape`. The shape of the resulting Tensors. Hypothesis will pick one if omitted. Returns: tensors: A strategy for drawing domain Tensors for the desired bijector. """ if is_invert(bijector): return draw(codomain_tensors(bijector.bijector, shape)) if shape is None: shape = draw(tfp_hps.shapes()) bijector_name = type(bijector).__name__ support = bijector_hps.bijector_supports()[bijector_name].forward if isinstance(bijector, tfb.PowerTransform): constraint_fn = bijector_hps.power_transform_constraint(bijector.power) else: constraint_fn = tfp_hps.constrainer(support) return draw(tfp_hps.constrained_tensors(constraint_fn, shape))
def codomain_tensors(draw, bijector, shape=None): if is_invert(bijector): return draw(domain_tensors(bijector.bijector, shape)) if shape is None: shape = draw(tfp_hps.batch_shapes()) bijector_name = type(bijector).__name__ support = bijector_hps.bijector_supports()[bijector_name].inverse constraint_fn = constrainer(support) return draw(tfp_hps.constrained_tensors(constraint_fn, shape))
def constrain_forward_shape(bijector, shape): """Constrain the shape so it is compatible with bijector.forward.""" if is_invert(bijector): return constrain_inverse_shape(bijector.bijector, shape=shape) support = bijector_hps.bijector_supports()[type(bijector).__name__].forward if support == tfp_hps.Support.VECTOR_SIZE_TRIANGULAR: # Need to constrain the shape. shape[-1] = int(shape[-1] * (shape[-1] + 1) / 2) return shape
def constrain_forward_shape(bijector, shape): """Constrain the shape so it is compatible with bijector.forward. Args: bijector: A `Bijector`. shape: A TensorShape or compatible, giving the desired event shape. Returns: shape: A TensorShape, giving an event shape compatible with `bijector.forward`, loosely inspired by the input `shape`. """ if is_invert(bijector): return constrain_inverse_shape(bijector.bijector, shape=shape) support = bijector_hps.bijector_supports()[type(bijector).__name__].forward if support == tfp_hps.Support.VECTOR_SIZE_TRIANGULAR: # Need to constrain the shape. shape[-1] = int(shape[-1] * (shape[-1] + 1) / 2) if isinstance(bijector, tfb.Reshape): # Note: This relies on the out event shape being fully determined shape = tf.get_static_value(bijector._event_shape_in) return tf.TensorShape(shape)