def codomain_tensors(draw, bijector, shape=None): """Strategy for drawing Tensors in the codomain of a bijector. If the bijector's codomain is constrained, this proceeds by drawing an unconstrained Tensor and then transforming it to fit. The constraints are declared in `bijectors.hypothesis_testlib.bijector_supports`. The transformations are defined by `tfp_hps.constrainer`. Args: draw: Hypothesis strategy sampler supplied by `@hps.composite`. bijector: A `Bijector` in whose codomain the Tensors will be. shape: An optional `TensorShape`. The shape of the resulting Tensors. Hypothesis will pick one if omitted. Returns: tensors: A strategy for drawing codomain Tensors for the desired bijector. """ if is_invert(bijector): return draw(domain_tensors(bijector.bijector, shape)) elif is_transform_diagonal(bijector): return draw(codomain_tensors(bijector.diag_bijector, shape)) if shape is None: shape = draw(tfp_hps.shapes()) bijector_name = type(bijector).__name__ support = bhps.bijector_supports()[bijector_name].inverse if is_generalized_pareto(bijector): constraint_fn = bhps.generalized_pareto_constraint( bijector.loc, bijector.scale, bijector.concentration) elif isinstance(bijector, tfb.SoftClip): constraint_fn = bhps.softclip_constraint(bijector.low, bijector.high) else: constraint_fn = tfp_hps.constrainer(support) return draw(tfp_hps.constrained_tensors(constraint_fn, shape))
def domain_tensors(draw, bijector, shape=None): """Strategy for drawing Tensors in the domain of a bijector. If the bijector's domain is constrained, this proceeds by drawing an unconstrained Tensor and then transforming it to fit. The constraints are declared in `bijectors.hypothesis_testlib.bijector_supports`. The transformations are defined by `tfp_hps.constrainer`. Args: draw: Hypothesis strategy sampler supplied by `@hps.composite`. bijector: A `Bijector` in whose domain the Tensors will be. shape: An optional `TensorShape`. The shape of the resulting Tensors. Hypothesis will pick one if omitted. Returns: tensors: A strategy for drawing domain Tensors for the desired bijector. """ if is_invert(bijector): return draw(codomain_tensors(bijector.bijector, shape)) if shape is None: shape = draw(tfp_hps.shapes()) bijector_name = type(bijector).__name__ support = bijector_hps.bijector_supports()[bijector_name].forward if isinstance(bijector, tfb.PowerTransform): constraint_fn = bijector_hps.power_transform_constraint(bijector.power) else: constraint_fn = tfp_hps.constrainer(support) return draw(tfp_hps.constrained_tensors(constraint_fn, shape))