def _body(i, state, num_steps_traced, trace_arrays): elem = elems_array.read(i) state = loop_fn(state, elem) trace_arrays, num_steps_traced = ps.cond( trace_criterion_fn(state) if trace_criterion_fn else True, lambda: (trace_one_step(num_steps_traced, trace_arrays, state), # pylint: disable=g-long-lambda num_steps_traced + 1), lambda: (trace_arrays, num_steps_traced)) return i + 1, state, num_steps_traced, trace_arrays
def _get_transpose_conv_dilated_padding(filter_dim, stride, dilation, padding): """Zero-padding for inputs dilated by strides.""" tot_filter_dim = filter_dim + (filter_dim - 1) * (dilation - 1) if padding == 'VALID': tot_pad = tot_filter_dim + stride - 2 + ps.maximum( tot_filter_dim - stride, 0) elif padding == 'SAME': tot_pad = tot_filter_dim + stride - 2 return ps.cond(filter_dim >= stride, lambda: (tot_pad - tot_pad // 2 - stride + 1, tot_pad // 2), lambda: (filter_dim - stride, tot_pad - filter_dim + 1))
def _secant2_inner(value_and_gradients_function, initial_args, val_0, val_c, f_lim, sufficient_decrease_param, curvature_param): """Helper function for secant square.""" # Apply the `update` function on active branch members to squeeze their # bracketing interval. update_result = update(value_and_gradients_function, initial_args.left, initial_args.right, val_c, f_lim, active=initial_args.active) # Update active and failed flags, update left/right on non-failed entries. active = initial_args.active & ~update_result.failed failed = initial_args.failed | update_result.failed val_left = _val_where(active, update_result.left, initial_args.left) val_right = _val_where(active, update_result.right, initial_args.right) # Check if new `c` points should be generated. updated_left = active & tf.equal(val_left.x, val_c.x) updated_right = active & tf.equal(val_right.x, val_c.x) is_new = updated_left | updated_right next_c = tf.where(updated_left, _secant(initial_args.left, val_left), val_c.x) next_c = tf.where(updated_right, _secant(initial_args.right, val_right), next_c) in_range = (val_left.x <= next_c) & (next_c <= val_right.x) # Figure out if an extra function evaluation is needed for new `c` points. needs_extra_eval = tf.reduce_any(input_tensor=in_range & is_new) num_evals = initial_args.num_evals + update_result.num_evals num_evals = num_evals + tf.cast(needs_extra_eval, num_evals.dtype) next_args = _Secant2Result( active=active & in_range, # No longer active if `c` is out of range. converged=initial_args.converged, failed=failed, num_evals=num_evals, left=val_left, right=val_right) def _apply_inner_update(): next_val_c = prefer_static.cond( needs_extra_eval, (lambda: _apply(value_and_gradients_function, next_c)), (lambda: val_c)) return _secant2_inner_update(value_and_gradients_function, next_args, val_0, next_val_c, f_lim, sufficient_decrease_param, curvature_param) return prefer_static.cond(tf.reduce_any(input_tensor=next_args.active), _apply_inner_update, lambda: next_args)
def pad_batch_dimension_for_multiple_chains( observed_time_series, model, chain_batch_shape): """"Expand the observed time series with extra batch dimension(s).""" # Running with multiple chains introduces an extra batch dimension. In # general we also need to pad the observed time series with a matching batch # dimension. # # For example, suppose our model has batch shape [3, 4] and # the observed time series has shape `concat([[5], [3, 4], [100])`, # corresponding to `sample_shape`, `batch_shape`, and `num_timesteps` # respectively. The model will produce distributions with batch shape # `concat([chain_batch_shape, [3, 4]])`, so we pad `observed_time_series` to # have matching shape `[5, 1, 3, 4, 100]`, where the added `1` dimension # between the sample and batch shapes will broadcast to `chain_batch_shape`. observed_time_series = maybe_expand_trailing_dim( observed_time_series) # Guarantee `event_ndims=2` event_ndims = 2 # event_shape = [num_timesteps, observation_size=1] model_batch_ndims = ( model.batch_shape.ndims if model.batch_shape.ndims is not None else tf.shape(input=model.batch_shape_tensor())[0]) # Compute ndims from chain_batch_shape. chain_batch_shape = tf.convert_to_tensor( value=chain_batch_shape, name='chain_batch_shape', dtype=tf.int32) if not chain_batch_shape.shape.is_fully_defined(): raise ValueError('Batch shape must have static rank. (given: {})'.format( chain_batch_shape)) if chain_batch_shape.shape.ndims == 0: # expand int `k` to `[k]`. chain_batch_shape = chain_batch_shape[tf.newaxis] chain_batch_ndims = tf.compat.dimension_value(chain_batch_shape.shape[0]) def do_padding(observed_time_series_tensor): current_sample_shape = tf.shape( input=observed_time_series_tensor)[:-(model_batch_ndims + event_ndims)] current_batch_and_event_shape = tf.shape( input=observed_time_series_tensor)[-(model_batch_ndims + event_ndims):] return tf.reshape( tensor=observed_time_series_tensor, shape=tf.concat([ current_sample_shape, tf.ones([chain_batch_ndims], dtype=tf.int32), current_batch_and_event_shape], axis=0)) # Padding is only needed if the observed time series has sample shape. observed_time_series = prefer_static.cond( (dist_util.prefer_static_rank(observed_time_series) > model_batch_ndims + event_ndims), lambda: do_padding(observed_time_series), lambda: observed_time_series) return observed_time_series
def _loop_body(curr_interval): """The loop body.""" secant2_raw_result = hzl.secant2(value_and_gradients_function, val_0, curr_interval, f_lim, sufficient_decrease_param, curvature_param) secant2_result = HagerZhangLineSearchResult( converged=secant2_raw_result.converged, failed=secant2_raw_result.failed, iterations=curr_interval.iterations + 1, func_evals=secant2_raw_result.num_evals, left=secant2_raw_result.left, right=secant2_raw_result.right) should_check_shrinkage = ~(secant2_result.converged | secant2_result.failed) def _do_check_shrinkage(): """Check if interval has shrinked enough.""" old_width = curr_interval.right.x - curr_interval.left.x new_width = secant2_result.right.x - secant2_result.left.x sufficient_shrinkage = new_width < old_width * shrinkage_param func_is_flat = ( _very_close(curr_interval.left.f, curr_interval.right.f) & _very_close(secant2_result.left.f, secant2_result.right.f)) new_converged = (should_check_shrinkage & sufficient_shrinkage & func_is_flat) needs_inner_bisect = should_check_shrinkage & ~sufficient_shrinkage inner_bisect_args = secant2_result._replace( converged=secant2_result.converged | new_converged) def _apply_inner_bisect(): return _line_search_inner_bisection( value_and_gradients_function, inner_bisect_args, needs_inner_bisect, f_lim) return prefer_static.cond( tf.reduce_any(input_tensor=needs_inner_bisect), _apply_inner_bisect, lambda: inner_bisect_args) next_args = prefer_static.cond( tf.reduce_any(input_tensor=should_check_shrinkage), _do_check_shrinkage, lambda: secant2_result) interval_shrunk = (~next_args.failed & _very_close(next_args.left.x, next_args.right.x)) return [ next_args._replace(converged=next_args.converged | interval_shrunk) ]
def _interleave(a, b): """Interleaves two `Tensor`s along their first axis.""" # [a b c ...] [d e f ...] -> [a d b e c f ...] num_elems_a = prefer_static.shape(a)[0] num_elems_b = prefer_static.shape(b)[0] def _interleave_with_b(a): return tf.reshape( tf.stack([a, b], axis=1), prefer_static.concat([[2 * num_elems_b], prefer_static.shape(a)[1:]], axis=0)) return prefer_static.cond( prefer_static.equal(num_elems_a, num_elems_b + 1), lambda: tf.concat([_interleave_with_b(a[:-1]), a[-1:]], axis=0), lambda: _interleave_with_b(a))
def _expand_and_maybe_replace(): """Performs the expansion step.""" expanded = face_centroid + expansion * (reflected - face_centroid) expanded_objective_value = objective_function(expanded) expanded_is_better = (expanded_objective_value < objective_at_reflected) accept_expanded_fn = lambda: (expanded, expanded_objective_value) accept_reflected_fn = lambda: (reflected, objective_at_reflected) next_pt, next_objective_value = prefer_static.cond( expanded_is_better, accept_expanded_fn, accept_reflected_fn) next_simplex = _replace_at_index(simplex, worst_index, next_pt) next_objective_at_simplex = _replace_at_index(objective_values, worst_index, next_objective_value) return False, next_simplex, next_objective_at_simplex, 1
def _secant2_inner_update(value_and_gradients_function, initial_args, val_0, val_c, f_lim, sufficient_decrease_param, curvature_param): """Helper function for secant-square step.""" # Fail if `val_c` is no longer finite. new_failed = initial_args.active & ~is_finite(val_c) active = initial_args.active & ~new_failed failed = initial_args.failed | new_failed # We converge when we find a point satisfying the Wolfe conditions, in those # cases we set `val_left = val_right = val_c`. found_wolfe = active & _satisfies_wolfe( val_0, val_c, f_lim, sufficient_decrease_param, curvature_param) val_left = val_where(found_wolfe, val_c, initial_args.left) val_right = val_where(found_wolfe, val_c, initial_args.right) converged = initial_args.converged | found_wolfe active = active & ~found_wolfe # If any active batch members remain, we apply the `update` function to # squeeze further their corresponding left/right bracketing interval. def _apply_update(): update_result = update( value_and_gradients_function, val_left, val_right, val_c, f_lim, active=active) return _Secant2Result( active=tf.zeros_like(active), # End of secant2, no actives anymore. converged=converged, failed=failed | update_result.failed, num_evals=initial_args.num_evals + update_result.num_evals, left=update_result.left, right=update_result.right) # Otherwise just return the current results. def _default(): return _Secant2Result( active=active, converged=converged, failed=failed, num_evals=initial_args.num_evals, left=val_left, right=val_right) return prefer_static.cond( tf.reduce_any(active), _apply_update, _default)
def _default_fn(): """Default action.""" new_width = secant2_result.right.x - secant2_result.left.x old_width = val_right.x - val_left.x sufficient_shrinkage = new_width < shrinkage_param * old_width def _sufficient_shrinkage_fn(): """Action to perform if secant2 shrank the interval sufficiently.""" func_is_flat = ( (tf.math.nextafter(val_left.f, val_right.f) >= val_right.f) & (tf.math.nextafter(secant2_result.left.f, secant2_result.right.f) >= secant2_result.right.f)) is_flat_retval = _LineSearchInnerResult( iteration=iteration, found_wolfe=True, failed=False, num_evals=evals, left=secant2_result.left, right=secant2_result.left) not_is_flat_retval = _LineSearchInnerResult( iteration=iteration, found_wolfe=False, failed=False, num_evals=evals, left=secant2_result.left, right=secant2_result.right) return prefer_static.cond(func_is_flat, true_fn=lambda: is_flat_retval, false_fn=lambda: not_is_flat_retval) def _insufficient_shrinkage_fn(): """Action to perform if secant2 didn't shrink the interval enough.""" update_result = _line_search_inner_bisection( value_and_gradients_function, secant2_result.left, secant2_result.right, f_lim) return _LineSearchInnerResult(iteration=iteration, found_wolfe=False, failed=update_result.failed, num_evals=evals + update_result.num_evals, left=update_result.left, right=update_result.right) return prefer_static.cond(sufficient_shrinkage, true_fn=_sufficient_shrinkage_fn, false_fn=_insufficient_shrinkage_fn)
def _marginal_hidden_probs(self): """Compute marginal pdf for each individual observable.""" num_states = self.transition_distribution.batch_shape_tensor()[-1] log_init = _extract_log_probs(num_states, self.initial_distribution) initial_log_probs = tf.broadcast_to(log_init, ps.concat([self.batch_shape_tensor(), [num_states]], axis=0)) # initial_log_probs :: batch_shape num_states no_transition_result = initial_log_probs[tf.newaxis, ...] def _scan_multiple_steps(): """Perform `scan` operation when `num_steps` > 1.""" transition_log_probs = _extract_log_probs(num_states, self.transition_distribution) def forward_step(log_probs, _): result = _log_vector_matrix(log_probs, transition_log_probs) # We know that `forward_step` must preserve the shape of the # tensor of probabilities of each state. This is because # the transition matrix must be square. But TensorFlow might # not know this so we explicitly tell it that the result has the # same shape. tensorshape_util.set_shape(result, log_probs.shape) return result dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) forward_log_probs = tf.scan(forward_step, dummy_index, initializer=initial_log_probs, name='forward_log_probs') result = tf.concat([[initial_log_probs], forward_log_probs], axis=0) return result forward_log_probs = ps.cond( self._num_steps > 1, _scan_multiple_steps, lambda: no_transition_result) return tf.exp(forward_log_probs)
def _contraction(): """Performs a contraction.""" contracted = face_centroid + contraction * (reflected - face_centroid) objective_at_contracted = objective_function(contracted) is_contracted_acceptable = objective_at_contracted <= objective_at_reflected def _accept_contraction(): next_simplex = _replace_at_index(simplex, worst_index, contracted) objective_at_next_simplex = _replace_at_index( objective_values, worst_index, objective_at_contracted) return (False, next_simplex, objective_at_next_simplex, 1) def _reject_contraction(): return _shrink_towards_best(objective_function, simplex, best_index, shrinkage, batch_evaluate_objective) return prefer_static.cond(is_contracted_acceptable, _accept_contraction, _reject_contraction)
def _contraction(): """Performs a contraction.""" contracted = face_centroid - contraction * (face_centroid - simplex[worst_index]) objective_at_contracted = objective_function(contracted) is_contracted_acceptable = objective_at_contracted <= worst_objective_value def _accept_contraction(): next_simplex = _replace_at_index(simplex, worst_index, contracted) objective_at_next_simplex = _replace_at_index( objective_values, worst_index, objective_at_contracted) return (False, next_simplex, objective_at_next_simplex, np.int32(1)) def _reject_contraction(): return _shrink_towards_best(objective_function, simplex, best_index, shrinkage, batch_evaluate_objective) return ps.cond(is_contracted_acceptable, _accept_contraction, _reject_contraction)
def recursive_case(): """Evaluate the next step of the recursion.""" odd_elems = _scan(level - 1, reduced_elems) def even_length_case(): return lowered_fn([odd_elem[:-1] for odd_elem in odd_elems], [elem[2::2] for elem in elems]) def odd_length_case(): return lowered_fn([odd_elem for odd_elem in odd_elems], [elem[2::2] for elem in elems]) results = prefer_static.cond( prefer_static.equal(elem_length % 2, 0), even_length_case, odd_length_case) # The first element of a scan is the same as the first element # of the original `elems`. even_elems = [tf.concat([elem[0:1], result], axis=0) for (elem, result) in zip(elems, results)] return list(map(_interleave, even_elems, odd_elems))
def _maybe_warn_increased_dof(self, component_name, component_ldj, increased_dof): """Warns or raises when `increased_dof` is True.""" # Short-circuit when the component LDJ is statically zero. if (tf.get_static_value(tf.rank(component_ldj)) == 0 and tf.get_static_value(component_ldj) == 0): return # Short-circuit when increased_dof is statically False. increased_dof_ = tf.get_static_value(increased_dof) if increased_dof_ is False: # pylint: disable=g-bool-id-comparison return error_message = ( 'Nested component "{}" in composition "{}" operates on inputs ' 'with increased degrees of freedom. This may result in an ' 'incorrect log_det_jacobian.' ).format(component_name, self.name) # When validate_args is True, we raise on increased DoF. if self._validate_args: if increased_dof_: raise ValueError(error_message) return assert_util.assert_equal(False, increased_dof, error_message) if (not tf.executing_eagerly() and control_flow_util.GraphOrParentsInXlaContext(tf1.get_default_graph())): return # No StringFormat or Print ops in XLA. # Otherwise, we print a warning and continue. return ps.cond( pred=increased_dof, false_fn=tf.no_op, true_fn=lambda: tf.print( # pylint: disable=g-long-lambda 'WARNING: ' + error_message, output_stream=sys.stderr))
def _update_inv_hessian(prev_state, next_state): """Update the BGFS state by computing the next inverse hessian estimate.""" # Only update the inverse Hessian if not already failed or converged. should_update = ~next_state.converged & ~next_state.failed # Compute the normalization term (y^T . s), should not update if is singular. gradient_delta = next_state.objective_gradient - prev_state.objective_gradient position_delta = next_state.position - prev_state.position normalization_factor = tf.reduce_sum(gradient_delta * position_delta, axis=-1) should_update = should_update & ~tf.equal(normalization_factor, 0) def _do_update_inv_hessian(): next_inv_hessian = _bfgs_inv_hessian_update( gradient_delta, position_delta, normalization_factor, prev_state.inverse_hessian_estimate) return bfgs_utils.update_fields( next_state, inverse_hessian_estimate=tf.where( should_update[..., tf.newaxis, tf.newaxis], next_inv_hessian, prev_state.inverse_hessian_estimate)) return ps.cond(ps.reduce_any(should_update), _do_update_inv_hessian, lambda: next_state)
def _do_check_shrinkage(): """Check if interval has shrinked enough.""" old_width = curr_interval.right.x - curr_interval.left.x new_width = secant2_result.right.x - secant2_result.left.x sufficient_shrinkage = new_width < old_width * shrinkage_param func_is_flat = ( _very_close(curr_interval.left.f, curr_interval.right.f) & _very_close(secant2_result.left.f, secant2_result.right.f)) new_converged = (should_check_shrinkage & sufficient_shrinkage & func_is_flat) needs_inner_bisect = should_check_shrinkage & ~sufficient_shrinkage inner_bisect_args = secant2_result._replace( converged=secant2_result.converged | new_converged) def _apply_inner_bisect(): return _line_search_inner_bisection( value_and_gradients_function, inner_bisect_args, needs_inner_bisect, f_lim) return prefer_static.cond( tf.reduce_any(input_tensor=needs_inner_bisect), _apply_inner_bisect, lambda: inner_bisect_args)
def _loop_body( # pylint: disable=missing-docstring iter_, x_update_diff_norm_sq, x_update, hess_matmul_x_update): # Inner loop of the minimizer. # # This loop updates a single coordinate of x_update. Ideally, an # iteration of this loop would set # # x_update[j] += argmin{ LocalLoss(x_update + z*e_j) : z in R } # # where # # LocalLoss(x_update') # = LocalLossSmoothComponent(x_update') # + l1_regularizer * (||x_start + x_update'||_1 - # ||x_start + x_update||_1) # := (UnregularizedLoss(x_start + x_update') - # UnregularizedLoss(x_start + x_update) # + l2_regularizer * (||x_start + x_update'||_2**2 - # ||x_start + x_update||_2**2) # + l1_regularizer * (||x_start + x_update'||_1 - # ||x_start + x_update||_1) # # In this algorithm approximate the above argmin using (univariate) # proximal gradient descent: # # (*) x_update[j] = prox_{t * l1_regularizer * L1}( # x_update[j] - # t * d/dz|z=0 UnivariateLocalLossSmoothComponent(z)) # # where # # UnivariateLocalLossSmoothComponent(z) # := LocalLossSmoothComponent(x_update + z*e_j) # # and we approximate # # d/dz UnivariateLocalLossSmoothComponent(z) # = grad LocalLossSmoothComponent(x_update))[j] # ~= (grad LossSmoothComponent(x_start) # + x_update matmul HessianOfLossSmoothComponent(x_start))[j]. # # To choose the parameter t, we squint and pretend that the inner term of # (*) is a Newton update as if we were using Newton's method to minimize # UnivariateLocalLossSmoothComponent. That is, we choose t such that # # -t * d/dz ULLSC = -learning_rate * (d/dz ULLSC) / (d^2/dz^2 ULLSC) # # at z=0. Hence # # t = learning_rate / (d^2/dz^2|z=0 ULLSC) # = learning_rate / HessianOfLossSmoothComponent( # x_start + x_update)[j,j] # ~= learning_rate / HessianOfLossSmoothComponent( # x_start)[j,j] # # The above approximation is equivalent to assuming that # HessianOfUnregularizedLoss is constant, i.e., ignoring third-order # effects. # # Note that because LossSmoothComponent is (assumed to be) convex, t is # positive. # In above notation, coord = j. coord = iter_ % dims # x_update_diff_norm_sq := ||x_update_end - x_update_start||_2**2, # computed incrementally, where x_update_end and x_update_start are as # defined in the convergence criteria. Accordingly, we reset # x_update_diff_norm_sq to zero at the beginning of each sweep. x_update_diff_norm_sq = tf.where( tf.equal(coord, 0), tf.zeros_like(x_update_diff_norm_sq), x_update_diff_norm_sq) # Recall that x_update and hess_matmul_x_update has the rightmost # dimension transposed to the leftmost dimension. w_old = x_start[..., coord] + x_update[coord, ...] # This is the coordinatewise Newton update if no L1 regularization. # In above notation, newton_step = -t * (approximation of d/dz|z=0 ULLSC). second_deriv = _hessian_diag_elt_with_l2(coord) newton_step = -_mul_ignoring_nones( # pylint: disable=invalid-unary-operand-type learning_rate, grad_loss_with_l2[..., coord] + hess_matmul_x_update[coord, ...]) / second_deriv # Applying the soft-threshold operator accounts for L1 regularization. # In above notation, delta = # prox_{t*l1_regularizer*L1}(w_old + newton_step) - w_old. delta = (soft_threshold( w_old + newton_step, _mul_ignoring_nones(learning_rate, l1_regularizer) / second_deriv) - w_old) def _do_update(x_update_diff_norm_sq, x_update, hess_matmul_x_update): # pylint: disable=missing-docstring hessian_column_with_l2 = sparse_or_dense_matvecmul( hessian_unregularized_loss_outer, hessian_unregularized_loss_middle * _sparse_or_dense_matmul_onehot( hessian_unregularized_loss_outer, coord), adjoint_a=True) if l2_regularizer is not None: hessian_column_with_l2 += _one_hot_like( hessian_column_with_l2, coord, on_value=2. * l2_regularizer) # Move the batch dimensions of `hessian_column_with_l2` to rightmost in # order to conform to `hess_matmul_x_update`. n = tf.rank(hessian_column_with_l2) perm = tf.roll(tf.range(n), shift=1, axis=0) hessian_column_with_l2 = tf.transpose(a=hessian_column_with_l2, perm=perm) # Update the entire batch at `coord` even if `delta` may be 0 at some # batch coordinates. In those cases, adding `delta` is a no-op. x_update = tf.tensor_scatter_add(x_update, [[coord]], [delta]) with tf.control_dependencies([x_update]): x_update_diff_norm_sq_ = x_update_diff_norm_sq + delta**2 hess_matmul_x_update_ = (hess_matmul_x_update + delta * hessian_column_with_l2) # Hint that loop vars retain the same shape. x_update_diff_norm_sq_.set_shape( x_update_diff_norm_sq_.shape.merge_with( x_update_diff_norm_sq.shape)) hess_matmul_x_update_.set_shape( hess_matmul_x_update_.shape.merge_with( hess_matmul_x_update.shape)) return [ x_update_diff_norm_sq_, x_update, hess_matmul_x_update_ ] inputs_to_update = [ x_update_diff_norm_sq, x_update, hess_matmul_x_update ] return [iter_ + 1] + prefer_static.cond( # Note on why checking delta (a difference of floats) for equality to # zero is ok: # # First of all, x - x == 0 in floating point -- see # https://stackoverflow.com/a/2686671 # # Delta will conceptually equal zero when one of the following holds: # (i) |w_old + newton_step| <= threshold and w_old == 0 # (ii) |w_old + newton_step| > threshold and # w_old + newton_step - sign(w_old + newton_step) * threshold # == w_old # # In case (i) comparing delta to zero is fine. # # In case (ii), newton_step conceptually equals # sign(w_old + newton_step) * threshold. # Also remember # threshold = -newton_step / (approximation of d/dz|z=0 ULLSC). # So (i) happens when # (approximation of d/dz|z=0 ULLSC) == -sign(w_old + newton_step). # If we did not require LossSmoothComponent to be strictly convex, # then this could actually happen a non-negligible amount of the time, # e.g. if the loss function is piecewise linear and one of the pieces # has slope 1. But since LossSmoothComponent is strictly convex, (i) # should not systematically happen. tf.reduce_all(input_tensor=tf.equal(delta, 0.)), lambda: inputs_to_update, lambda: _do_update(*inputs_to_update))
def _sample_n(self, n, seed=None): init_seed, scan_seed, observation_seed = samplers.split_seed( seed, n=3, salt='HiddenMarkovModel') transition_batch_shape = self.transition_distribution.batch_shape_tensor() num_states = transition_batch_shape[-1] batch_shape = self.batch_shape_tensor() batch_size = ps.reduce_prod(batch_shape) # The batch sizes of the underlying initial distributions and # transition distributions might not match the batch size of # the HMM distribution. # As a result we need to ask for more samples from the # underlying distributions and then reshape the results into # the correct batch size for the HMM. init_repeat = ( ps.reduce_prod(batch_shape) // ps.reduce_prod(self._initial_distribution.batch_shape_tensor())) init_state = self._initial_distribution.sample(n * init_repeat, seed=init_seed) init_state = tf.reshape(init_state, [n, batch_size]) # init_state :: n batch_size transition_repeat = ( ps.reduce_prod(batch_shape) // ps.reduce_prod( transition_batch_shape[:-1])) init_shape = init_state.shape def generate_step(state_and_seed, _): """Take a single step in Markov chain.""" state, seed = state_and_seed sample_seed, next_seed = samplers.split_seed(seed) gen = self._transition_distribution.sample(n * transition_repeat, seed=sample_seed) # gen :: (n * transition_repeat) transition_batch new_states = tf.reshape(gen, [n, batch_size, num_states]) # new_states :: n batch_size num_states old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32) # old_states :: n batch_size num_states result = tf.reduce_sum(old_states_one_hot * new_states, axis=-1) # We know that `generate_step` must preserve the shape of the # tensor of states of each state. This is because # the transition matrix must be square. But TensorFlow might # not know this so we explicitly tell it that the result has the # same shape. tensorshape_util.set_shape(result, init_shape) return result, next_seed def _scan_multiple_steps(): """Take multiple steps with tf.scan.""" dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) hidden_states, _ = tf.scan(generate_step, dummy_index, initializer=(init_state, scan_seed)) # TODO(b/115618503): add/use prepend_initializer to tf.scan return tf.concat([[init_state], hidden_states], axis=0) hidden_states = ps.cond( self._num_steps > 1, _scan_multiple_steps, lambda: init_state[tf.newaxis, ...]) hidden_one_hot = tf.one_hot(hidden_states, num_states, dtype=self._observation_distribution.dtype) # hidden_one_hot :: num_steps n batch_size num_states # The observation distribution batch size might not match # the required batch size so as with the initial and # transition distributions we generate more samples and # reshape. observation_repeat = tf.maximum( batch_size // ps.reduce_prod( self._observation_distribution.batch_shape_tensor()[:-1]), 1) if self._time_varying_observation_distribution: possible_observations = self._observation_distribution.sample( [observation_repeat * n], seed=observation_seed) # possible observations needs to have num_steps moved to the beginning. possible_observations = distribution_util.move_dimension( possible_observations, -(tf.size(self._observation_distribution.event_shape_tensor()) + 2), 0) else: possible_observations = self._observation_distribution.sample( [self._num_steps, observation_repeat * n], seed=observation_seed) inner_shape = self._observation_distribution.event_shape_tensor() # possible_observations :: num_steps (observation_repeat * n) # observation_batch[:-1] num_states inner_shape possible_observations = tf.reshape( possible_observations, ps.concat([[self._num_steps, n], batch_shape, [num_states], inner_shape], axis=0)) # possible_observations :: steps n batch_size num_states inner_shape hidden_one_hot = tf.reshape(hidden_one_hot, ps.concat([[self._num_steps, n], batch_shape, [num_states], ps.ones_like(inner_shape)], axis=0)) # hidden_one_hot :: steps n batch_size num_states "inner_shape" observations = tf.reduce_sum( hidden_one_hot * possible_observations, axis=-1 - ps.size(inner_shape)) # observations :: steps n batch_size inner_shape observations = distribution_util.move_dimension(observations, 0, 1 + ps.size(batch_shape)) # returned :: n batch_shape steps inner_shape return observations
def posterior_mode(self, observations, mask=None, name=None): """Compute maximum likelihood sequence of hidden states. When this function is provided with a sequence of observations `x[0], ..., x[num_steps - 1]`, it returns the sequence of hidden states `z[0], ..., z[num_steps - 1]`, drawn from the underlying Markov chain, that is most likely to yield those observations. It uses the [Viterbi algorithm]( https://en.wikipedia.org/wiki/Viterbi_algorithm). Note: the behavior of this function is undefined if the `observations` argument represents impossible observations from the model. Note: if there isn't a unique most likely sequence then one of the equally most likely sequences is chosen. Args: observations: A tensor representing a batch of observations made on the hidden Markov model. The rightmost dimensions of this tensor correspond to the dimensions of the observation distributions of the underlying Markov chain. The next dimension from the right indexes the steps in a sequence of observations from a single sample from the hidden Markov model. The size of this dimension should match the `num_steps` parameter of the hidden Markov model object. The other dimensions are the dimensions of the batch and these are broadcast with the hidden Markov model's parameters. mask: optional bool-type `tensor` with rightmost dimension matching `num_steps` indicating which observations the result of this function should be conditioned on. When the mask has value `True` the corresponding observations aren't used. if `mask` is `None` then all of the observations are used. the `mask` dimensions left of the last are broadcast with the hmm batch as well as with the observations. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Returns: posterior_mode: A `Tensor` representing the most likely sequence of hidden states. The rightmost dimension of this tensor will equal the `num_steps` parameter providing one hidden state for each step. The other dimensions are those of the batch. Raises: ValueError: if the `observations` tensor does not consist of sequences of `num_steps` observations. #### Examples ```python tfd = tfp.distributions # A simple weather model. # Represent a cold day with 0 and a hot day with 1. # Suppose the first day of a sequence has a 0.8 chance of being cold. initial_distribution = tfd.Categorical(probs=[0.8, 0.2]) # Suppose a cold day has a 30% chance of being followed by a hot day # and a hot day has a 20% chance of being followed by a cold day. transition_distribution = tfd.Categorical(probs=[[0.7, 0.3], [0.2, 0.8]]) # Suppose additionally that on each day the temperature is # normally distributed with mean and standard deviation 0 and 5 on # a cold day and mean and standard deviation 15 and 10 on a hot day. observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.]) # This gives the hidden Markov model: model = tfd.HiddenMarkovModel( initial_distribution=initial_distribution, transition_distribution=transition_distribution, observation_distribution=observation_distribution, num_steps=7) # Suppose we observe gradually rising temperatures over a week: temps = [-2., 0., 2., 4., 6., 8., 10.] # We can now compute the most probable sequence of hidden states: model.posterior_mode(temps) # The result is [0 0 0 0 0 1 1] telling us that the transition # from "cold" to "hot" most likely happened between the # 5th and 6th days. ``` """ with tf.name_scope(name or "posterior_mode"): observations = tf.convert_to_tensor(observations, name="observations") if mask is not None: mask = tf.convert_to_tensor(mask, name="mask", dtype_hint=tf.bool) with tf.control_dependencies(self._runtime_assertions): observation_tensor_shape = tf.shape(observations) mask_tensor_shape = tf.shape( mask) if mask is not None else None with self._observation_mask_shape_preconditions( observation_tensor_shape, mask_tensor_shape): observation_log_probs = self._observation_log_probs( observations, mask) log_prob = self._log_init + observation_log_probs[0] def _reduce_multiple_steps(): """Perform `reduce_max` operation when `num_steps` > 1.""" def forward_step(previous_step_pair, log_prob_observation): log_prob_previous = previous_step_pair[0] log_prob = ( log_prob_previous[..., tf.newaxis] + self._log_trans + log_prob_observation[..., tf.newaxis, :]) most_likely_given_successor = tf.argmax(log_prob, axis=-2) max_log_p_given_successor = tf.reduce_max(log_prob, axis=-2) return (max_log_p_given_successor, most_likely_given_successor) forward_log_probs, all_most_likely_given_successor = tf.scan( forward_step, observation_log_probs[1:], initializer=(log_prob, tf.zeros(tf.shape(log_prob), dtype=tf.int64)), name="forward_log_probs") most_likely_end = tf.argmax(forward_log_probs[-1], axis=-1) # We require the operation that gives C from A and B where # C[i...j] = A[i...j, B[i...j]] # and A = most_likely_given_successor # B = most_likely_successor. # tf.gather requires indices of known shape so instead we use # reduction with tf.one_hot(B) to pick out elements from B def backward_step(most_likely_successor, most_likely_given_successor): return tf.reduce_sum( (most_likely_given_successor * tf.one_hot(most_likely_successor, self._num_states, dtype=tf.int64)), axis=-1) backward_scan = tf.scan( backward_step, all_most_likely_given_successor, most_likely_end, reverse=True) most_likely_sequences = tf.concat( [backward_scan, [most_likely_end]], axis=0) return distribution_util.move_dimension( most_likely_sequences, 0, -1) return prefer_static.cond( self.num_steps > 1, _reduce_multiple_steps, lambda: tf.argmax(log_prob, axis=-1)[..., tf.newaxis])
def _sample_n(self, n, seed=None): with tf.control_dependencies(self._runtime_assertions): strm = SeedStream(seed, salt="HiddenMarkovModel") num_states = self._num_states batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(batch_shape) # The batch sizes of the underlying initial distributions and # transition distributions might not match the batch size of # the HMM distribution. # As a result we need to ask for more samples from the # underlying distributions and then reshape the results into # the correct batch size for the HMM. init_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod( self._initial_distribution.batch_shape_tensor())) init_state = self._initial_distribution.sample(n * init_repeat, seed=strm()) init_state = tf.reshape(init_state, [n, batch_size]) # init_state :: n batch_size transition_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod( self._transition_distribution.batch_shape_tensor()[:-1])) def generate_step(state, _): """Take a single step in Markov chain.""" gen = self._transition_distribution.sample(n * transition_repeat, seed=strm()) # gen :: (n * transition_repeat) transition_batch new_states = tf.reshape(gen, [n, batch_size, num_states]) # new_states :: n batch_size num_states old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32) # old_states :: n batch_size num_states return tf.reduce_sum(old_states_one_hot * new_states, axis=-1) def _scan_multiple_steps(): """Take multiple steps with tf.scan.""" dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) if seed is not None: # Force parallel_iterations to 1 to ensure reproducibility # b/139210489 hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state, parallel_iterations=1) else: # Invoke default parallel_iterations behavior hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state) # TODO(b/115618503): add/use prepend_initializer to tf.scan return tf.concat([[init_state], hidden_states], axis=0) hidden_states = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps, lambda: init_state[tf.newaxis, ...]) hidden_one_hot = tf.one_hot( hidden_states, num_states, dtype=self._observation_distribution.dtype) # hidden_one_hot :: num_steps n batch_size num_states # The observation distribution batch size might not match # the required batch size so as with the initial and # transition distributions we generate more samples and # reshape. observation_repeat = (batch_size // tf.reduce_prod( self._observation_distribution.batch_shape_tensor()[:-1])) possible_observations = self._observation_distribution.sample( [self._num_steps, observation_repeat * n], seed=strm()) inner_shape = self._observation_distribution.event_shape # possible_observations :: num_steps (observation_repeat * n) # observation_batch[:-1] num_states inner_shape possible_observations = tf.reshape( possible_observations, tf.concat([[self._num_steps, n], batch_shape, [num_states], inner_shape], axis=0)) # possible_observations :: steps n batch_size num_states inner_shape hidden_one_hot = tf.reshape( hidden_one_hot, tf.concat([[self._num_steps, n], batch_shape, [num_states], tf.ones_like(inner_shape)], axis=0)) # hidden_one_hot :: steps n batch_size num_states "inner_shape" observations = tf.reduce_sum(hidden_one_hot * possible_observations, axis=-1 - tf.size(inner_shape)) # observations :: steps n batch_size inner_shape observations = distribution_util.move_dimension( observations, 0, 1 + tf.size(batch_shape)) # returned :: n batch_shape steps inner_shape return observations
def _bracket_and_search(value_and_gradients_function, val_0, val_c, f_lim, max_iterations, shrinkage_param=None, expansion_param=None, sufficient_decrease_param=None, curvature_param=None): """Brackets the minimum and performs a line search. Args: value_and_gradients_function: A Python callable that accepts a real scalar tensor and returns an object that can be converted to a namedtuple. The namedtuple should have fields 'f' and 'df' that correspond to scalar tensors of real dtype containing the value of the function and its derivative at that point. The other namedtuple fields, if present, should be tensors or sequences (possibly nested) of tensors. In usual optimization application, this function would be generated by projecting the multivariate objective function along some specific direction. The direction is determined by some other procedure but should be a descent direction (i.e. the derivative of the projected univariate function must be negative at 0.). val_0: Instance of `_FnDFn` containing the value and gradient of the objective at 0. The gradient must be negative (i.e. must be a descent direction). val_c: Instance of `_FnDFn` containing the initial step size and the value and gradient of the objective at the initial step size. The step size must be positive and finite. f_lim: Scalar `Tensor` of float dtype. max_iterations: Positive scalar `Tensor` of integral dtype. The maximum number of iterations to perform in the line search. The number of iterations used to bracket the minimum are also counted against this parameter. shrinkage_param: Scalar positive Tensor of real dtype. Must be less than `1.`. Corresponds to the parameter `gamma` in [Hager and Zhang (2006)][2]. expansion_param: Scalar positive `Tensor` of real dtype. Must be greater than `1.`. Used to expand the initial interval in case it does not bracket a minimum. Corresponds to `rho` in [Hager and Zhang (2006)][2]. sufficient_decrease_param: Positive scalar `Tensor` of real dtype. Bounded above by the curvature param. Corresponds to `delta` in the terminology of [Hager and Zhang (2006)][2]. curvature_param: Positive scalar `Tensor` of real dtype. Bounded above by `1.`. Corresponds to 'sigma' in the terminology of [Hager and Zhang (2006)][2]. Returns: A namedtuple containing the following fields. iteration: A scalar int32 `Tensor`. The number of iterations consumed. found_wolfe: A scalar boolean `Tensor`. Indicates whether a point satisfying the Wolfe conditions has been found. If this is True, the interval will be degenerate (i.e. left and right below will be identical). failed: A scalar boolean `Tensor`. Indicates if invalid function or gradient values were encountered (i.e. infinity or NaNs). num_evals: A scalar int32 `Tensor`. The total number of function evaluations made. left: Instance of _FnDFn. The position and the associated value and derivative at the updated left end point of the interval. right: Instance of _FnDFn. The position and the associated value and derivative at the updated right end point of the interval. """ bracket_result = hzl.bracket(value_and_gradients_function, val_0, val_c, f_lim, max_iterations, expansion_param=expansion_param) # If the bracketing failed, or we have already exhausted all the allowed # iterations, we return an error. failed = (bracket_result.failed | tf.greater_equal(bracket_result.iteration, max_iterations)) def _bracketing_failed_fn(): return _LineSearchInnerResult(iteration=bracket_result.iteration, found_wolfe=False, failed=True, num_evals=bracket_result.num_evals, left=val_0, right=val_c) def _bracketing_success_fn(): """Performs line search.""" result = _line_search_after_bracketing( value_and_gradients_function, val_0, bracket_result.left, bracket_result.right, f_lim, bracket_result.iteration, max_iterations, sufficient_decrease_param=sufficient_decrease_param, curvature_param=curvature_param, shrinkage_param=shrinkage_param) return _LineSearchInnerResult(iteration=result.iteration, found_wolfe=result.found_wolfe, failed=result.failed, num_evals=bracket_result.num_evals + result.num_evals, left=result.left, right=result.right) return prefer_static.cond(failed, true_fn=_bracketing_failed_fn, false_fn=_bracketing_success_fn)
def _valid_inputs_fn(): """Performs bracketing and line search if inputs are valid.""" # If the value or the gradient at the supplied step is not finite, # we attempt to repair it. step_size_too_large = ~(tf.math.is_finite(val_c_input.df) & tf.math.is_finite(val_c_input.f)) def _is_too_large_fn(): return _fix_step_size(value_and_gradients_function, val_c_input, step_size_shrink_param) val_c, fix_evals = prefer_static.cond(step_size_too_large, _is_too_large_fn, lambda: (val_c_input, 0)) # Check if c is fixed now. valid_at_c = hzl.is_finite(val_c) & (val_c.x > 0) def _failure_fn(): # If c is still not good, just return 0. return HagerZhangLineSearchResult( converged=tf.convert_to_tensor(value=True, name='converged'), failed=tf.convert_to_tensor(value=False, name='failed'), func_evals=prepare_evals + fix_evals, iterations=tf.convert_to_tensor(value=0), left_pt=val_0.x, objective_at_left_pt=val_0.f, grad_objective_at_left_pt=val_0.df, right_pt=val_0.x, objective_at_right_pt=val_0.f, grad_objective_at_right_pt=val_0.df, full_result=val_0.full_result) def success_fn(): """Bracketing and searching to do if all inputs are valid.""" result = _bracket_and_search( value_and_gradients_function, val_0, val_c, f_lim, max_iterations, shrinkage_param=shrinkage_param, expansion_param=expansion_param, sufficient_decrease_param=sufficient_decrease_param, curvature_param=curvature_param) converged = tf.convert_to_tensor(value=result.found_wolfe, name='converged') return HagerZhangLineSearchResult( converged=converged, failed=tf.convert_to_tensor(value=result.failed, name='failed'), func_evals=result.num_evals + prepare_evals + fix_evals, iterations=result.iteration, left_pt=result.left.x, objective_at_left_pt=result.left.f, grad_objective_at_left_pt=result.left.df, right_pt=result.right.x, objective_at_right_pt=result.right.f, grad_objective_at_right_pt=result.right.df, full_result=result.left.full_result) return prefer_static.cond(valid_at_c, true_fn=success_fn, false_fn=_failure_fn)
def scan_associative(fn, elems, max_num_levels=48, validate_args=False, name=None): """Perform a scan with an associative binary operation, in parallel. The associative scan operation computes the cumulative sum, or [all-prefix sum](https://en.wikipedia.org/wiki/Prefix_sum), of a set of elements under an associative binary operation [1]. For example, using the ordinary addition operator `fn = lambda a, b: a + b`, this is equivalent to the ordinary cumulative sum `tf.math.cumsum` along axis 0. This method supports the general case of arbitrary associative binary operations operating on `Tensor`s or structures of `Tensor`s: ```python associative_scan(fn, elems) = tf.stack([ elems[0], fn(elems[0], elems[1]), fn(elems[0], fn(elems[1], elems[2])), ... fn(elems[0], fn(elems[1], fn(..., fn(elems[-2], elems[-1]))), ], axis=0) ``` The associative structure allows the computation to be decomposed and executed by parallel reduction. Where a naive sequential implementation would loop over all `N` elements, this method requires only a logarithmic number (`2 * ceil(log_2 N)`) of sequential steps, and can thus yield substantial performance speedups from hardware-accelerated vectorization. The total number of invocations of the binary operation (including those performed in parallel) is `2 * (N / 2 + N / 4 + ... + 1) = 2N - 2` --- i.e., approximately twice as many as a naive approach. [1] Blelloch, Guy E. [Prefix sums and their applications]( https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf) Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University, 1990. Args: fn: Python callable implementing an associative binary operation with signature `r = fn(a, b)`. This must satisfy associativity: `fn(a, fn(b, c)) == fn(fn(a, b), c)`. The inputs and result are (possibly nested structures of) `Tensor`(s), matching `elems`. Each `Tensor` has a leading batch dimension in place of `elem_length`; the `fn` is expected to map over this dimension. The result `r` has the same shape (and structure) as the two inputs `a` and `b`. elems: A (possibly nested structure of) `Tensor`(s), each with leading dimension `elem_length`. Note that `elem_length` determines the number of recursive steps required to perform the scan: if, in graph mode, this is not statically available, then ops will be created to handle any `elem_length` up to the maximum dimension of a `Tensor`. max_num_levels: Python `int`. The size of the first dimension of the tensors in `elems` must be less than `2**(max_num_levels + 1)`. The default value is sufficiently large for most needs. Lowering this value can reduce graph-building time when `scan_associative` is used with inputs of unknown shape. Default value: `48`. validate_args: Python `bool`. When `True`, runtime checks for invalid inputs are performed. This may carry a performance cost. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Returns: result: A (possibly nested structure of) `Tensor`(s) of the same shape and structure as `elems`, in which the `k`th element is the result of recursively applying `fn` to combine the first `k` elements of `elems`. For example, given `elems = [a, b, c, ...]`, the result would be `[a, fn(a, b), fn(fn(a, b), c), ...]`. #### Examples ```python import tensorflow as tf import tensorflow_probability as tfp import operator # Example 1: Partials sums of numbers. tfp.math.associative_scan(operator.add, tf.range(0, 4)) # ==> [ 0, 1, 3, 6] # Example 2: Partial products of random matrices. dist = tfp.distributions.Normal(loc=0., scale=1.) matrices = dist.sample(sample_shape=[100, 2, 2]) tfp.math.associative_scan(tf.matmul, matrices) ``` """ def lowered_fn(a, b): # Lower `fn` to operate on flattened sequences of elems. with tf.name_scope('fn'): return tf.nest.flatten( fn(tf.nest.pack_sequence_as(elems, a), tf.nest.pack_sequence_as(elems, b))) elems_flat = [ tf.convert_to_tensor(elem) for elem in tf.nest.flatten(elems) ] # Summary of algorithm: # # Consider elements of `_scan(elems)` at odd indices. That's the same as first # summing successive pairs of elements of `elems` and performing a scan on # that half sized tensor. We perform the latter scan by recursion. # # Now consider the even elements of `_scan(elems)`. These can be computed # from the odd elements of `_scan(elems)` by adding each odd element of # `_scan(elems)` to the matching even element in the original `elems`. # # We return the odd and even elements interleaved. # # For the base case of the recursion we return the first element # of `elems` followed by the sum of the first two elements computed as # a (small two-down-to-one) reduction step. # The following is a pictorial representation of the algorithm using the # variables in the code below. The operator '+' is used to represent # the binary operation. # Note how the recursive call to `_scan` operates on a reduced form of the # input array in which successive pairs have already been summed. # elems x0 x1 x2 x3 x4 x5 ... # |\ / | \ / | \ / # | \ / | \ / | \ / # | \ / | \ / | \ / # | \ / | \ / | \ / # reduced | x0+x1 | x2+x3 | x4+x5 ... # _elems | | | | | | # | | | | | | # | | | | | | # _scan(..) | | | | | | # +--|----+----------|-----+----------|-------+---- ... # | | | | # | | | | # +--|----+----------|-----+----------|-------+---- ... # | | | | | | # odd | x0+x1 | x0+...+x3 | x0+..+x5 ... # _elems | | \ | | \ | | # | | \ | | \ | | # even | | \ | | \ | | # _elems x0 | x0+...+x2 | x0+...+x4 | ... # | | | | | | # inter | | | | | | # leave(..) | | | | | | # x0 x0+x1 x0+...+x2 x0+...+x3 x0+...+x4 x0+...+x5 ... # TODO(b/150374456): if the sizes of all of the tensors can be determined # statically then we don't need a `level` parameter. def _scan(level, elems): """Perform scan on `elems`.""" elem_length = prefer_static.shape(elems[0])[0] # Apply `fn` to reduce adjacent pairs to a single entry. a = [elem[0:-1:2] for elem in elems] b = [elem[1::2] for elem in elems] reduced_elems = lowered_fn(a, b) def handle_base_case_elem_length_two(): return [ tf.concat([elem[0:1], reduced_elem], axis=0) for (reduced_elem, elem) in zip(reduced_elems, elems) ] def handle_base_case_elem_length_three(): reduced_reduced_elems = lowered_fn(reduced_elems, [elem[2:3] for elem in elems]) return [ tf.concat([elem[0:1], reduced_elem, reduced_reduced_elem], axis=0) for (reduced_reduced_elem, reduced_elem, elem) in zip(reduced_reduced_elems, reduced_elems, elems) ] # Base case of recursion: assumes `elem_length` is 2 or 3. at_base_case = prefer_static.logical_or( prefer_static.equal(elem_length, 2), prefer_static.equal(elem_length, 3)) base_value = lambda: prefer_static.cond( # pylint: disable=g-long-lambda prefer_static.equal(elem_length, 2 ), handle_base_case_elem_length_two, handle_base_case_elem_length_three) if level <= 0: return base_value() def recursive_case(): """Evaluate the next step of the recursion.""" odd_elems = _scan(level - 1, reduced_elems) def even_length_case(): return lowered_fn([odd_elem[:-1] for odd_elem in odd_elems], [elem[2::2] for elem in elems]) def odd_length_case(): return lowered_fn([odd_elem for odd_elem in odd_elems], [elem[2::2] for elem in elems]) results = prefer_static.cond( prefer_static.equal(elem_length % 2, 0), even_length_case, odd_length_case) # The first element of a scan is the same as the first element # of the original `elems`. even_elems = [ tf.concat([elem[0:1], result], axis=0) for (elem, result) in zip(elems, results) ] return list(map(_interleave, even_elems, odd_elems)) return prefer_static.cond(at_base_case, base_value, recursive_case) with tf.name_scope(name if name else 'scan_associative'): elem_length, assertions = _validate_elem_length( max_num_levels, elems_flat) with tf.control_dependencies(assertions if validate_args else []): return prefer_static.cond( elem_length < 2, lambda: elems, lambda: ( tf.nest.pack_sequence_as( # pylint: disable=g-long-lambda elems, _scan(max_num_levels - 1, elems_flat))))
def _scan(level, elems): """Perform scan on `elems`.""" elem_length = prefer_static.shape(elems[0])[0] # Apply `fn` to reduce adjacent pairs to a single entry. a = [elem[0:-1:2] for elem in elems] b = [elem[1::2] for elem in elems] reduced_elems = lowered_fn(a, b) def handle_base_case_elem_length_two(): return [ tf.concat([elem[0:1], reduced_elem], axis=0) for (reduced_elem, elem) in zip(reduced_elems, elems) ] def handle_base_case_elem_length_three(): reduced_reduced_elems = lowered_fn(reduced_elems, [elem[2:3] for elem in elems]) return [ tf.concat([elem[0:1], reduced_elem, reduced_reduced_elem], axis=0) for (reduced_reduced_elem, reduced_elem, elem) in zip(reduced_reduced_elems, reduced_elems, elems) ] # Base case of recursion: assumes `elem_length` is 2 or 3. at_base_case = prefer_static.logical_or( prefer_static.equal(elem_length, 2), prefer_static.equal(elem_length, 3)) base_value = lambda: prefer_static.cond( # pylint: disable=g-long-lambda prefer_static.equal(elem_length, 2 ), handle_base_case_elem_length_two, handle_base_case_elem_length_three) if level <= 0: return base_value() def recursive_case(): """Evaluate the next step of the recursion.""" odd_elems = _scan(level - 1, reduced_elems) def even_length_case(): return lowered_fn([odd_elem[:-1] for odd_elem in odd_elems], [elem[2::2] for elem in elems]) def odd_length_case(): return lowered_fn([odd_elem for odd_elem in odd_elems], [elem[2::2] for elem in elems]) results = prefer_static.cond( prefer_static.equal(elem_length % 2, 0), even_length_case, odd_length_case) # The first element of a scan is the same as the first element # of the original `elems`. even_elems = [ tf.concat([elem[0:1], result], axis=0) for (elem, result) in zip(elems, results) ] return list(map(_interleave, even_elems, odd_elems)) return prefer_static.cond(at_base_case, base_value, recursive_case)
def _filter_one_step(step, observation, previous_particles, log_weights, transition_fn, observation_fn, proposal_fn, resample_criterion_fn, has_observation=True, seed=None): """Advances the particle filter by a single time step.""" with tf.name_scope('filter_one_step'): seed = SeedStream(seed, 'filter_one_step') num_particles = prefer_static.shape(log_weights)[0] proposed_particles, proposal_log_weights = _propose_with_log_weights( step=step - 1, particles=previous_particles, transition_fn=transition_fn, proposal_fn=proposal_fn, seed=seed) log_weights = tf.nn.log_softmax(proposal_log_weights + log_weights, axis=-1) # If this step has an observation, compute its weights and marginal # likelihood (and otherwise, leave weights unchanged). observation_log_weights = prefer_static.cond( has_observation, lambda: prefer_static.broadcast_to( # pylint: disable=g-long-lambda _compute_observation_log_weights(step, proposed_particles, observation, observation_fn), prefer_static.shape(log_weights)), lambda: tf.zeros_like(log_weights)) unnormalized_log_weights = log_weights + observation_log_weights step_log_marginal_likelihood = tf.math.reduce_logsumexp( unnormalized_log_weights, axis=0) log_weights = (unnormalized_log_weights - step_log_marginal_likelihood) # Adaptive resampling: resample particles iff the specified criterion. do_resample = resample_criterion_fn(unnormalized_log_weights) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to use # the resampled values for each batch element according to # `do_resample`. If there were no batching, we might prefer to use # `tf.cond` to avoid the resampling computation on steps where it's not # needed---but we're ultimately interested in adaptive resampling # for statistical (not computational) purposes, so this isn't a dealbreaker. resampled_particles, resample_indices = _resample(proposed_particles, log_weights, resample_independent, seed=seed) uniform_weights = (prefer_static.zeros_like(log_weights) - prefer_static.log(num_particles)) (resampled_particles, resample_indices, log_weights) = tf.nest.map_structure( lambda r, p: prefer_static.where(do_resample, r, p), (resampled_particles, resample_indices, uniform_weights), (proposed_particles, _dummy_indices_like(resample_indices), log_weights)) return ParticleFilterStepResults( particles=resampled_particles, log_weights=log_weights, parent_indices=resample_indices, step_log_marginal_likelihood=step_log_marginal_likelihood)
def hager_zhang(value_and_gradients_function, initial_step_size=None, value_at_initial_step=None, value_at_zero=None, threshold_use_approximate_wolfe_condition=1e-6, shrinkage_param=0.66, expansion_param=5.0, sufficient_decrease_param=0.1, curvature_param=0.9, step_size_shrink_param=0.1, max_iterations=50, name=None): """The Hager Zhang line search algorithm. Performs an inexact line search based on the algorithm of [Hager and Zhang (2006)][2]. The univariate objective function `value_and_gradients_function` is typically generated by projecting a multivariate objective function along a search direction. Suppose the multivariate function to be minimized is `g(x1,x2, .. xn)`. Let (d1, d2, ..., dn) be the direction along which we wish to perform a line search. Then the projected univariate function to be used for line search is ```None f(a) = g(x1 + d1 * a, x2 + d2 * a, ..., xn + dn * a) ``` The directional derivative along (d1, d2, ..., dn) is needed for this procedure. This also corresponds to the derivative of the projected function `f(a)` with respect to `a`. Note that this derivative must be negative for `a = 0` if the direction is a descent direction. The usual stopping criteria for the line search is the satisfaction of the (weak) Wolfe conditions. For details of the Wolfe conditions, see ref. [3]. On a finite precision machine, the exact Wolfe conditions can be difficult to satisfy when one is very close to the minimum and as argued by [Hager and Zhang (2005)][1], one can only expect the minimum to be determined within square root of machine precision. To improve the situation, they propose to replace the Wolfe conditions with an approximate version depending on the derivative of the function which is applied only when one is very close to the minimum. The following algorithm implements this enhanced scheme. ### Usage: Primary use of line search methods is as an internal component of a class of optimization algorithms (called line search based methods as opposed to trust region methods). Hence, the end user will typically not want to access line search directly. In particular, inexact line search should not be confused with a univariate minimization method. The stopping criteria of line search is the satisfaction of Wolfe conditions and not the discovery of the minimum of the function. With this caveat in mind, the following example illustrates the standalone usage of the line search. ```python # Define value and gradient namedtuple ValueAndGradient = namedtuple('ValueAndGradient', ['x', 'f', 'df']) # Define a quadratic target with minimum at 1.3. def value_and_gradients_function(x): return ValueAndGradient(x=x, f=(x - 1.3) ** 2, df=2 * (x-1.3)) # Set initial step size. step_size = tf.constant(0.1) ls_result = tfp.optimizer.linesearch.hager_zhang( value_and_gradients_function, initial_step_size=step_size) # Evaluate the results. with tf.Session() as session: results = session.run(ls_result) # Ensure convergence. assert results.converged # If the line search converged, the left and the right ends of the # bracketing interval are identical. assert results.left.x == result.right.x # Print the number of evaluations and the final step size. print ("Final Step Size: %f, Evaluations: %d" % (results.left.x, results.func_evals)) ``` ### References: [1]: William Hager, Hongchao Zhang. A new conjugate gradient method with guaranteed descent and an efficient line search. SIAM J. Optim., Vol 16. 1, pp. 170-172. 2005. https://www.math.lsu.edu/~hozhang/papers/cg_descent.pdf [2]: William Hager, Hongchao Zhang. Algorithm 851: CG_DESCENT, a conjugate gradient method with guaranteed descent. ACM Transactions on Mathematical Software, Vol 32., 1, pp. 113-137. 2006. http://users.clas.ufl.edu/hager/papers/CG/cg_compare.pdf [3]: Jorge Nocedal, Stephen Wright. Numerical Optimization. Springer Series in Operations Research. pp 33-36. 2006 Args: value_and_gradients_function: A Python callable that accepts a real scalar tensor and returns a namedtuple with the fields 'x', 'f', and 'df' that correspond to scalar tensors of real dtype containing the point at which the function was evaluated, the value of the function, and its derivative at that point. The other namedtuple fields, if present, should be tensors or sequences (possibly nested) of tensors. In usual optimization application, this function would be generated by projecting the multivariate objective function along some specific direction. The direction is determined by some other procedure but should be a descent direction (i.e. the derivative of the projected univariate function must be negative at 0.). Alternatively, the function may represent the batching of `n` such line functions (e.g. projecting a single multivariate objective function along `n` distinct directions at once) accepting n points as input, i.e. a tensor of shape [n], and the fields 'x', 'f' and 'df' in the returned namedtuple should each be a tensor of shape [n], with the corresponding input points, function values, and derivatives at those input points. initial_step_size: (Optional) Scalar positive `Tensor` of real dtype, or a tensor of shape [n] in batching mode. The initial value (or values) to try to bracket the minimum. Default is `1.` as a float32. Note that this point need not necessarily bracket the minimum for the line search to work correctly but the supplied value must be greater than 0. A good initial value will make the search converge faster. value_at_initial_step: (Optional) The full return value of evaluating value_and_gradients_function at initial_step_size, i.e. a namedtuple with 'x', 'f', 'df', if already known by the caller. If supplied the value of `initial_step_size` will be ignored, otherwise the tuple will be computed by evaluating value_and_gradients_function. value_at_zero: (Optional) The full return value of value_and_gradients_function at `0.`, i.e. a namedtuple with 'x', 'f', 'df', if already known by the caller. If not supplied the tuple will be computed by evaluating value_and_gradients_function. threshold_use_approximate_wolfe_condition: Scalar positive `Tensor` of real dtype. Corresponds to the parameter 'epsilon' in [Hager and Zhang (2006)][2]. Used to estimate the threshold at which the line search switches to approximate Wolfe conditions. shrinkage_param: Scalar positive Tensor of real dtype. Must be less than `1.`. Corresponds to the parameter `gamma` in [Hager and Zhang (2006)][2]. If the secant**2 step does not shrink the bracketing interval by this proportion, a bisection step is performed to reduce the interval width. expansion_param: Scalar positive `Tensor` of real dtype. Must be greater than `1.`. Used to expand the initial interval in case it does not bracket a minimum. Corresponds to `rho` in [Hager and Zhang (2006)][2]. sufficient_decrease_param: Positive scalar `Tensor` of real dtype. Bounded above by the curvature param. Corresponds to `delta` in the terminology of [Hager and Zhang (2006)][2]. curvature_param: Positive scalar `Tensor` of real dtype. Bounded above by `1.`. Corresponds to 'sigma' in the terminology of [Hager and Zhang (2006)][2]. step_size_shrink_param: Positive scalar `Tensor` of real dtype. Bounded above by `1`. If the supplied step size is too big (i.e. either the objective value or the gradient at that point is infinite), this factor is used to shrink the step size until it is finite. max_iterations: Positive scalar `Tensor` of integral dtype or None. The maximum number of iterations to perform in the line search. The number of iterations used to bracket the minimum are also counted against this parameter. name: (Optional) Python str. The name prefixed to the ops created by this function. If not supplied, the default name 'hager_zhang' is used. Returns: results: A namedtuple containing the following attributes. converged: Boolean `Tensor` of shape [n]. Whether a point satisfying Wolfe/Approx wolfe was found. failed: Boolean `Tensor` of shape [n]. Whether line search failed e.g. if either the objective function or the gradient are not finite at an evaluation point. iterations: Scalar int32 `Tensor`. Number of line search iterations made. func_evals: Scalar int32 `Tensor`. Number of function evaluations made. left: A namedtuple, as returned by value_and_gradients_function, of the left end point of the final bracketing interval. Values are equal to those of `right` on batch members where converged is True. Otherwise, it corresponds to the last interval computed. right: A namedtuple, as returned by value_and_gradients_function, of the right end point of the final bracketing interval. Values are equal to those of `left` on batch members where converged is True. Otherwise, it corresponds to the last interval computed. """ with tf.compat.v1.name_scope(name, 'hager_zhang', [ initial_step_size, value_at_initial_step, value_at_zero, threshold_use_approximate_wolfe_condition, shrinkage_param, expansion_param, sufficient_decrease_param, curvature_param ]): val_0, val_initial, f_lim, prepare_evals = _prepare_args( value_and_gradients_function, initial_step_size, value_at_initial_step, value_at_zero, threshold_use_approximate_wolfe_condition) valid_inputs = (hzl.is_finite(val_0) & (val_0.df < 0) & tf.math.is_finite(val_initial.x) & (val_initial.x > 0)) # Note: _fix_step_size returns immediately if either all inputs are invalid # or none need fixing. fix_step_evals, val_c, fix_failed = _fix_step_size( value_and_gradients_function, val_initial, valid_inputs, step_size_shrink_param) failed = ~valid_inputs | fix_failed init_interval = HagerZhangLineSearchResult( converged=tf.zeros_like(failed), # i.e. all False. failed=failed, func_evals=prepare_evals + fix_step_evals, iterations=tf.convert_to_tensor(value=0), left=val_0, right=val_c) def _apply_bracket_and_search(): """Bracketing and searching to do for valid inputs.""" return _bracket_and_search(value_and_gradients_function, init_interval, f_lim, max_iterations, shrinkage_param, expansion_param, sufficient_decrease_param, curvature_param) return prefer_static.cond(tf.reduce_any(input_tensor=~failed), _apply_bracket_and_search, lambda: init_interval)
def posterior_marginals(self, observations, mask=None, name=None): """Compute marginal posterior distribution for each state. This function computes, for each time step, the marginal conditional probability that the hidden Markov model was in each possible state given the observations that were made at each time step. So if the hidden states are `z[0],...,z[num_steps - 1]` and the observations are `x[0], ..., x[num_steps - 1]`, then this function computes `P(z[i] | x[0], ..., x[num_steps - 1])` for all `i` from `0` to `num_steps - 1`. This operation is sometimes called smoothing. It uses a form of the forward-backward algorithm. Note: the behavior of this function is undefined if the `observations` argument represents impossible observations from the model. Args: observations: A tensor representing a batch of observations made on the hidden Markov model. The rightmost dimension of this tensor gives the steps in a sequence of observations from a single sample from the hidden Markov model. The size of this dimension should match the `num_steps` parameter of the hidden Markov model object. The other dimensions are the dimensions of the batch and these are broadcast with the hidden Markov model's parameters. mask: optional bool-type `tensor` with rightmost dimension matching `num_steps` indicating which observations the result of this function should be conditioned on. When the mask has value `True` the corresponding observations aren't used. if `mask` is `None` then all of the observations are used. the `mask` dimensions left of the last are broadcast with the hmm batch as well as with the observations. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Returns: posterior_marginal: A `Categorical` distribution object representing the marginal probability of the hidden Markov model being in each state at each step. The rightmost dimension of the `Categorical` distributions batch will equal the `num_steps` parameter providing one marginal distribution for each step. The other dimensions are the dimensions corresponding to the batch of observations. Raises: ValueError: if rightmost dimension of `observations` does not have size `num_steps`. """ with tf.name_scope(name or "posterior_marginals"): with tf.control_dependencies(self._runtime_assertions): observation_tensor_shape = tf.shape(observations) mask_tensor_shape = tf.shape( mask) if mask is not None else None with self._observation_mask_shape_preconditions( observation_tensor_shape, mask_tensor_shape): observation_log_probs = self._observation_log_probs( observations, mask) log_prob = self._log_init + observation_log_probs[0] log_transition = self._log_trans log_adjoint_prob = tf.zeros_like(log_prob) def _scan_multiple_steps_forwards(): def forward_step(log_previous_step, log_prob_observation): return _log_vector_matrix( log_previous_step, log_transition) + log_prob_observation forward_log_probs = tf.scan(forward_step, observation_log_probs[1:], initializer=log_prob, name="forward_log_probs") return tf.concat([[log_prob], forward_log_probs], axis=0) forward_log_probs = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps_forwards, lambda: tf.convert_to_tensor([log_prob])) total_log_prob = tf.reduce_logsumexp(forward_log_probs[-1], axis=-1) def _scan_multiple_steps_backwards(): """Perform `scan` operation when `num_steps` > 1.""" def backward_step(log_previous_step, log_prob_observation): return _log_matrix_vector( log_transition, log_prob_observation + log_previous_step) backward_log_adjoint_probs = tf.scan( backward_step, observation_log_probs[1:], initializer=log_adjoint_prob, reverse=True, name="backward_log_adjoint_probs") return tf.concat( [backward_log_adjoint_probs, [log_adjoint_prob]], axis=0) backward_log_adjoint_probs = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps_backwards, lambda: tf.convert_to_tensor([log_adjoint_prob])) log_likelihoods = forward_log_probs + backward_log_adjoint_probs marginal_log_probs = distribution_util.move_dimension( log_likelihoods - total_log_prob[..., tf.newaxis], 0, -2) return categorical.Categorical(logits=marginal_log_probs)
def _get_search_direction(state): """Computes the search direction to follow at the current state. On the `k`-th iteration of the main L-BFGS algorithm, the state has collected the most recent `m` correction pairs in position_deltas and gradient_deltas, where `k = state.num_iterations` and `m = min(k, num_correction_pairs)`. Assuming these, the code below is an implementation of the L-BFGS two-loop recursion algorithm given by [Nocedal and Wright(2006)][1]: ```None q_direction = objective_gradient for i in reversed(range(m)): # First loop. inv_rho[i] = gradient_deltas[i]^T * position_deltas[i] alpha[i] = position_deltas[i]^T * q_direction / inv_rho[i] q_direction = q_direction - alpha[i] * gradient_deltas[i] kth_inv_hessian_factor = (gradient_deltas[-1]^T * position_deltas[-1] / gradient_deltas[-1]^T * gradient_deltas[-1]) r_direction = kth_inv_hessian_factor * I * q_direction for i in range(m): # Second loop. beta = gradient_deltas[i]^T * r_direction / inv_rho[i] r_direction = r_direction + position_deltas[i] * (alpha[i] - beta) return -r_direction # Approximates - H_k * objective_gradient. ``` Args: state: A `LBfgsOptimizerResults` tuple with the current state of the search procedure. Returns: A real `Tensor` of the same shape as the `state.position`. The direction along which to perform line search. """ # The number of correction pairs that have been collected so far. num_elements = tf.minimum( state.num_iterations, distribution_util.prefer_static_shape(state.position_deltas)[0]) def _two_loop_algorithm(): """L-BFGS two-loop algorithm.""" # Correction pairs are always appended to the end, so only the latest # `num_elements` vectors have valid position/gradient deltas. position_deltas = state.position_deltas[-num_elements:] gradient_deltas = state.gradient_deltas[-num_elements:] # Pre-compute all `inv_rho[i]`s. inv_rhos = tf.reduce_sum(gradient_deltas * position_deltas, axis=-1) def first_loop(acc, args): _, q_direction = acc position_delta, gradient_delta, inv_rho = args alpha = tf.reduce_sum(position_delta * q_direction, axis=-1) / inv_rho direction_delta = alpha[..., tf.newaxis] * gradient_delta return (alpha, q_direction - direction_delta) # Run first loop body computing and collecting `alpha[i]`s, while also # computing the updated `q_direction` at each step. zero = tf.zeros_like(inv_rhos[0]) alphas, q_directions = tf.scan( first_loop, [position_deltas, gradient_deltas, inv_rhos], initializer=(zero, state.objective_gradient), reverse=True) # We use `H^0_k = gamma_k * I` as an estimate for the initial inverse # hessian for the k-th iteration; then `r_direction = H^0_k * q_direction`. gamma_k = inv_rhos[-1] / tf.reduce_sum( gradient_deltas[-1] * gradient_deltas[-1], axis=-1) r_direction = gamma_k[..., tf.newaxis] * q_directions[0] def second_loop(r_direction, args): alpha, position_delta, gradient_delta, inv_rho = args beta = tf.reduce_sum(gradient_delta * r_direction, axis=-1) / inv_rho direction_delta = (alpha - beta)[..., tf.newaxis] * position_delta return r_direction + direction_delta # Finally, run second loop body computing the updated `r_direction` at each # step. r_directions = tf.scan( second_loop, [alphas, position_deltas, gradient_deltas, inv_rhos], initializer=r_direction) return -r_directions[-1] return prefer_static.cond(tf.equal(num_elements, 0), (lambda: -state.objective_gradient), _two_loop_algorithm)
def fit_one_step( model_matrix, response, model, model_coefficients_start=None, predicted_linear_response_start=None, l2_regularizer=None, dispersion=None, offset=None, learning_rate=None, fast_unsafe_numerics=True, l2_regularization_penalty_factor=None, name=None): """Runs one step of Fisher scoring. Args: model_matrix: (Batch of) `float`-like, matrix-shaped `Tensor` where each row represents a sample's features. response: (Batch of) vector-shaped `Tensor` where each element represents a sample's observed response (to the corresponding row of features). Must have same `dtype` as `model_matrix`. model: `tfp.glm.ExponentialFamily`-like instance used to construct the negative log-likelihood loss, gradient, and expected Hessian (i.e., the Fisher information matrix). model_coefficients_start: Optional (batch of) vector-shaped `Tensor` representing the initial model coefficients, one for each column in `model_matrix`. Must have same `dtype` as `model_matrix`. Default value: Zeros. predicted_linear_response_start: Optional `Tensor` with `shape`, `dtype` matching `response`; represents `offset` shifted initial linear predictions based on `model_coefficients_start`. Default value: `offset` if `model_coefficients is None`, and `tf.linalg.matvec(model_matrix, model_coefficients_start) + offset` otherwise. l2_regularizer: Optional scalar `Tensor` representing L2 regularization penalty, i.e., `loss(w) = sum{-log p(y[i]|x[i],w) : i=1..n} + l2_regularizer ||w||_2^2`. Default value: `None` (i.e., no L2 regularization). dispersion: Optional (batch of) `Tensor` representing `response` dispersion, i.e., as in, `p(y|theta) := exp((y theta - A(theta)) / dispersion)`. Must broadcast with rows of `model_matrix`. Default value: `None` (i.e., "no dispersion"). offset: Optional `Tensor` representing constant shift applied to `predicted_linear_response`. Must broadcast to `response`. Default value: `None` (i.e., `tf.zeros_like(response)`). learning_rate: Optional (batch of) scalar `Tensor` used to dampen iterative progress. Typically only needed if optimization diverges, should be no larger than `1` and typically very close to `1`. Default value: `None` (i.e., `1`). fast_unsafe_numerics: Optional Python `bool` indicating if solve should be based on Cholesky or QR decomposition. Default value: `True` (i.e., "prefer speed via Cholesky decomposition"). l2_regularization_penalty_factor: Optional (batch of) vector-shaped `Tensor`, representing a separate penalty factor to apply to each model coefficient, length equal to columns in `model_matrix`. Each penalty factor multiplies l2_regularizer to allow differential regularization. Can be 0 for some coefficients, which implies no regularization. Default is 1 for all coefficients. `loss(w) = sum{-log p(y[i]|x[i],w) : i=1..n} + l2_regularizer ||w * l2_regularization_penalty_factor||_2^2` name: Python `str` used as name prefix to ops created by this function. Default value: `"fit_one_step"`. Returns: model_coefficients: (Batch of) vector-shaped `Tensor`; represents the next estimate of the model coefficients, one for each column in `model_matrix`. predicted_linear_response: `response`-shaped `Tensor` representing linear predictions based on new `model_coefficients`, i.e., `tf.linalg.matvec(model_matrix, model_coefficients_next) + offset`. """ with tf.name_scope(name or 'fit_one_step'): [ model_matrix, response, model_coefficients_start, predicted_linear_response_start, offset, ] = prepare_args( model_matrix, response, model_coefficients_start, predicted_linear_response_start, offset) # Compute: mean, grad(mean, predicted_linear_response_start), and variance. mean, variance, grad_mean = model(predicted_linear_response_start) # If either `grad_mean` or `variance is non-finite or zero, then we'll # replace it with a value such that the row is zeroed out. Although this # procedure may seem circuitous, it is necessary to ensure this algorithm is # itself differentiable. is_valid = ( tf.math.is_finite(grad_mean) & tf.not_equal(grad_mean, 0.) & tf.math.is_finite(variance) & (variance > 0.)) def mask_if_invalid(x, mask): return tf.where( is_valid, x, np.array(mask, dtype_util.as_numpy_dtype(x.dtype))) # Run one step of iteratively reweighted least-squares. # Compute "`z`", the adjusted predicted linear response. # z = predicted_linear_response_start # + learning_rate * (response - mean) / grad_mean z = (response - mean) / mask_if_invalid(grad_mean, 1.) # TODO(jvdillon): Rather than use learning rate, we should consider using # backtracking line search. if learning_rate is not None: z *= learning_rate[..., tf.newaxis] z += predicted_linear_response_start if offset is not None: z -= offset # Compute "`w`", the per-sample weight. if dispersion is not None: # For convenience, we'll now scale the variance by the dispersion factor. variance *= dispersion w = ( mask_if_invalid(grad_mean, 0.) * tf.math.rsqrt(mask_if_invalid(variance, np.inf))) a = model_matrix * w[..., tf.newaxis] b = z * w # Solve `min{ || A @ model_coefficients - b ||_2**2 : model_coefficients }` # where `@` denotes `matmul`. if l2_regularizer is None: l2_regularizer = np.array(0, dtype_util.as_numpy_dtype(a.dtype)) else: l2_regularizer_ = distribution_util.maybe_get_static_value( l2_regularizer, dtype_util.as_numpy_dtype(a.dtype)) if l2_regularizer_ is not None: l2_regularizer = l2_regularizer_ def _embed_l2_regularization(): """Adds synthetic observations to implement L2 regularization.""" # `tf.matrix_solve_ls` does not respect the `l2_regularization` argument # when `fast_unsafe_numerics` is `False`. This function adds synthetic # observations to the data to implement the regularization instead. # Adding observations `sqrt(l2_regularizer) * I` is mathematically # equivalent to adding the term # `-l2_regularizer ||coefficients||_2**2` to the log-likelihood. num_model_coefficients = num_cols(model_matrix) batch_shape = tf.shape(model_matrix)[:-2] if l2_regularization_penalty_factor is None: eye = tf.eye( num_model_coefficients, batch_shape=batch_shape, dtype=a.dtype) else: eye = tf.linalg.tensor_diag( tf.cast(l2_regularization_penalty_factor, dtype=a.dtype)) broadcasted_shape = prefer_static.concat( [batch_shape, [num_model_coefficients, num_model_coefficients]], axis=0) eye = tf.broadcast_to(eye, broadcasted_shape) a_ = tf.concat([a, tf.sqrt(l2_regularizer) * eye], axis=-2) b_ = distribution_util.pad( b, count=num_model_coefficients, axis=-1, back=True) # Return l2_regularizer=0 since its now embedded. l2_regularizer_ = np.array(0, dtype_util.as_numpy_dtype(a.dtype)) return a_, b_, l2_regularizer_ a, b, l2_regularizer = prefer_static.cond( prefer_static.reduce_all([ prefer_static.logical_or( not(fast_unsafe_numerics), l2_regularization_penalty_factor is not None), l2_regularizer > 0. ]), _embed_l2_regularization, lambda: (a, b, l2_regularizer)) model_coefficients_next = tf.linalg.lstsq( a, b[..., tf.newaxis], fast=fast_unsafe_numerics, l2_regularizer=l2_regularizer, name='model_coefficients_next') model_coefficients_next = model_coefficients_next[..., 0] # TODO(b/79122261): The approach used in `matrix_solve_ls` could be made # faster by avoiding explicitly forming Q and instead keeping the # factorization in 'implicit' form with stacked (rescaled) Householder # vectors underneath the 'R' and then applying the (accumulated) # reflectors in the appropriate order to apply Q'. However, we don't # presently do this because we lack core TF functionality. For reference, # the vanilla QR approach is: # q, r = tf.linalg.qr(a) # c = tf.matmul(q, b, adjoint_a=True) # model_coefficients_next = tf.matrix_triangular_solve( # r, c, lower=False, name='model_coefficients_next') predicted_linear_response_next = compute_predicted_linear_response( model_matrix, model_coefficients_next, offset, name='predicted_linear_response_next') return model_coefficients_next, predicted_linear_response_next
def test_missing_arg2(self): x = tf.constant(1) with self.assertRaises(TypeError): ps.cond(True, lambda: x)