Exemplo n.º 1
0
def _validate_elem_length(max_num_levels, elems_flat):
  """Checks that elems all have the same length, and returns that length."""
  assertions = []

  elem_length = prefer_static.shape(elems_flat[0])[0]

  # The default size limit will overflow a 32-bit int, so make sure we're
  # using 64-bit.
  size_limit = 2**(prefer_static.cast(max_num_levels, np.int64) + 1)
  enough_levels = prefer_static.less(
      prefer_static.cast(elem_length, np.int64), size_limit)
  enough_levels_ = tf.get_static_value(enough_levels)
  if enough_levels_ is None:
    assertions.append(
        tf.debugging.assert_equal(
            enough_levels, True,
            message='Input `Tensor`s must have first axis dimension less than'
            ' `2**(max_num_levels + 1)`'
            ' (saw: {} which is not less than 2**{} == {})'.format(
                elem_length,
                max_num_levels,
                size_limit)))
  elif not enough_levels_:
    raise ValueError(
        'Input `Tensor`s must have first axis dimension less than'
        ' `2**(max_num_levels + 1)`'
        ' (saw: {} which is not less than 2**{} == {})'.format(
            elem_length,
            max_num_levels,
            size_limit))

  is_consistent = prefer_static.reduce_all([
      prefer_static.equal(
          prefer_static.shape(elem)[0], elem_length)
      for elem in elems_flat[1:]])

  is_consistent_ = tf.get_static_value(is_consistent)
  if is_consistent_ is None:
    assertions.append(
        tf.debugging.assert_equal(
            is_consistent, True,
            message='Input `Tensor`s must have the same first dimension.'
            ' (saw: {})'.format([elem.shape for elem in elems_flat])))
  elif not is_consistent_:
    raise ValueError(
        'Input `Tensor`s must have the same first dimension.'
        ' (saw: {})'.format([elem.shape for elem in elems_flat]))
  return elem_length, assertions
 def _needs_rotation(self, override_event_shape, override_batch_shape,
                     base_is_scalar_batch):
     # To convert a scalar distribution into a multivariate distribution we
     # will permute dims from the sample dims, which are otherwise iid. This is
     # easy to do except in the case that the base distribution has nonscalar
     # batch and we're overriding only event shape. Under these conditions, this
     # function returns `True`, indicating that event dims will incorrectly be to
     # the left of the batch dims and we'll need to cyclically permute left the
     # new dims (in `_maybe_rotate_dims`). If these conditions do not hold, this
     # function returns `False` and no rotation is needed.
     return prefer_static.reduce_all([
         self._has_nonzero_rank(override_event_shape),
         prefer_static.logical_not(
             self._has_nonzero_rank(override_batch_shape)),
         prefer_static.logical_not(base_is_scalar_batch)
     ])
