コード例 #1
0
def rq_splines(draw, batch_shape=None, dtype=tf.float32):
    if batch_shape is None:
        batch_shape = draw(tfp_hps.shapes())

    lo = draw(hps.floats(min_value=-5, max_value=.5))
    hi = draw(hps.floats(min_value=-.5, max_value=5))
    lo, hi = min(lo, hi), max(lo, hi) + .2
    hp.note('lo, hi: {!r}'.format((lo, hi)))

    constraints = dict(
        bin_widths=functools.partial(bijector_hps.spline_bin_size_constraint,
                                     hi=hi,
                                     lo=lo,
                                     dtype=dtype),
        bin_heights=functools.partial(bijector_hps.spline_bin_size_constraint,
                                      hi=hi,
                                      lo=lo,
                                      dtype=dtype),
        knot_slopes=functools.partial(bijector_hps.spline_slope_constraint,
                                      dtype=dtype))
    params = draw(
        tfp_hps.broadcasting_params(batch_shape,
                                    params_event_ndims=dict(bin_widths=1,
                                                            bin_heights=1,
                                                            knot_slopes=1),
                                    constraint_fn_for=constraints.get))
    hp.note('params: {!r}'.format(params))
    return tfb.RationalQuadraticSpline(range_min=lo,
                                       validate_args=draw(hps.booleans()),
                                       **params)
コード例 #2
0
def broadcasting_params(draw,
                        bijector_name,
                        batch_shape,
                        event_dim=None,
                        enable_vars=False):
    """Draws a dict of parameters which should yield the given batch shape."""
    params_event_ndims = BIJECTOR_PARAMS_NDIMS.get(bijector_name, {})

    def _constraint(param):
        return constraint_for(bijector_name, param)

    return draw(
        tfp_hps.broadcasting_params(batch_shape,
                                    params_event_ndims,
                                    event_dim=event_dim,
                                    enable_vars=enable_vars,
                                    constraint_fn_for=_constraint,
                                    mutex_params=MUTEX_PARAMS))
コード例 #3
0
def broadcasting_params(draw,
                        dist_name,
                        batch_shape,
                        event_dim=None,
                        enable_vars=False):
    """Strategy for drawing parameters broadcasting to `batch_shape`."""
    params_event_ndims = INSTANTIABLE_BASE_DISTS[dist_name].params_event_ndims

    def _constraint(param):
        return constraint_for(dist_name, param)

    return draw(
        tfp_hps.broadcasting_params(batch_shape,
                                    params_event_ndims,
                                    event_dim=event_dim,
                                    enable_vars=enable_vars,
                                    constraint_fn_for=_constraint,
                                    mutex_params=MUTEX_PARAMS))
コード例 #4
0
def broadcasting_params(draw,
                        dist_name,
                        batch_shape,
                        event_dim=None,
                        enable_vars=False):
    """Draws a dict of parameters which should yield the given batch shape."""
    _, params_event_ndims = INSTANTIABLE_DISTS[dist_name]

    def _constraint(param):
        return constraint_for(dist_name, param)

    return draw(
        tfp_hps.broadcasting_params(batch_shape,
                                    event_dim=event_dim,
                                    enable_vars=enable_vars,
                                    params_event_ndims=params_event_ndims,
                                    constraint_fn_for=_constraint,
                                    mutex_params=MUTEX_PARAMS))
コード例 #5
0
def generalized_paretos(draw, batch_shape=None):
  if batch_shape is None:
    batch_shape = draw(tfp_hps.shapes())

  constraints = dict(
      loc=tfp_hps.identity_fn,
      scale=tfp_hps.softplus_plus_eps(),
      concentration=lambda x: tf.math.tanh(x) * 0.24)  # <.25==safe for variance

  params = draw(
      tfp_hps.broadcasting_params(
          batch_shape,
          params_event_ndims=dict(loc=0, scale=0, concentration=0),
          constraint_fn_for=constraints.get))
  dist = tfd.GeneralizedPareto(validate_args=draw(hps.booleans()), **params)
  if dist.batch_shape != batch_shape:
    raise AssertionError('batch_shape mismatch: expect {} but got {}'.format(
        batch_shape, dist))
  return dist
