Exemple #1
0
def codomain_tensors(draw, bijector, shape=None):
    """Strategy for drawing Tensors in the codomain of a bijector.

  If the bijector's codomain is constrained, this proceeds by drawing an
  unconstrained Tensor and then transforming it to fit.  The constraints are
  declared in `bijectors.hypothesis_testlib.bijector_supports`.  The
  transformations are defined by `tfp_hps.constrainer`.

  Args:
    draw: Hypothesis strategy sampler supplied by `@hps.composite`.
    bijector: A `Bijector` in whose codomain the Tensors will be.
    shape: An optional `TensorShape`.  The shape of the resulting
      Tensors.  Hypothesis will pick one if omitted.

  Returns:
    tensors: A strategy for drawing codomain Tensors for the desired bijector.
  """
    if is_invert(bijector):
        return draw(domain_tensors(bijector.bijector, shape))
    elif is_transform_diagonal(bijector):
        return draw(codomain_tensors(bijector.diag_bijector, shape))
    if shape is None:
        shape = draw(tfp_hps.shapes())
    bijector_name = type(bijector).__name__
    support = bhps.bijector_supports()[bijector_name].inverse
    if is_generalized_pareto(bijector):
        constraint_fn = bhps.generalized_pareto_constraint(
            bijector.loc, bijector.scale, bijector.concentration)
    elif isinstance(bijector, tfb.SoftClip):
        constraint_fn = bhps.softclip_constraint(bijector.low, bijector.high)
    else:
        constraint_fn = tfp_hps.constrainer(support)
    return draw(tfp_hps.constrained_tensors(constraint_fn, shape))
def domain_tensors(draw, bijector, shape=None):
  """Strategy for drawing Tensors in the domain of a bijector.

  If the bijector's domain is constrained, this proceeds by drawing an
  unconstrained Tensor and then transforming it to fit.  The constraints are
  declared in `bijectors.hypothesis_testlib.bijector_supports`.  The
  transformations are defined by `tfp_hps.constrainer`.

  Args:
    draw: Hypothesis strategy sampler supplied by `@hps.composite`.
    bijector: A `Bijector` in whose domain the Tensors will be.
    shape: An optional `TensorShape`.  The shape of the resulting
      Tensors.  Hypothesis will pick one if omitted.

  Returns:
    tensors: A strategy for drawing domain Tensors for the desired bijector.
  """
  if is_invert(bijector):
    return draw(codomain_tensors(bijector.bijector, shape))
  if shape is None:
    shape = draw(tfp_hps.shapes())
  bijector_name = type(bijector).__name__
  support = bijector_hps.bijector_supports()[bijector_name].forward
  if isinstance(bijector, tfb.PowerTransform):
    constraint_fn = bijector_hps.power_transform_constraint(bijector.power)
  else:
    constraint_fn = tfp_hps.constrainer(support)
  return draw(tfp_hps.constrained_tensors(constraint_fn, shape))