Exemple #1
0
def prefer_static_cond(predicate, true_fn, false_fn, name=None):
    """Identical to `tf.cond` but operates statically if possible."""
    with tf.name_scope(name, 'prefer_static_cond', [predicate]):
        predicate_ = distributions_util.maybe_get_static_value(predicate)
        if predicate_ is None:
            return tf.cond(predicate, true_fn, false_fn)
        return true_fn() if predicate_ else false_fn()
  def _maybe_get_static_event_ndims(self):
    if self.event_shape.ndims is not None:
      return self.event_shape.ndims

    event_ndims = array_ops.size(self.event_shape_tensor())
    event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)

    if event_ndims_ is not None:
      return event_ndims_

    return event_ndims
    def _maybe_get_static_event_ndims(self):
        if self.event_shape.ndims is not None:
            return self.event_shape.ndims

        event_ndims = array_ops.size(self.event_shape_tensor())
        event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)

        if event_ndims_ is not None:
            return event_ndims_

        return event_ndims
Exemple #4
0
    def _maybe_get_event_ndims_statically(self, event_ndims):
        """Helper which returns tries to return an integer static value."""
        event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)

        if isinstance(event_ndims_, np.ndarray):
            if (event_ndims_.dtype not in (np.int32, np.int64)
                    or len(event_ndims_.shape)):
                raise ValueError(
                    "Expected a scalar integer, got {}".format(event_ndims_))
            event_ndims_ = event_ndims_.tolist()

        return event_ndims_
Exemple #5
0
def prefer_static_reduce_all(preds, name=None):
    """Identical to `tf.reduce_all` but operates statically if possible."""
    with tf.name_scope(name, 'prefer_static_reduce_all', [preds]):
        preds_ = [
            distributions_util.maybe_get_static_value(p, np.bool)
            for p in preds
        ]
        if any(p is False for p in preds_):
            return False
        if any(p is None for p in preds_):
            return tf.reduce_all(preds)
        return all(preds_)
  def _maybe_get_static_event_ndims(self, event_ndims):
    """Helper which returns tries to return an integer static value."""
    event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)

    if isinstance(event_ndims_, (np.generic, np.ndarray)):
      if event_ndims_.dtype not in (np.int32, np.int64):
        raise ValueError("Expected integer dtype, got dtype {}".format(
            event_ndims_.dtype))

      if isinstance(event_ndims_, np.ndarray) and len(event_ndims_.shape):
        raise ValueError("Expected a scalar integer, got {}".format(
            event_ndims_))
      event_ndims_ = int(event_ndims_)

    return event_ndims_
Exemple #7
0
  def _maybe_get_static_event_ndims(self, event_ndims):
    """Helper which returns tries to return an integer static value."""
    event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)

    if isinstance(event_ndims_, (np.generic, np.ndarray)):
      if event_ndims_.dtype not in (np.int32, np.int64):
        raise ValueError("Expected integer dtype, got dtype {}".format(
            event_ndims_.dtype))

      if isinstance(event_ndims_, np.ndarray) and len(event_ndims_.shape):
        raise ValueError("Expected a scalar integer, got {}".format(
            event_ndims_))
      event_ndims_ = int(event_ndims_)

    return event_ndims_