コード例 #6
0
def broadcasting_params(draw,
                        kernel_name,
                        batch_shape,
                        event_dim=None,
                        enable_vars=False):
    """Draws a dict of parameters which should yield the given batch shape."""
    if kernel_name not in INSTANTIABLE_BASE_KERNELS:
        raise ValueError('Unknown Kernel name {}'.format(kernel_name))
    params_event_ndims = INSTANTIABLE_BASE_KERNELS.get(kernel_name, {})

    def _constraint(param):
        return constraint_for(kernel_name, param)

    return draw(
        tfp_hps.broadcasting_params(batch_shape,
                                    params_event_ndims,
                                    event_dim=event_dim,
                                    enable_vars=enable_vars,
                                    constraint_fn_for=_constraint,
                                    mutex_params=MUTEX_PARAMS))
コード例 #7
0
def broadcasting_params(draw,
                        dist_name,
                        batch_shape,
                        event_dim=None,
                        enable_vars=False):
    """Strategy for drawing parameters broadcasting to `batch_shape`."""
    if dist_name not in INSTANTIABLE_BASE_DISTS:
        raise ValueError('Unknown Distribution name {}'.format(dist_name))

    params_event_ndims = INSTANTIABLE_BASE_DISTS[dist_name].params_event_ndims

    def _constraint(param):
        return constraint_for(dist_name, param)

    return draw(
        tfp_hps.broadcasting_params(batch_shape,
                                    params_event_ndims,
                                    event_dim=event_dim,
                                    enable_vars=enable_vars,
                                    constraint_fn_for=_constraint,
                                    mutex_params=MUTEX_PARAMS))
コード例 #8
0
def broadcasting_params(draw,
                        process_name,
                        batch_shape,
                        event_dim=None,
                        enable_vars=False):
    """Strategy for drawing parameters broadcasting to `batch_shape`."""
    if process_name not in PARAM_EVENT_NDIMS_BY_PROCESS_NAME:
        raise ValueError('Unknown Process name {}'.format(process_name))

    params_event_ndims = PARAM_EVENT_NDIMS_BY_PROCESS_NAME[process_name]

    def _constraint(param):
        return constraint_for(process_name, param)

    return draw(
        tfp_hps.broadcasting_params(batch_shape,
                                    params_event_ndims,
                                    event_dim=event_dim,
                                    enable_vars=enable_vars,
                                    constraint_fn_for=_constraint,
                                    mutex_params=MUTEX_PARAMS,
                                    dtype=np.float64))