def fit_one_step(
    model_matrix,
    response,
    model,
    model_coefficients_start=None,
    predicted_linear_response_start=None,
    l2_regularizer=None,
    dispersion=None,
    offset=None,
    learning_rate=None,
    fast_unsafe_numerics=True,
    l2_regularization_penalty_factor=None,
    name=None):
  """Runs one step of Fisher scoring.

  Args:
    model_matrix: (Batch of) `float`-like, matrix-shaped `Tensor` where each row
      represents a sample's features.
    response: (Batch of) vector-shaped `Tensor` where each element represents a
      sample's observed response (to the corresponding row of features). Must
      have same `dtype` as `model_matrix`.
    model: `tfp.glm.ExponentialFamily`-like instance used to construct the
      negative log-likelihood loss, gradient, and expected Hessian (i.e., the
      Fisher information matrix).
    model_coefficients_start: Optional (batch of) vector-shaped `Tensor`
      representing the initial model coefficients, one for each column in
      `model_matrix`. Must have same `dtype` as `model_matrix`.
      Default value: Zeros.
    predicted_linear_response_start: Optional `Tensor` with `shape`, `dtype`
      matching `response`; represents `offset` shifted initial linear
      predictions based on `model_coefficients_start`.
      Default value: `offset` if `model_coefficients is None`, and
      `tf.linalg.matvec(model_matrix, model_coefficients_start) + offset`
      otherwise.
    l2_regularizer: Optional scalar `Tensor` representing L2 regularization
      penalty, i.e.,
      `loss(w) = sum{-log p(y[i]|x[i],w) : i=1..n} + l2_regularizer ||w||_2^2`.
      Default value: `None` (i.e., no L2 regularization).
    dispersion: Optional (batch of) `Tensor` representing `response` dispersion,
      i.e., as in, `p(y|theta) := exp((y theta - A(theta)) / dispersion)`.
      Must broadcast with rows of `model_matrix`.
      Default value: `None` (i.e., "no dispersion").
    offset: Optional `Tensor` representing constant shift applied to
      `predicted_linear_response`.  Must broadcast to `response`.
      Default value: `None` (i.e., `tf.zeros_like(response)`).
    learning_rate: Optional (batch of) scalar `Tensor` used to dampen iterative
      progress. Typically only needed if optimization diverges, should be no
      larger than `1` and typically very close to `1`.
      Default value: `None` (i.e., `1`).
    fast_unsafe_numerics: Optional Python `bool` indicating if solve should be
      based on Cholesky or QR decomposition.
      Default value: `True` (i.e., "prefer speed via Cholesky decomposition").
    l2_regularization_penalty_factor: Optional (batch of) vector-shaped
      `Tensor`, representing a separate penalty factor to apply to each model
      coefficient, length equal to columns in `model_matrix`. Each penalty
      factor multiplies l2_regularizer to allow differential regularization. Can
      be 0 for some coefficients, which implies no regularization. Default is 1
      for all coefficients.
      `loss(w) = sum{-log p(y[i]|x[i],w) : i=1..n} + l2_regularizer ||w *
        l2_regularization_penalty_factor||_2^2`
    name: Python `str` used as name prefix to ops created by this function.
      Default value: `"fit_one_step"`.

  Returns:
    model_coefficients: (Batch of) vector-shaped `Tensor`; represents the
      next estimate of the model coefficients, one for each column in
      `model_matrix`.
    predicted_linear_response: `response`-shaped `Tensor` representing linear
      predictions based on new `model_coefficients`, i.e.,
      `tf.linalg.matvec(model_matrix, model_coefficients_next) + offset`.
  """
  with tf.name_scope(name or 'fit_one_step'):

    [
        model_matrix,
        response,
        model_coefficients_start,
        predicted_linear_response_start,
        offset,
    ] = prepare_args(
        model_matrix,
        response,
        model_coefficients_start,
        predicted_linear_response_start,
        offset)

    # Compute: mean, grad(mean, predicted_linear_response_start), and variance.
    mean, variance, grad_mean = model(predicted_linear_response_start)

    # If either `grad_mean` or `variance is non-finite or zero, then we'll
    # replace it with a value such that the row is zeroed out. Although this
    # procedure may seem circuitous, it is necessary to ensure this algorithm is
    # itself differentiable.
    is_valid = (
        tf.math.is_finite(grad_mean) & tf.not_equal(grad_mean, 0.)
        & tf.math.is_finite(variance) & (variance > 0.))

    def mask_if_invalid(x, mask):
      return tf.where(
          is_valid, x, np.array(mask, dtype_util.as_numpy_dtype(x.dtype)))

    # Run one step of iteratively reweighted least-squares.
    # Compute "`z`", the adjusted predicted linear response.
    # z = predicted_linear_response_start
    #     + learning_rate * (response - mean) / grad_mean
    z = (response - mean) / mask_if_invalid(grad_mean, 1.)
    # TODO(jvdillon): Rather than use learning rate, we should consider using
    # backtracking line search.
    if learning_rate is not None:
      z *= learning_rate[..., tf.newaxis]
    z += predicted_linear_response_start
    if offset is not None:
      z -= offset

    # Compute "`w`", the per-sample weight.
    if dispersion is not None:
      # For convenience, we'll now scale the variance by the dispersion factor.
      variance *= dispersion
    w = (
        mask_if_invalid(grad_mean, 0.) *
        tf.math.rsqrt(mask_if_invalid(variance, np.inf)))

    a = model_matrix * w[..., tf.newaxis]
    b = z * w
    # Solve `min{ || A @ model_coefficients - b ||_2**2 : model_coefficients }`
    # where `@` denotes `matmul`.

    if l2_regularizer is None:
      l2_regularizer = np.array(0, dtype_util.as_numpy_dtype(a.dtype))
    else:
      l2_regularizer_ = distribution_util.maybe_get_static_value(
          l2_regularizer, dtype_util.as_numpy_dtype(a.dtype))
      if l2_regularizer_ is not None:
        l2_regularizer = l2_regularizer_

    def _embed_l2_regularization():
      """Adds synthetic observations to implement L2 regularization."""
      # `tf.matrix_solve_ls` does not respect the `l2_regularization` argument
      # when `fast_unsafe_numerics` is `False`. This function  adds synthetic
      # observations to the data to implement the regularization instead.
      # Adding observations `sqrt(l2_regularizer) * I` is mathematically
      # equivalent to adding the term
      # `-l2_regularizer ||coefficients||_2**2` to the log-likelihood.
      num_model_coefficients = num_cols(model_matrix)
      batch_shape = tf.shape(model_matrix)[:-2]
      if l2_regularization_penalty_factor is None:
        eye = tf.eye(
            num_model_coefficients, batch_shape=batch_shape, dtype=a.dtype)
      else:
        eye = tf.linalg.tensor_diag(
            tf.cast(l2_regularization_penalty_factor, dtype=a.dtype))
        broadcasted_shape = prefer_static.concat(
            [batch_shape, [num_model_coefficients, num_model_coefficients]],
            axis=0)
        eye = tf.broadcast_to(eye, broadcasted_shape)
      a_ = tf.concat([a, tf.sqrt(l2_regularizer) * eye], axis=-2)
      b_ = distribution_util.pad(
          b, count=num_model_coefficients, axis=-1, back=True)
      # Return l2_regularizer=0 since its now embedded.
      l2_regularizer_ = np.array(0, dtype_util.as_numpy_dtype(a.dtype))
      return a_, b_, l2_regularizer_

    a, b, l2_regularizer = prefer_static.cond(
        prefer_static.reduce_all([
            prefer_static.logical_or(
                not(fast_unsafe_numerics),
                l2_regularization_penalty_factor is not None),
            l2_regularizer > 0.
        ]),
        _embed_l2_regularization,
        lambda: (a, b, l2_regularizer))

    model_coefficients_next = tf.linalg.lstsq(
        a,
        b[..., tf.newaxis],
        fast=fast_unsafe_numerics,
        l2_regularizer=l2_regularizer,
        name='model_coefficients_next')
    model_coefficients_next = model_coefficients_next[..., 0]

    # TODO(b/79122261): The approach used in `matrix_solve_ls` could be made
    # faster by avoiding explicitly forming Q and instead keeping the
    # factorization in 'implicit' form with stacked (rescaled) Householder
    # vectors underneath the 'R' and then applying the (accumulated)
    # reflectors in the appropriate order to apply Q'. However, we don't
    # presently do this because we lack core TF functionality. For reference,
    # the vanilla QR approach is:
    #   q, r = tf.linalg.qr(a)
    #   c = tf.matmul(q, b, adjoint_a=True)
    #   model_coefficients_next = tf.matrix_triangular_solve(
    #       r, c, lower=False, name='model_coefficients_next')

    predicted_linear_response_next = compute_predicted_linear_response(
        model_matrix,
        model_coefficients_next,
        offset,
        name='predicted_linear_response_next')

    return model_coefficients_next, predicted_linear_response_next
    def __init__(self,
                 distribution,
                 bijector,
                 batch_shape=None,
                 event_shape=None,
                 kwargs_split_fn=_default_kwargs_split_fn,
                 validate_args=False,
                 parameters=None,
                 name=None):
        """Construct a Transformed Distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      bijector: The object responsible for calculating the transformation.
        Typically an instance of `Bijector`.
      batch_shape: `integer` vector `Tensor` which overrides `distribution`
        `batch_shape`; valid only if `distribution.is_scalar_batch()`.
      event_shape: `integer` vector `Tensor` which overrides `distribution`
        `event_shape`; valid only if `distribution.is_scalar_event()`.
      kwargs_split_fn: Python `callable` which takes a kwargs `dict` and returns
        a tuple of kwargs `dict`s for each of the `distribution` and `bijector`
        parameters respectively.
        Default value: `_default_kwargs_split_fn` (i.e.,
            `lambda kwargs: (kwargs.get('distribution_kwargs', {}),
                             kwargs.get('bijector_kwargs', {}))`)
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      parameters: Locals dict captured by subclass constructor, to be used for
        copy/slice re-instantiation operations.
      name: Python `str` name prefixed to Ops created by this class. Default:
        `bijector.name + distribution.name`.
    """
        parameters = dict(locals()) if parameters is None else parameters
        name = name or (("" if bijector is None else bijector.name) +
                        (distribution.name or ""))
        with tf.name_scope(name) as name:
            self._kwargs_split_fn = (_default_kwargs_split_fn
                                     if kwargs_split_fn is None else
                                     kwargs_split_fn)
            # For convenience we define some handy constants.
            self._zero = tf.constant(0, dtype=tf.int32, name="zero")
            self._empty = tf.constant([], dtype=tf.int32, name="empty")

            # We will keep track of a static and dynamic version of
            # self._is_{batch,event}_override. This way we can do more prior to graph
            # execution, including possibly raising Python exceptions.

            self._override_batch_shape = self._maybe_validate_shape_override(
                batch_shape, distribution.is_scalar_batch(), validate_args,
                "batch_shape")
            self._is_batch_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_batch_shape),
                    self._zero))
            self._is_maybe_batch_override = bool(
                tf.get_static_value(self._override_batch_shape) is None
                or tf.get_static_value(self._override_batch_shape).size != 0)

            self._override_event_shape = self._maybe_validate_shape_override(
                event_shape, distribution.is_scalar_event(), validate_args,
                "event_shape")
            self._is_event_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_event_shape),
                    self._zero))
            self._is_maybe_event_override = bool(
                tf.get_static_value(self._override_event_shape) is None
                or tf.get_static_value(self._override_event_shape).size != 0)

            # To convert a scalar distribution into a multivariate distribution we
            # will draw dims from the sample dims, which are otherwise iid. This is
            # easy to do except in the case that the base distribution has batch dims
            # and we're overriding event shape. When that case happens the event dims
            # will incorrectly be to the left of the batch dims. In this case we'll
            # cyclically permute left the new dims.
            self._needs_rotation = prefer_static.reduce_all([
                self._is_event_override,
                prefer_static.logical_not(self._is_batch_override),
                prefer_static.logical_not(distribution.is_scalar_batch())
            ])
            override_event_ndims = prefer_static.rank_from_shape(
                self._override_event_shape)
            self._rotate_ndims = _pick_scalar_condition(
                self._needs_rotation, override_event_ndims, 0)
            # We'll be reducing the head dims (if at all), i.e., this will be []
            # if we don't need to reduce.
            self._reduce_event_indices = tf.range(
                self._rotate_ndims - override_event_ndims, self._rotate_ndims)

        self._distribution = distribution
        self._bijector = bijector
        super(TransformedDistribution, self).__init__(
            dtype=self._distribution.dtype,
            reparameterization_type=self._distribution.reparameterization_type,
            validate_args=validate_args,
            allow_nan_stats=self._distribution.allow_nan_stats,
            parameters=parameters,
            # We let TransformedDistribution access _graph_parents since this class
            # is more like a baseclass than derived.
            graph_parents=(
                distribution._graph_parents +  # pylint: disable=protected-access
                bijector.graph_parents),
            name=name)