def fit_one_step(
    model_matrix,
    response,
    model,
    model_coefficients_start=None,
    predicted_linear_response_start=None,
    l2_regularizer=None,
    dispersion=None,
    offset=None,
    learning_rate=None,
    fast_unsafe_numerics=True,
    name=None):
  """Runs one step of Fisher scoring.

  Args:
    model_matrix: (Batch of) `float`-like, matrix-shaped `Tensor` where each row
      represents a sample's features.
    response: (Batch of) vector-shaped `Tensor` where each element represents a
      sample's observed response (to the corresponding row of features). Must
      have same `dtype` as `model_matrix`.
    model: `tfp.glm.ExponentialFamily`-like instance used to construct the
      negative log-likelihood loss, gradient, and expected Hessian (i.e., the
      Fisher information matrix).
    model_coefficients_start: Optional (batch of) vector-shaped `Tensor`
      representing the initial model coefficients, one for each column in
      `model_matrix`. Must have same `dtype` as `model_matrix`.
      Default value: Zeros.
    predicted_linear_response_start: Optional `Tensor` with `shape`, `dtype`
      matching `response`; represents `offset` shifted initial linear
      predictions based on `model_coefficients_start`.
      Default value: `offset` if `model_coefficients is None`, and
      `tfp.math.matvecmul(model_matrix, model_coefficients_start) + offset`
      otherwise.
    l2_regularizer: Optional scalar `Tensor` representing L2 regularization
      penalty, i.e.,
      `loss(w) = sum{-log p(y[i]|x[i],w) : i=1..n} + l2_regularizer ||w||_2^2`.
      Default value: `None` (i.e., no L2 regularization).
    dispersion: Optional (batch of) `Tensor` representing `response` dispersion,
      i.e., as in, `p(y|theta) := exp((y theta - A(theta)) / dispersion)`.
      Must broadcast with rows of `model_matrix`.
      Default value: `None` (i.e., "no dispersion").
    offset: Optional `Tensor` representing constant shift applied to
      `predicted_linear_response`.  Must broadcast to `response`.
      Default value: `None` (i.e., `tf.zeros_like(response)`).
    learning_rate: Optional (batch of) scalar `Tensor` used to dampen iterative
      progress. Typically only needed if optimization diverges, should be no
      larger than `1` and typically very close to `1`.
      Default value: `None` (i.e., `1`).
    fast_unsafe_numerics: Optional Python `bool` indicating if solve should be
      based on Cholesky or QR decomposition.
      Default value: `True` (i.e., "prefer speed via Cholesky decomposition").
    name: Python `str` used as name prefix to ops created by this function.
      Default value: `"fit_one_step"`.

  Returns:
    model_coefficients: (Batch of) vector-shaped `Tensor`; represents the
      next estimate of the model coefficients, one for each column in
      `model_matrix`.
    predicted_linear_response: `response`-shaped `Tensor` representing linear
      predictions based on new `model_coefficients`, i.e.,
      `tfp.math.matvecmul(model_matrix, model_coefficients_next) + offset`.
  """
  graph_deps = [model_matrix, response, model_coefficients_start,
                predicted_linear_response_start, dispersion, learning_rate]
  with tf.name_scope(name, 'fit_one_step', graph_deps):

    [
        model_matrix,
        response,
        model_coefficients_start,
        predicted_linear_response_start,
        offset,
    ] = prepare_args(
        model_matrix,
        response,
        model_coefficients_start,
        predicted_linear_response_start,
        offset)

    # Compute: mean, grad(mean, predicted_linear_response_start), and variance.
    mean, variance, grad_mean = model(predicted_linear_response_start)

    # If either `grad_mean` or `variance is non-finite or zero, then we'll
    # replace it with a value such that the row is zeroed out. Although this
    # procedure may seem circuitous, it is necessary to ensure this algorithm is
    # itself differentiable.
    is_valid = (tf.is_finite(grad_mean) & tf.not_equal(grad_mean, 0.) &
                tf.is_finite(variance) & (variance > 0.))
    def mask_if_invalid(x, mask):
      mask = tf.fill(tf.shape(x), value=np.array(mask, x.dtype.as_numpy_dtype))
      return tf.where(is_valid, x, mask)

    # Run one step of iteratively reweighted least-squares.
    # Compute "`z`", the adjusted predicted linear response.
    # z = predicted_linear_response_start
    #     + learning_rate * (response - mean) / grad_mean
    z = (response - mean) / mask_if_invalid(grad_mean, 1.)
    # TODO(jvdillon): Rather than use learning rate, we should consider using
    # backtracking line search.
    if learning_rate is not None:
      z *= learning_rate[..., tf.newaxis]
    z += predicted_linear_response_start

    # Compute "`w`", the per-sample weight.
    if dispersion is not None:
      # For convenience, we'll now scale the variance by the dispersion factor.
      variance *= dispersion
    w = (mask_if_invalid(grad_mean, 0.) *
         tf.rsqrt(mask_if_invalid(variance, np.inf)))

    a = model_matrix * w[..., tf.newaxis]
    b = z * w
    # Solve `min{ || A @ model_coefficients - b ||_2**2 : model_coefficients }`
    # where `@` denotes `matmul`.

    if l2_regularizer is None:
      l2_regularizer = np.array(0, a.dtype.as_numpy_dtype)
    else:
      l2_regularizer_ = distributions_util.maybe_get_static_value(
          l2_regularizer, a.dtype.as_numpy_dtype)
      if l2_regularizer_ is not None:
        l2_regularizer = l2_regularizer_

    def _embed_l2_regularization():
      """Adds synthetic observations to implement L2 regularization."""
      # `tf.matrix_solve_ls` does not respect the `l2_regularization` argument
      # when `fast_unsafe_numerics` is `False`. This function  adds synthetic
      # observations to the data to implement the regularization instead.
      # Adding observations `sqrt(l2_regularizer) * I` is mathematically
      # equivalent to adding the term
      # `-l2_regularizer ||coefficients||_2**2` to the log-likelihood.
      num_model_coefficients = num_cols(model_matrix)
      batch_shape = tf.shape(model_matrix)[:-2]
      eye = tf.eye(
          num_model_coefficients, batch_shape=batch_shape, dtype=a.dtype)
      a_ = tf.concat([a, tf.sqrt(l2_regularizer) * eye], axis=-2)
      b_ = distributions_util.pad(
          b, count=num_model_coefficients, axis=-1, back=True)
      # Return l2_regularizer=0 since its now embedded.
      l2_regularizer_ = np.array(0, a.dtype.as_numpy_dtype)
      return a_, b_, l2_regularizer_

    a, b, l2_regularizer = smart_cond.smart_cond(
        smart_reduce_all([not(fast_unsafe_numerics),
                          l2_regularizer > 0.]),
        _embed_l2_regularization,
        lambda: (a, b, l2_regularizer))

    model_coefficients_next = tf.matrix_solve_ls(
        a, b[..., tf.newaxis],
        fast=fast_unsafe_numerics,
        l2_regularizer=l2_regularizer,
        name='model_coefficients_next')
    model_coefficients_next = model_coefficients_next[..., 0]

    # TODO(b/79122261): The approach used in `matrix_solve_ls` could be made
    # faster by avoiding explicitly forming Q and instead keeping the
    # factorization in 'implicit' form with stacked (rescaled) Householder
    # vectors underneath the 'R' and then applying the (accumulated)
    # reflectors in the appropriate order to apply Q'. However, we don't
    # presently do this because we lack core TF functionality. For reference,
    # the vanilla QR approach is:
    #   q, r = tf.linalg.qr(a)
    #   c = tf.matmul(q, b, adjoint_a=True)
    #   model_coefficients_next = tf.matrix_triangular_solve(
    #       r, c, lower=False, name='model_coefficients_next')

    predicted_linear_response_next = calculate_linear_predictor(
        model_matrix,
        model_coefficients_next,
        offset,
        name='predicted_linear_response_next')

    return model_coefficients_next, predicted_linear_response_next