コード例 #9
0
def bijectors(draw,
              bijector_name=None,
              batch_shape=None,
              event_dim=None,
              enable_vars=False,
              allowed_bijectors=None,
              validate_args=True,
              return_duplicate=False):
    """Strategy for drawing Bijectors.

  The emitted bijector may be a basic bijector or an `Invert` of a basic
  bijector, but not a compound like `Chain`.

  Args:
    draw: Hypothesis strategy sampler supplied by `@hps.composite`.
    bijector_name: Optional Python `str`.  If given, the produced bijectors
      will all have this type.  If omitted, Hypothesis chooses one from
      the allowlist `INSTANTIABLE_BIJECTORS`.
    batch_shape: An optional `TensorShape`.  The batch shape of the resulting
      bijector.  Hypothesis will pick one if omitted.
    event_dim: Optional Python int giving the size of each of the underlying
      distribution's parameters' event dimensions.  This is shared across all
      parameters, permitting square event matrices, compatible location and
      scale Tensors, etc. If omitted, Hypothesis will choose one.
    enable_vars: TODO(bjp): Make this `True` all the time and put variable
      initialization in slicing_test.  If `False`, the returned parameters are
      all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor`
      `tfp.util.TransformedVariable`}
    allowed_bijectors: Optional list of `str` Bijector names to sample from.
      Bijectors not in this list will not be returned or instantiated as
      part of a meta-bijector (Chain, Invert, etc.). Defaults to
      `INSTANTIABLE_BIJECTORS`.
    validate_args: Python `bool`; whether to enable runtime checks.
    return_duplicate: Python `bool`: If `False` return a single bijector. If
      `True` return a tuple of two bijectors of the same type, instantiated with
      the same parameters.

  Returns:
    bijectors: A strategy for drawing bijectors with the specified `batch_shape`
      (or an arbitrary one if omitted).
  """
    if allowed_bijectors is None:
        allowed_bijectors = bhps.INSTANTIABLE_BIJECTORS
    if bijector_name is None:
        bijector_name = draw(hps.sampled_from(allowed_bijectors))
    if batch_shape is None:
        batch_shape = draw(tfp_hps.shapes())
    if event_dim is None:
        event_dim = draw(hps.integers(min_value=2, max_value=6))
    if bijector_name == 'Invert':
        underlying_name = draw(
            hps.sampled_from(sorted(set(allowed_bijectors) - {'Invert'})))
        underlying = draw(
            bijectors(bijector_name=underlying_name,
                      batch_shape=batch_shape,
                      event_dim=event_dim,
                      enable_vars=enable_vars,
                      allowed_bijectors=allowed_bijectors,
                      validate_args=validate_args))
        bijector_params = {'bijector': underlying}
        msg = 'Forming Invert bijector with underlying bijector {}.'
        hp.note(msg.format(underlying))
    elif bijector_name == 'TransformDiagonal':
        underlying_name = draw(
            hps.sampled_from(
                sorted(
                    set(allowed_bijectors)
                    & set(bhps.TRANSFORM_DIAGONAL_ALLOWLIST))))
        underlying = draw(
            bijectors(bijector_name=underlying_name,
                      batch_shape=(),
                      event_dim=event_dim,
                      enable_vars=enable_vars,
                      allowed_bijectors=allowed_bijectors,
                      validate_args=validate_args))
        bijector_params = {'diag_bijector': underlying}
        msg = 'Forming TransformDiagonal bijector with underlying bijector {}.'
        hp.note(msg.format(underlying))
    elif bijector_name == 'Inline':
        scale = draw(
            tfp_hps.maybe_variable(
                hps.sampled_from(np.float32([1., -1., 2, -2.])), enable_vars))
        b = tfb.Scale(scale=scale)

        bijector_params = dict(
            forward_fn=CallableModule(b.forward, b),
            inverse_fn=b.inverse,
            forward_log_det_jacobian_fn=lambda x: b.forward_log_det_jacobian(  # pylint: disable=g-long-lambda
                x,
                event_ndims=b.forward_min_event_ndims),
            forward_min_event_ndims=b.forward_min_event_ndims,
            is_constant_jacobian=b.is_constant_jacobian,
            is_increasing=b._internal_is_increasing,  # pylint: disable=protected-access
        )
    elif bijector_name == 'DiscreteCosineTransform':
        dct_type = hps.integers(min_value=2, max_value=3)
        bijector_params = {'dct_type': draw(dct_type)}
    elif bijector_name == 'GeneralizedPareto':
        concentration = hps.floats(min_value=-200., max_value=200)
        scale = hps.floats(min_value=1e-2, max_value=200)
        loc = hps.floats(min_value=-200, max_value=200)
        bijector_params = {
            'concentration': draw(concentration),
            'scale': draw(scale),
            'loc': draw(loc)
        }
    elif bijector_name == 'PowerTransform':
        power = hps.floats(min_value=1e-6, max_value=10.)
        bijector_params = {'power': draw(power)}
    elif bijector_name == 'Permute':
        event_ndims = draw(hps.integers(min_value=1, max_value=2))
        axis = hps.integers(min_value=-event_ndims, max_value=-1)
        # This is a permutation of dimensions within an axis.
        # (Contrast with `Transpose` below.)
        bijector_params = {
            'axis':
            draw(axis),
            'permutation':
            draw(
                tfp_hps.maybe_variable(hps.permutations(np.arange(event_dim)),
                                       enable_vars,
                                       dtype=tf.int32))
        }
    elif bijector_name == 'Reshape':
        event_shape_out = draw(tfp_hps.shapes(min_ndims=1))
        # TODO(b/142135119): Wanted to draw general input and output shapes like the
        # following, but Hypothesis complained about filtering out too many things.
        # event_shape_in = draw(tfp_hps.shapes(min_ndims=1))
        # hp.assume(event_shape_out.num_elements() == event_shape_in.num_elements())
        event_shape_in = [event_shape_out.num_elements()]
        bijector_params = {
            'event_shape_out': event_shape_out,
            'event_shape_in': event_shape_in
        }
    elif bijector_name == 'Transpose':
        event_ndims = draw(hps.integers(min_value=0, max_value=2))
        # This is a permutation of axes.
        # (Contrast with `Permute` above.)
        bijector_params = {
            'perm': draw(hps.permutations(np.arange(event_ndims)))
        }
    else:
        params_event_ndims = bhps.INSTANTIABLE_BIJECTORS[
            bijector_name].params_event_ndims
        bijector_params = draw(
            tfp_hps.broadcasting_params(
                batch_shape,
                params_event_ndims,
                event_dim=event_dim,
                enable_vars=enable_vars,
                constraint_fn_for=lambda param: constraint_for(
                    bijector_name, param),  # pylint:disable=line-too-long
                mutex_params=MUTEX_PARAMS))
        bijector_params = constrain_params(bijector_params, bijector_name)

    ctor = getattr(tfb, bijector_name)
    hp.note('Forming {} bijector with params {}.'.format(
        bijector_name, bijector_params))
    bijector = ctor(validate_args=validate_args, **bijector_params)
    if not return_duplicate:
        return bijector
    return (bijector, ctor(validate_args=validate_args, **bijector_params))