Exemplo n.º 5
0
  def __init__(self,
               bijectors,
               block_sizes=None,
               validate_args=False,
               maybe_changes_size=True,
               name=None):
    """Creates the bijector.

    Args:
      bijectors: A non-empty list of bijectors.
      block_sizes: A 1-D integer `Tensor` with each element signifying the
        length of the block of the input vector to pass to the corresponding
        bijector. The length of `block_sizes` must be be equal to the length of
        `bijectors`. If left as None, a vector of 1's is used.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      maybe_changes_size: Python `bool` indicating that this bijector might
        change the event size. If this is known to be false and set
        appropriately, then this will lead to improved static shape inference
        when the block sizes are not statically known.
      name: Python `str`, name given to ops managed by this object. Default:
        E.g., `Blockwise([Exp(), Softplus()]).name ==
        'blockwise_of_exp_and_softplus'`.

    Raises:
      NotImplementedError: If there is a bijector with `event_ndims` > 1.
      ValueError: If `bijectors` list is empty.
      ValueError: If size of `block_sizes` does not equal to the length of
        bijectors or is not a vector.
    """
    parameters = dict(locals())
    if not name:
      name = 'blockwise_of_' + '_and_'.join([b.name for b in bijectors])
      name = name.replace('/', '')

    with tf.name_scope(name) as name:
      for b in bijectors:
        if (nest.is_nested(b.forward_min_event_ndims)
            or nest.is_nested(b.inverse_min_event_ndims)):
          raise ValueError('Bijectors must all be single-part.')
        elif isinstance(b.forward_min_event_ndims, int):
          if b.forward_min_event_ndims != b.inverse_min_event_ndims:
            raise ValueError('Rank-changing bijectors are not supported.')
          elif b.forward_min_event_ndims > 1:
            raise ValueError('Only scalar and vector event-shape '
                             'bijectors are supported at this time.')

      b_joint = joint_map.JointMap(list(bijectors), name='jointmap')

      block_sizes = (
          np.ones(len(bijectors), dtype=np.int32)
          if block_sizes is None else
          _validate_block_sizes(block_sizes, bijectors, validate_args))
      b_split = split.Split(
          block_sizes, name='split', validate_args=validate_args)

      if maybe_changes_size:
        i_block_sizes = _validate_block_sizes(
            ps.concat(b_joint.forward_event_shape_tensor(
                ps.split(block_sizes, len(bijectors))), axis=0),
            bijectors, validate_args)
        maybe_changes_size = not tf.get_static_value(
            ps.reduce_all(block_sizes == i_block_sizes))
      b_concat = invert.Invert(
          (split.Split(i_block_sizes, name='isplit')
           if maybe_changes_size else b_split),
          name='concat')

      self._maybe_changes_size = maybe_changes_size
      super(Blockwise, self).__init__(
          bijectors=[b_concat, b_joint, b_split],
          validate_args=validate_args,
          parameters=parameters,
          name=name)