def prefer_static_cond(predicate, true_fn, false_fn, name=None): """Identical to `tf.cond` but operates statically if possible.""" with tf.name_scope(name, 'prefer_static_cond', [predicate]): predicate_ = distributions_util.maybe_get_static_value(predicate) if predicate_ is None: return tf.cond(predicate, true_fn, false_fn) return true_fn() if predicate_ else false_fn()
def _maybe_get_static_event_ndims(self): if self.event_shape.ndims is not None: return self.event_shape.ndims event_ndims = array_ops.size(self.event_shape_tensor()) event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) if event_ndims_ is not None: return event_ndims_ return event_ndims
def _maybe_get_static_event_ndims(self): if self.event_shape.ndims is not None: return self.event_shape.ndims event_ndims = array_ops.size(self.event_shape_tensor()) event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) if event_ndims_ is not None: return event_ndims_ return event_ndims
def _maybe_get_event_ndims_statically(self, event_ndims): """Helper which returns tries to return an integer static value.""" event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) if isinstance(event_ndims_, np.ndarray): if (event_ndims_.dtype not in (np.int32, np.int64) or len(event_ndims_.shape)): raise ValueError( "Expected a scalar integer, got {}".format(event_ndims_)) event_ndims_ = event_ndims_.tolist() return event_ndims_
def prefer_static_reduce_all(preds, name=None): """Identical to `tf.reduce_all` but operates statically if possible.""" with tf.name_scope(name, 'prefer_static_reduce_all', [preds]): preds_ = [ distributions_util.maybe_get_static_value(p, np.bool) for p in preds ] if any(p is False for p in preds_): return False if any(p is None for p in preds_): return tf.reduce_all(preds) return all(preds_)
def _maybe_get_static_event_ndims(self, event_ndims): """Helper which returns tries to return an integer static value.""" event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) if isinstance(event_ndims_, (np.generic, np.ndarray)): if event_ndims_.dtype not in (np.int32, np.int64): raise ValueError("Expected integer dtype, got dtype {}".format( event_ndims_.dtype)) if isinstance(event_ndims_, np.ndarray) and len(event_ndims_.shape): raise ValueError("Expected a scalar integer, got {}".format( event_ndims_)) event_ndims_ = int(event_ndims_) return event_ndims_
def _maybe_get_static_event_ndims(self, event_ndims): """Helper which returns tries to return an integer static value.""" event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) if isinstance(event_ndims_, (np.generic, np.ndarray)): if event_ndims_.dtype not in (np.int32, np.int64): raise ValueError("Expected integer dtype, got dtype {}".format( event_ndims_.dtype)) if isinstance(event_ndims_, np.ndarray) and len(event_ndims_.shape): raise ValueError("Expected a scalar integer, got {}".format( event_ndims_)) event_ndims_ = int(event_ndims_) return event_ndims_
def fit_one_step( model_matrix, response, model, model_coefficients_start=None, predicted_linear_response_start=None, l2_regularizer=None, dispersion=None, offset=None, learning_rate=None, fast_unsafe_numerics=True, name=None): """Runs one step of Fisher scoring. Args: model_matrix: (Batch of) `float`-like, matrix-shaped `Tensor` where each row represents a sample's features. response: (Batch of) vector-shaped `Tensor` where each element represents a sample's observed response (to the corresponding row of features). Must have same `dtype` as `model_matrix`. model: `tfp.glm.ExponentialFamily`-like instance used to construct the negative log-likelihood loss, gradient, and expected Hessian (i.e., the Fisher information matrix). model_coefficients_start: Optional (batch of) vector-shaped `Tensor` representing the initial model coefficients, one for each column in `model_matrix`. Must have same `dtype` as `model_matrix`. Default value: Zeros. predicted_linear_response_start: Optional `Tensor` with `shape`, `dtype` matching `response`; represents `offset` shifted initial linear predictions based on `model_coefficients_start`. Default value: `offset` if `model_coefficients is None`, and `tfp.math.matvecmul(model_matrix, model_coefficients_start) + offset` otherwise. l2_regularizer: Optional scalar `Tensor` representing L2 regularization penalty, i.e., `loss(w) = sum{-log p(y[i]|x[i],w) : i=1..n} + l2_regularizer ||w||_2^2`. Default value: `None` (i.e., no L2 regularization). dispersion: Optional (batch of) `Tensor` representing `response` dispersion, i.e., as in, `p(y|theta) := exp((y theta - A(theta)) / dispersion)`. Must broadcast with rows of `model_matrix`. Default value: `None` (i.e., "no dispersion"). offset: Optional `Tensor` representing constant shift applied to `predicted_linear_response`. Must broadcast to `response`. Default value: `None` (i.e., `tf.zeros_like(response)`). learning_rate: Optional (batch of) scalar `Tensor` used to dampen iterative progress. Typically only needed if optimization diverges, should be no larger than `1` and typically very close to `1`. Default value: `None` (i.e., `1`). fast_unsafe_numerics: Optional Python `bool` indicating if solve should be based on Cholesky or QR decomposition. Default value: `True` (i.e., "prefer speed via Cholesky decomposition"). name: Python `str` used as name prefix to ops created by this function. Default value: `"fit_one_step"`. Returns: model_coefficients: (Batch of) vector-shaped `Tensor`; represents the next estimate of the model coefficients, one for each column in `model_matrix`. predicted_linear_response: `response`-shaped `Tensor` representing linear predictions based on new `model_coefficients`, i.e., `tfp.math.matvecmul(model_matrix, model_coefficients_next) + offset`. """ graph_deps = [model_matrix, response, model_coefficients_start, predicted_linear_response_start, dispersion, learning_rate] with tf.name_scope(name, 'fit_one_step', graph_deps): [ model_matrix, response, model_coefficients_start, predicted_linear_response_start, offset, ] = prepare_args( model_matrix, response, model_coefficients_start, predicted_linear_response_start, offset) # Compute: mean, grad(mean, predicted_linear_response_start), and variance. mean, variance, grad_mean = model(predicted_linear_response_start) # If either `grad_mean` or `variance is non-finite or zero, then we'll # replace it with a value such that the row is zeroed out. Although this # procedure may seem circuitous, it is necessary to ensure this algorithm is # itself differentiable. is_valid = (tf.is_finite(grad_mean) & tf.not_equal(grad_mean, 0.) & tf.is_finite(variance) & (variance > 0.)) def mask_if_invalid(x, mask): mask = tf.fill(tf.shape(x), value=np.array(mask, x.dtype.as_numpy_dtype)) return tf.where(is_valid, x, mask) # Run one step of iteratively reweighted least-squares. # Compute "`z`", the adjusted predicted linear response. # z = predicted_linear_response_start # + learning_rate * (response - mean) / grad_mean z = (response - mean) / mask_if_invalid(grad_mean, 1.) # TODO(jvdillon): Rather than use learning rate, we should consider using # backtracking line search. if learning_rate is not None: z *= learning_rate[..., tf.newaxis] z += predicted_linear_response_start # Compute "`w`", the per-sample weight. if dispersion is not None: # For convenience, we'll now scale the variance by the dispersion factor. variance *= dispersion w = (mask_if_invalid(grad_mean, 0.) * tf.rsqrt(mask_if_invalid(variance, np.inf))) a = model_matrix * w[..., tf.newaxis] b = z * w # Solve `min{ || A @ model_coefficients - b ||_2**2 : model_coefficients }` # where `@` denotes `matmul`. if l2_regularizer is None: l2_regularizer = np.array(0, a.dtype.as_numpy_dtype) else: l2_regularizer_ = distributions_util.maybe_get_static_value( l2_regularizer, a.dtype.as_numpy_dtype) if l2_regularizer_ is not None: l2_regularizer = l2_regularizer_ def _embed_l2_regularization(): """Adds synthetic observations to implement L2 regularization.""" # `tf.matrix_solve_ls` does not respect the `l2_regularization` argument # when `fast_unsafe_numerics` is `False`. This function adds synthetic # observations to the data to implement the regularization instead. # Adding observations `sqrt(l2_regularizer) * I` is mathematically # equivalent to adding the term # `-l2_regularizer ||coefficients||_2**2` to the log-likelihood. num_model_coefficients = num_cols(model_matrix) batch_shape = tf.shape(model_matrix)[:-2] eye = tf.eye( num_model_coefficients, batch_shape=batch_shape, dtype=a.dtype) a_ = tf.concat([a, tf.sqrt(l2_regularizer) * eye], axis=-2) b_ = distributions_util.pad( b, count=num_model_coefficients, axis=-1, back=True) # Return l2_regularizer=0 since its now embedded. l2_regularizer_ = np.array(0, a.dtype.as_numpy_dtype) return a_, b_, l2_regularizer_ a, b, l2_regularizer = smart_cond.smart_cond( smart_reduce_all([not(fast_unsafe_numerics), l2_regularizer > 0.]), _embed_l2_regularization, lambda: (a, b, l2_regularizer)) model_coefficients_next = tf.matrix_solve_ls( a, b[..., tf.newaxis], fast=fast_unsafe_numerics, l2_regularizer=l2_regularizer, name='model_coefficients_next') model_coefficients_next = model_coefficients_next[..., 0] # TODO(b/79122261): The approach used in `matrix_solve_ls` could be made # faster by avoiding explicitly forming Q and instead keeping the # factorization in 'implicit' form with stacked (rescaled) Householder # vectors underneath the 'R' and then applying the (accumulated) # reflectors in the appropriate order to apply Q'. However, we don't # presently do this because we lack core TF functionality. For reference, # the vanilla QR approach is: # q, r = tf.linalg.qr(a) # c = tf.matmul(q, b, adjoint_a=True) # model_coefficients_next = tf.matrix_triangular_solve( # r, c, lower=False, name='model_coefficients_next') predicted_linear_response_next = calculate_linear_predictor( model_matrix, model_coefficients_next, offset, name='predicted_linear_response_next') return model_coefficients_next, predicted_linear_response_next