コード例 #10
0
def changepoints(draw,
                 batch_shape=None,
                 event_dim=None,
                 feature_dim=None,
                 feature_ndims=None,
                 enable_vars=None,
                 depth=None):
    """Strategy for drawing `Changepoint` kernels.

  The underlying kernel is drawn from the `kernels` strategy.

  Args:
    draw: Hypothesis strategy sampler supplied by `@hps.composite`.
    batch_shape: An optional `TensorShape`.  The batch shape of the resulting
      Kernel.  Hypothesis will pick a batch shape if omitted.
    event_dim: Optional Python int giving the size of each of the
      kernel's parameters' event dimensions.  This is shared across all
      parameters, permitting square event matrices, compatible location and
      scale Tensors, etc. If omitted, Hypothesis will choose one.
    feature_dim: Optional Python int giving the size of each feature dimension.
      If omitted, Hypothesis will choose one.
    feature_ndims: Optional Python int stating the number of feature dimensions
      inputs will have. If omitted, Hypothesis will choose one.
    enable_vars: TODO(bjp): Make this `True` all the time and put variable
      initialization in slicing_test.  If `False`, the returned parameters are
      all Tensors, never Variables or DeferredTensor.
    depth: Python `int` giving maximum nesting depth of compound kernel.

  Returns:
    kernels: A strategy for drawing `Changepoint` kernels with the specified
      `batch_shape` (or an arbitrary one if omitted).
  """
    if depth is None:
        depth = draw(depths())
    if batch_shape is None:
        batch_shape = draw(tfp_hps.shapes())
    if event_dim is None:
        event_dim = draw(hps.integers(min_value=2, max_value=6))
    if feature_dim is None:
        feature_dim = draw(hps.integers(min_value=2, max_value=6))
    if feature_ndims is None:
        feature_ndims = draw(hps.integers(min_value=2, max_value=6))

    num_kernels = draw(hps.integers(min_value=2, max_value=4))

    inner_kernels = []
    kernel_variable_names = []
    for _ in range(num_kernels):
        base_kernel, variable_names = draw(
            kernels(batch_shape=batch_shape,
                    event_dim=event_dim,
                    feature_dim=feature_dim,
                    feature_ndims=feature_ndims,
                    enable_vars=False,
                    depth=depth - 1))
        inner_kernels.append(base_kernel)
        kernel_variable_names += variable_names

    constraints = dict(
        locs=lambda x: tf.cumsum(tf.math.abs(x) + 1e-3, axis=-1),
        slopes=tfp_hps.softplus_plus_eps())

    params = draw(
        tfp_hps.broadcasting_params(batch_shape,
                                    event_dim=num_kernels - 1,
                                    params_event_ndims=dict(locs=1, slopes=1),
                                    constraint_fn_for=constraints.get))
    params = {k: tf.cast(params[k], tf.float64) for k in params}

    if enable_vars and draw(hps.booleans()):
        kernel_variable_names.append('locs')
        kernel_variable_names.append('slopes')
        params['locs'] = tf.Variable(params['locs'], name='locs')
        params['slopes'] = tf.Variable(params['slopes'], name='slopes')
    result_kernel = tfpk.ChangePoint(kernels=inner_kernels,
                                     validate_args=True,
                                     **params)
    return result_kernel, kernel_variable_names