def model(): # Shared birthplace x_start = yield Root( tfd.Categorical(probs=tf.ones(n) / n, dtype=tf.int32)) y_start = yield tfd.Categorical(probs=tf.ones(n) / n, dtype=tf.int32) x = m * [x_start] y = m * [y_start] for t in range(n_steps): for i in range(m): # Construct PDF for next step in walk # Start with PDF for all mass on current point. ox = tf.one_hot(x[i], n) oy = tf.one_hot(y[i], n) o = ox[..., :, None] * oy[..., None, :] # Deliberate choice of non-centered distribution as # reduced symmetry lends itself to better testing. p = (0.1 * tf.roll(o, shift=[0, -1], axis=[-2, -1]) + 0.2 * tf.roll(o, shift=[0, 1], axis=[-2, -1]) + 0.3 * tf.roll(o, shift=[-1, 0], axis=[-2, -1]) + 0.4 * tf.roll(o, shift=[1, 0], axis=[-2, -1])) # Reshape just last two dimensions. p = tf.reshape(p, _cat(p.shape[:-2], [-1])) xy = yield tfd.Categorical(probs=p, dtype=tf.int32) x[i] = xy // n y[i] = xy % n # 2 * m noisy 2D observations at end for i in range(m): yield tfd.Normal(tf.cast(x[i], dtype=tf.float32), scale=2.0) yield tfd.Normal(tf.cast(y[i], dtype=tf.float32), scale=2.0)
def overlap_add_conv1d(inputs: tf.Tensor, filters: tf.Tensor, fft_length: int = 4096) -> tf.Tensor: """FFT based convolution in 1D, using the overlap-add method. Args: inputs: a tf.Tensor<float>[batch_size, seq_length, 1] of input sequences. filters: a tf.Tensor<float>[filter_length, 1, channels]. fft_length: an int, the length of the Fourier transform. Returns: A tf.Tensor<float>[batch_size, seq_length, channels] containing the response to the 1D convolutions with the filters. """ seq_len = tf.shape(inputs)[1] filter_len = tf.shape(filters)[0] overlap = filter_len - 1 seg_size = fft_length - overlap f_filters = tf.expand_dims(tf.signal.rfft(tf.transpose(filters, (1, 2, 0)), fft_length=[fft_length]), axis=2) framed = tf.signal.frame(tf.transpose(inputs, (0, 2, 1)), frame_length=seg_size, frame_step=seg_size, pad_end=True) paddings = [[0, 0], [0, 0], [0, 0], [int(overlap / 2), int(overlap / 2)]] framed = tf.pad(framed, paddings) f_inputs = tf.signal.rfft(framed, fft_length=[fft_length]) result = tf.signal.irfft(f_inputs * tf.math.conj(f_filters)) result = tf.roll(result, filter_len // 2 - 1, axis=-1) result = tf.signal.overlap_and_add(result, frame_step=seg_size) result = tf.transpose(result, (0, 2, 1))[..., :seq_len, :] output = tf.roll(result, 1 - filter_len // 2, axis=1) shape = tf.concat([tf.shape(inputs)[:-1], tf.shape(filters)[-1:]], axis=0) return tf.reshape(output, shape)
def roll(a, shift, axis=None): # pylint: disable=missing-docstring a = asarray(a).data if axis is not None: return utils.tensor_to_ndarray(tf.roll(a, shift, axis)) # If axis is None, the roll happens as a 1-d tensor. original_shape = tf.shape(a) a = tf.roll(tf.reshape(a, [-1]), shift, 0) return utils.tensor_to_ndarray(tf.reshape(a, original_shape))
def fft_conv1d(inputs: tf.Tensor, filters: tf.Tensor) -> Tuple[tf.Tensor]: """FFT based convolution in 1D. We round the input length to the closest upper power of 2 before FFT. Args: inputs: a tf.Tensor<float>[batch_size, seq_length, 1] of input sequences. filters: a tf.Tensor<float>[filter_length, 1, channels] Returns: A tf.Tensor<float>[batch_size, seq_length, channels] containing the response to the 1D convolutions with the filters, and a tf.Tensor<float>[1,] with the L1 norm of the filters FFT. """ seq_length = tf.shape(inputs)[1] fft_length = upper_power_of_2(seq_length) filter_length = tf.shape(filters)[0] f_inputs = tf.signal.rfft(tf.transpose(inputs, (0, 2, 1)), fft_length=[fft_length]) f_filters = tf.signal.rfft(tf.transpose(filters, (1, 2, 0)), fft_length=[fft_length]) f_filters_l1 = tf.reduce_sum(tf.math.abs(f_filters)) result = tf.transpose(tf.signal.irfft(f_inputs * tf.math.conj(f_filters)), (0, 2, 1)) output = tf.roll(result, filter_length // 2 - 1, axis=1) output = output[:, :seq_length, :] shape = tf.concat([tf.shape(inputs)[:-1], tf.shape(filters)[-1:]], axis=0) return tf.reshape(output, shape), f_filters_l1
def nearby_difference(x): """Compute L2 norms for nearby entries in a batch.""" # This is a very rough measure of diversity. with tf.device('cpu'): x1 = tf.reshape(x, shape=[int(x.shape[0]), -1]) x2 = tf.roll(x1, shift=1, axis=0) return tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(x1, x2))))
def distribution_fn(sample): num_frames = sample.shape[-1] mask = tf.one_hot(0, num_frames)[:, tf.newaxis] probs = tf.roll(tf.one_hot(sample, 3), shift=1, axis=-2) probs = probs * (1.0 - mask) + tf.convert_to_tensor([0.5, 0.5, 0]) * mask return tfd.Independent(tfd.Categorical(probs=probs), reinterpreted_batch_ndims=1)
def _sample_bates(total_count, low, high, n, seed=None): """Vectorized production of `Bates` samples. Args: total_count: (Batches of) counts of `Uniform`s to take means of. Should have integer dtype and already be broadcasted to the batch shape. low: (Batches of) lower bounds of the `Uniform` variables to sample. Should be the same floating dtype as `high` and broadcastable to the batch shape. high: (Batches of) upper bounds of the `Uniform` variables to sample. Should be the same floating dtype as `low` and broadcastable to the batch shape. n: `int32` number of samples to generate. seed: Random seed to pass to `Uniform` sampler. Returns: samples: Samples of (batches of) the `Bates` variable. Will have same dtype as `low` and `high`. If the batch shape is `[B1,..., Bn]`, `samples` has shape `[n, B1,..., Bn]`. """ # 1. Sample Uniform(0, 1)s, flattening the batch dimension into axis 0. uniform_sample_shape = tf.concat([[tf.reduce_sum(total_count)], [n]], axis=0) uniform_samples = samplers.uniform( uniform_sample_shape, minval=0., maxval=1., dtype=low.dtype, seed=seed) # 2. Produce segment means. segment_lengths = tf.reshape(total_count, [-1]) segment_ids = tf.repeat(tf.range(tf.size(segment_lengths)), segment_lengths) flatmeans = tf.math.segment_mean(uniform_samples, segment_ids) # 3. Reshape and transpose segment means back to the original shape. outshape = tf.concat([tf.shape(total_count), [n]], axis=0) tmeans = tf.reshape(flatmeans, outshape) axes = tf.range(tf.rank(tmeans)) means = tf.transpose(tmeans, tf.roll(axes, shift=1, axis=0)) # 4. Shift/scale from (0, 1) to (low, high). return low + (high - low) * means
def one_hot_minus(inputs, shift): """Performs (inputs - shift) % vocab_size in the one-hot space. Args: inputs: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot Tensor. shift: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot Tensor specifying how much to shift the corresponding one-hot vector in inputs. Soft values perform a "weighted shift": for example, shift=[0.2, 0.3, 0.5] performs a linear combination of 0.2 * shifting by zero; 0.3 * shifting by one; and 0.5 * shifting by two. Returns: Tensor of same shape and dtype as inputs. """ # TODO(trandustin): Implement with circular conv1d. inputs = tf.convert_to_tensor(inputs) shift = tf.cast(shift, inputs.dtype) vocab_size = inputs.shape[-1] if isinstance(vocab_size, tf1.Dimension): vocab_size = vocab_size.value # Form a [..., vocab_size, vocab_size] matrix. Each batch element of # inputs will vector-matrix multiply the vocab_size x vocab_size matrix. This # "shifts" the inputs batch element by the corresponding shift batch element. shift_matrix = tf.stack( [tf.roll(shift, i, axis=-1) for i in range(vocab_size)], axis=-2) outputs = tf.einsum('...v,...uv->...u', inputs, shift_matrix) return outputs
def my_fn(ex): for feat in ["inputs", "targets"]: tokens = ex[feat] res = ex.copy() n_tokens = tf.size(tokens) random_shift = tf.random.uniform([], maxval=n_tokens, dtype=tf.int32) res[feat] = tf.roll(tokens, shift=random_shift, axis=0) return res
def grad(dy): dshp = tf.shape(x) - tf.shape(dy) z = tf.zeros(tf.where(tf.equal(0, dshp), tf.shape(x), dshp), dtype=x.dtype) dx = tf.roll(tf.concat([dy, z], axis=axis), shift=begin[axis], axis=axis) return dx
def random_token_preprocessor(ex, seed): """Selects a random shift to roll the tokens by for each feature.""" for feat in ["inputs", "targets"]: tokens = ex[feat] res = ex.copy() n_tokens = tf.size(tokens) random_shift = tf.random.stateless_uniform( [], maxval=n_tokens, dtype=tf.int32, seed=seed) res[feat] = tf.roll(tokens, shift=random_shift, axis=0) return res
def prep_fn(image, label): """Image preprocessing function.""" if config.roll_pixels: image = tf.roll(image, config.roll_pixels, -2) if is_training: image = tf.image.random_flip_left_right(image) image = tf.pad(image, [[4, 4], [4, 4], [0, 0]]) image = tf.image.random_crop(image, CIFAR_SHAPE) image = tf.image.convert_image_dtype(image, tf.float32) return image, label
def _week_day_mappers(weekend_mask): """Creates functions to map from ordinals to week days and inverse. Creates functions to map from ordinal space (i.e. days since 31 Dec 0) to week days. The function assigns the value of 0 to the first non weekend day in the week starting on Sunday, 31 Dec 1 through to Saturday, 6 Jan 1 and the value assigned to each successive work day is incremented by 1. For a day that is not a week day, this count is not incremented from the previous week day (hence, multiple ordinal days may have the same week day value). Args: weekend_mask: A bool `Tensor` of length 7 or None. The weekend mask. Returns: A tuple of callables. `forward`: Takes one `Tensor` argument containing ordinals and returns a tuple of two `Tensor`s of the same shape as the input. The first `Tensor` is of type `int32` and contains the week day value. The second is a bool `Tensor` indicating whether the supplied ordinal was a weekend day (i.e. True where the day is a weekend day and False otherwise). `backward`: Takes one int32 `Tensor` argument containing week day values and returns an int32 `Tensor` containing ordinals for those week days. """ if weekend_mask is None: default_forward = lambda x: (x, tf.zeros_like(x, dtype=tf.bool)) identity = lambda x: x return default_forward, identity weekend_mask = tf.convert_to_tensor(weekend_mask, dtype=tf.bool) weekend_mask = tf.roll(weekend_mask, -_DAYOFWEEK_0, axis=0) weekday_mask = tf.logical_not(weekend_mask) weekday_offsets = tf.cumsum(tf.cast(weekday_mask, dtype=tf.int32)) num_workdays = weekday_offsets[-1] weekday_offsets -= 1 # Adjust the first workday to index 0. ordinal_offsets = tf.convert_to_tensor([0, 1, 2, 3, 4, 5, 6], dtype=tf.int32) ordinal_offsets = ordinal_offsets[weekday_mask] def forward(ordinals): """Adjusts the ordinals by removing the number of weekend days so far.""" mod, remainder = ordinals // 7, ordinals % 7 weekday_values = mod * num_workdays + tf.gather( weekday_offsets, remainder) is_weekday = tf.gather(weekday_mask, remainder) return weekday_values, is_weekday def backward(weekday_values): """Converts from weekend adjusted values to ordinals.""" return ((weekday_values // num_workdays) * 7 + tf.gather(ordinal_offsets, weekday_values % num_workdays)) return forward, backward
def _shift_right_by_one(tensor: tf.Tensor, axis: int = -1) -> tf.Tensor: """Shift the 1d input tensor to the right by one position without wrapping.""" if not tensor.dtype.is_integer: raise ValueError("Only integer types are supported.") # tf.roll wraps around the axis. rolled = tf.roll(tensor, shift=1, axis=axis) # Zero out the first position by multiplying with [0, 1, 1, ..., 1]. reverse_onehot = tf.one_hot(0, depth=tensor.shape[axis], on_value=0, off_value=1, dtype=tensor.dtype) return rolled * reverse_onehot
def _do_update(x_update_diff_norm_sq, x_update, hess_matmul_x_update): # pylint: disable=missing-docstring hessian_column_with_l2 = sparse_or_dense_matvecmul( hessian_unregularized_loss_outer, hessian_unregularized_loss_middle * _sparse_or_dense_matmul_onehot( hessian_unregularized_loss_outer, coord), adjoint_a=True) if l2_regularizer is not None: hessian_column_with_l2 += _one_hot_like( hessian_column_with_l2, coord, on_value=2. * l2_regularizer) # Move the batch dimensions of `hessian_column_with_l2` to rightmost in # order to conform to `hess_matmul_x_update`. n = tf.rank(hessian_column_with_l2) perm = tf.roll(tf.range(n), shift=1, axis=0) hessian_column_with_l2 = tf.transpose(a=hessian_column_with_l2, perm=perm) # Update the entire batch at `coord` even if `delta` may be 0 at some # batch coordinates. In those cases, adding `delta` is a no-op. x_update = tf.tensor_scatter_nd_add(x_update, [[coord]], [delta]) with tf.control_dependencies([x_update]): x_update_diff_norm_sq_ = x_update_diff_norm_sq + delta**2 hess_matmul_x_update_ = (hess_matmul_x_update + delta * hessian_column_with_l2) # Hint that loop vars retain the same shape. x_update_diff_norm_sq_.set_shape( x_update_diff_norm_sq_.shape.merge_with( x_update_diff_norm_sq.shape)) hess_matmul_x_update_.set_shape( hess_matmul_x_update_.shape.merge_with( hess_matmul_x_update.shape)) return [ x_update_diff_norm_sq_, x_update, hess_matmul_x_update_ ]
def crps_score(labels=None, predictive_samples=None): r"""Computes the Continuous Ranked Probability Score (CRPS). The Continuous Ranked Probability Score is a [proper scoring rule][1] for assessing the probabilistic predictions of a model against a realized value. The CRPS is \\(\textrm{CRPS}(F,y) = \int_{-\inf}^{\inf} (F(z) - 1_{z \geq y})^2 dz.\\) Here \\(F\\) is the cumulative distribution function of the model predictive distribution and \\(y)\\ is the realized ground truth value. The CRPS can be used as a loss function for training an implicit model for probabilistic regression. It can also be used to assess the predictive performance of a probabilistic regression model. In this implementation we use an equivalent representation of the CRPS, \\(\textrm{CRPS}(F,y) = E_{z~F}[|z-y|] - (1/2) E_{z,z'~F}[|z-z'|].\\) This equivalent representation has an unbiased sample estimate and our implementation of the CRPS has a complexity is O(n m). #### References [1]: Tilmann Gneiting, Adrian E. Raftery. Strictly Proper Scoring Rules, Prediction, and Estimation. Journal of the American Statistical Association, Vol. 102, 2007. https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf Args: labels: Tensor, (n,), with tf.float32 or tf.float64 real-valued targets. predictive_samples: Tensor, (n,m), with tf.float32 or tf.float64 values. Each row at [i,:] contains m samples of the model predictive distribution, p(y|x_i). Returns: crps: (n,) Tensor, the CRPS score for each instance; a lower score indicates a better fit. """ if labels is None: raise ValueError("target labels must be provided") if labels.shape.ndims != 1: raise ValueError("target labels must be of rank 1") if predictive_samples is None: raise ValueError("predictive samples must be provided") if predictive_samples.shape.ndims != 2: raise ValueError("predictive samples must be of rank 2") if predictive_samples.shape[0] != labels.shape[0]: raise ValueError("first dimension of predictive samples shape " "must match target labels shape") pairwise_diff = tf.roll(predictive_samples, 1, axis=1) - predictive_samples predictive_diff = tf.abs(pairwise_diff) estimated_dist_pairwise = tf.reduce_mean(input_tensor=predictive_diff, axis=1) labels = tf.expand_dims(labels, 1) dist_realization = tf.reduce_mean(tf.abs(predictive_samples - labels), axis=1) crps = dist_realization - 0.5 * estimated_dist_pairwise return crps
def rollaxis(x, shift): return tf.transpose(x, tf.roll(tf.range(tf.rank(x)), shift=shift, axis=0))
def _parse_fn(record): """Parses a record into a feature_dict.""" feature_values = tf.io.parse_single_example( serialized=record, features={ 'i/o': tf.io.FixedLenFeature([], tf.string, default_value=''), 'program_encoding': tf.io.FixedLenFeature([], tf.string, default_value=''), }) ios = tf.strings.split(tf.strings.split(feature_values['i/o'], sep='>'), sep='<') inputs, outputs = ios.merge_dims(0, 1)[::2], ios.merge_dims(0, 1)[1::2] # Parse inputs into tokens. inputs = tf.strings.unicode_split(inputs, 'UTF-8').to_tensor() inputs = spec_vocab_table.lookup(inputs) # Map characters to tokens. # Parse outputs into tokens. outputs_with_separators = (tf.strings.unicode_split( outputs, 'UTF-8').to_tensor()) outputs_with_separators = spec_vocab_table.lookup( outputs_with_separators) split_outputs = tf.strings.unicode_split( tf.strings.split(outputs, sep='|'), 'UTF-8') outputs = split_outputs.merge_dims(1, 2).to_tensor() outputs = spec_vocab_table.lookup(outputs) # Compute indices for the start of each part of the spec, w.r.t. the # original spec. separator_indices = tf.where( tf.equal(outputs_with_separators, separator_id))[:, 1] separator_indices = tf.reshape( separator_indices, (tf.shape(outputs_with_separators)[0], -1)) start_indices = separator_indices - tf.expand_dims( tf.range(tf.shape(separator_indices)[1], dtype=tf.int64), 0) start_indices = tf.concat((tf.zeros( (tf.shape(start_indices)[0], 1), dtype=tf.int64), start_indices), axis=1) num_examples = tf.shape(start_indices)[0] num_parts = tf.shape(start_indices)[1] # Construct the shifted spec suffixes. flat_start_indices = tf.reshape(start_indices, (-1, )) prefix_mask = (1 - tf.sequence_mask( flat_start_indices, maxlen=tf.shape(outputs)[-1], dtype=tf.int64)) masked_outputs = tf.repeat(outputs, num_parts, axis=0) * prefix_mask output_suffixes = tf.vectorized_map( fn=lambda x: tf.roll(x[0], x[1], axis=0), elems=(masked_outputs, -flat_start_indices)) # Compute indices for the start/end of spec parts, w.r.t. the shifted spec # suffixes. ground_truth_start_indices = tf.zeros((num_examples * num_parts, ), dtype=tf.int64) cumulative_end_indices = tf.concat( (start_indices, tf.math.count_nonzero(outputs, axis=-1, keepdims=True)), axis=1) ground_truth_end_indices = tf.reshape( cumulative_end_indices[:, 1:] - cumulative_end_indices[:, :-1], (-1, )) # Construct the actual spec parts to predict. range_indices = tf.expand_dims(tf.range(tf.shape(output_suffixes)[-1], dtype=tf.int64), axis=0) part_mask = tf.where( tf.logical_and( range_indices >= tf.expand_dims(ground_truth_start_indices, axis=1), range_indices < tf.expand_dims(ground_truth_end_indices, axis=1)), 1, 0) output_parts = output_suffixes * tf.cast(part_mask, tf.int64) output_parts = tf.pad(output_parts, [[0, 0], [0, 1]]) # Make room for sep. # TODO(kshi): roll output_parts leftward by start_indices for SCAN. first_zero_index = tf.math.count_nonzero(output_parts, axis=-1) output_parts += tf.one_hot(first_zero_index, depth=tf.shape(output_parts)[-1], dtype=tf.int64) * separator_id # Reshape everything so that different spec suffixes become different # dataset elements. output_suffixes_reshaped = tf.transpose( tf.reshape(output_suffixes, (num_examples, num_parts, -1)), (1, 0, 2)) output_parts_reshaped = tf.transpose( tf.reshape(output_parts, (num_examples, num_parts, -1)), (1, 0, 2)) inputs_reshaped = tf.reshape(tf.tile(inputs, (num_parts, 1)), (num_parts, num_examples, -1)) ground_truth_start_indices_reshaped = tf.transpose( tf.reshape(ground_truth_start_indices, (num_examples, num_parts))) ground_truth_end_indices_reshaped = tf.transpose( tf.reshape(ground_truth_end_indices, (num_examples, num_parts))) # Combine spec parts from all examples into one sequence with separator # tokens between examples and ending in EOS. shifts = tf.cumsum(tf.concat((tf.zeros( (num_parts, 1), dtype=tf.int64), ground_truth_end_indices_reshaped[:, :-1] + 1), 1), axis=-1) flat_shifts = tf.reshape(shifts, (-1, )) output_len = tf.shape(output_parts_reshaped)[-1] flat_spec_parts = tf.reshape(output_parts_reshaped, (-1, output_len)) flat_spec_parts = tf.pad(flat_spec_parts, [[0, 0], [0, max_target_length - output_len]]) combined_spec_parts = tf.vectorized_map( fn=lambda x: tf.roll(x[0], x[1], axis=0), elems=(flat_spec_parts, flat_shifts)) combined_spec_parts = tf.reshape(combined_spec_parts, (num_parts, num_examples, -1)) combined_spec_parts = tf.reduce_sum(combined_spec_parts, axis=1) first_zero_index = tf.math.count_nonzero(combined_spec_parts, axis=-1) combined_spec_parts += tf.one_hot( first_zero_index, depth=tf.shape(combined_spec_parts)[-1], dtype=tf.int64) * eos_id # Create a dataset containing data for all spec suffixes. dataset = tf.data.Dataset.from_tensor_slices({ 'inputs': inputs_reshaped, 'outputs': output_suffixes_reshaped, 'spec_parts': combined_spec_parts, 'start_index': ground_truth_start_indices_reshaped, 'end_index': ground_truth_end_indices_reshaped }) return dataset
def roll_fn(image, label): """Function to roll pixels.""" image = tf.roll(image, config.roll_pixels, -2) return image, label
def minimize_one_step(gradient_unregularized_loss, hessian_unregularized_loss_outer, hessian_unregularized_loss_middle, x_start, tolerance, l1_regularizer, l2_regularizer=None, maximum_full_sweeps=1, learning_rate=None, name=None): """One step of (the outer loop of) the minimization algorithm. This function returns a new value of `x`, equal to `x_start + x_update`. The increment `x_update in R^n` is computed by a coordinate descent method, that is, by a loop in which each iteration updates exactly one coordinate of `x_update`. (Some updates may leave the value of the coordinate unchanged.) The particular update method used is to apply an L1-based proximity operator, "soft threshold", whose fixed point `x_update_fix` is the desired minimum ```none x_update_fix = argmin{ Loss(x_start + x_update') + l1_regularizer * ||x_start + x_update'||_1 + l2_regularizer * ||x_start + x_update'||_2**2 : x_update' } ``` where in each iteration `x_update'` is constrained to have at most one nonzero coordinate. This update method preserves sparsity, i.e., tends to find sparse solutions if `x_start` is sparse. Additionally, the choice of step size is based on curvature (Hessian), which significantly speeds up convergence. This algorithm assumes that `Loss` is convex, at least in a region surrounding the optimum. (If `l2_regularizer > 0`, then only weak convexity is needed.) Args: gradient_unregularized_loss: (Batch of) `Tensor` with the same shape and dtype as `x_start` representing the gradient, evaluated at `x_start`, of the unregularized loss function (denoted `Loss` above). (In all current use cases, `Loss` is the negative log likelihood.) hessian_unregularized_loss_outer: (Batch of) `Tensor` or `SparseTensor` having the same dtype as `x_start`, and shape `[N, n]` where `x_start` has shape `[n]`, satisfying the property `Transpose(hessian_unregularized_loss_outer) @ diag(hessian_unregularized_loss_middle) @ hessian_unregularized_loss_inner = (approximation of) Hessian matrix of Loss, evaluated at x_start`. hessian_unregularized_loss_middle: (Batch of) vector-shaped `Tensor` having the same dtype as `x_start`, and shape `[N]` where `hessian_unregularized_loss_outer` has shape `[N, n]`, satisfying the property `Transpose(hessian_unregularized_loss_outer) @ diag(hessian_unregularized_loss_middle) @ hessian_unregularized_loss_inner = (approximation of) Hessian matrix of Loss, evaluated at x_start`. x_start: (Batch of) vector-shaped, `float` `Tensor` representing the current value of the argument to the Loss function. tolerance: scalar, `float` `Tensor` representing the convergence threshold. The optimization step will terminate early, returning its current value of `x_start + x_update`, once the following condition is met: `||x_update_end - x_update_start||_2 / (1 + ||x_start||_2) < sqrt(tolerance)`, where `x_update_end` is the value of `x_update` at the end of a sweep and `x_update_start` is the value of `x_update` at the beginning of that sweep. l1_regularizer: scalar, `float` `Tensor` representing the weight of the L1 regularization term (see equation above). If L1 regularization is not required, then `tfp.glm.fit_one_step` is preferable. l2_regularizer: scalar, `float` `Tensor` representing the weight of the L2 regularization term (see equation above). Default value: `None` (i.e., no L2 regularization). maximum_full_sweeps: Python integer specifying maximum number of sweeps to run. A "sweep" consists of an iteration of coordinate descent on each coordinate. After this many sweeps, the algorithm will terminate even if convergence has not been reached. Default value: `1`. learning_rate: scalar, `float` `Tensor` representing a multiplicative factor used to dampen the proximal gradient descent steps. Default value: `None` (i.e., factor is conceptually `1`). name: Python string representing the name of the TensorFlow operation. The default name is `"minimize_one_step"`. Returns: x: (Batch of) `Tensor` having the same shape and dtype as `x_start`, representing the updated value of `x`, that is, `x_start + x_update`. is_converged: scalar, `bool` `Tensor` indicating whether convergence occurred across all batches within the specified number of sweeps. iter: scalar, `int` `Tensor` representing the actual number of coordinate updates made (before achieving convergence). Since each sweep consists of `tf.size(x_start)` iterations, the maximum number of updates is `maximum_full_sweeps * tf.size(x_start)`. #### References [1]: Jerome Friedman, Trevor Hastie and Rob Tibshirani. Regularization Paths for Generalized Linear Models via Coordinate Descent. _Journal of Statistical Software_, 33(1), 2010. https://www.jstatsoft.org/article/view/v033i01/v33i01.pdf [2]: Guo-Xun Yuan, Chia-Hua Ho and Chih-Jen Lin. An Improved GLMNET for L1-regularized Logistic Regression. _Journal of Machine Learning Research_, 13, 2012. http://www.jmlr.org/papers/volume13/yuan12a/yuan12a.pdf """ with tf.name_scope(name or 'minimize_one_step'): x_shape = _get_shape(x_start) batch_shape = x_shape[:-1] dims = x_shape[-1] def _hessian_diag_elt_with_l2(coord): # pylint: disable=missing-docstring # Returns the (coord, coord) entry of # # Hessian(UnregularizedLoss(x) + l2_regularizer * ||x||_2**2) # # evaluated at x = x_start. inner_square = tf.reduce_sum( _sparse_or_dense_matmul_onehot( hessian_unregularized_loss_outer, coord)**2, axis=-1) unregularized_component = inner_square * tf.gather( hessian_unregularized_loss_middle, coord, axis=-1) l2_component = _mul_or_none(2., l2_regularizer) return _add_ignoring_nones(unregularized_component, l2_component) grad_loss_with_l2 = _add_ignoring_nones( gradient_unregularized_loss, _mul_or_none(2., l2_regularizer, x_start)) # We define `x_update_diff_norm_sq_convergence_threshold` such that the # convergence condition # ||x_update_end - x_update_start||_2 / (1 + ||x_start||_2) # < sqrt(tolerance) # is equivalent to # ||x_update_end - x_update_start||_2**2 # < x_update_diff_norm_sq_convergence_threshold. x_update_diff_norm_sq_convergence_threshold = ( tolerance * (1. + tf.norm(tensor=x_start, ord=2, axis=-1))**2) # Reshape update vectors so that the coordinate sweeps happen along the # first dimension. This is so that we can use tensor_scatter_update to make # sparse updates along the first axis without copying the Tensor. # TODO(b/118789120): Switch to something like tf.tensor_scatter_nd_add if # or when it exists. update_shape = tf.concat([[dims], batch_shape], axis=-1) def _loop_cond(iter_, x_update_diff_norm_sq, x_update, hess_matmul_x_update): del x_update del hess_matmul_x_update sweep_complete = (iter_ > 0) & tf.equal(iter_ % dims, 0) small_delta = ( x_update_diff_norm_sq < x_update_diff_norm_sq_convergence_threshold) converged = sweep_complete & small_delta allowed_more_iterations = iter_ < maximum_full_sweeps * dims return allowed_more_iterations & tf.reduce_any(~converged) def _loop_body( # pylint: disable=missing-docstring iter_, x_update_diff_norm_sq, x_update, hess_matmul_x_update): # Inner loop of the minimizer. # # This loop updates a single coordinate of x_update. Ideally, an # iteration of this loop would set # # x_update[j] += argmin{ LocalLoss(x_update + z*e_j) : z in R } # # where # # LocalLoss(x_update') # = LocalLossSmoothComponent(x_update') # + l1_regularizer * (||x_start + x_update'||_1 - # ||x_start + x_update||_1) # := (UnregularizedLoss(x_start + x_update') - # UnregularizedLoss(x_start + x_update) # + l2_regularizer * (||x_start + x_update'||_2**2 - # ||x_start + x_update||_2**2) # + l1_regularizer * (||x_start + x_update'||_1 - # ||x_start + x_update||_1) # # In this algorithm approximate the above argmin using (univariate) # proximal gradient descent: # # (*) x_update[j] = prox_{t * l1_regularizer * L1}( # x_update[j] - # t * d/dz|z=0 UnivariateLocalLossSmoothComponent(z)) # # where # # UnivariateLocalLossSmoothComponent(z) # := LocalLossSmoothComponent(x_update + z*e_j) # # and we approximate # # d/dz UnivariateLocalLossSmoothComponent(z) # = grad LocalLossSmoothComponent(x_update))[j] # ~= (grad LossSmoothComponent(x_start) # + x_update matmul HessianOfLossSmoothComponent(x_start))[j]. # # To choose the parameter t, we squint and pretend that the inner term of # (*) is a Newton update as if we were using Newton's method to minimize # UnivariateLocalLossSmoothComponent. That is, we choose t such that # # -t * d/dz ULLSC = -learning_rate * (d/dz ULLSC) / (d^2/dz^2 ULLSC) # # at z=0. Hence # # t = learning_rate / (d^2/dz^2|z=0 ULLSC) # = learning_rate / HessianOfLossSmoothComponent( # x_start + x_update)[j,j] # ~= learning_rate / HessianOfLossSmoothComponent( # x_start)[j,j] # # The above approximation is equivalent to assuming that # HessianOfUnregularizedLoss is constant, i.e., ignoring third-order # effects. # # Note that because LossSmoothComponent is (assumed to be) convex, t is # positive. # In above notation, coord = j. coord = iter_ % dims # x_update_diff_norm_sq := ||x_update_end - x_update_start||_2**2, # computed incrementally, where x_update_end and x_update_start are as # defined in the convergence criteria. Accordingly, we reset # x_update_diff_norm_sq to zero at the beginning of each sweep. x_update_diff_norm_sq = tf.where( tf.equal(coord, 0), dtype_util.as_numpy_dtype(x_update_diff_norm_sq.dtype)(0.), x_update_diff_norm_sq) # Recall that x_update and hess_matmul_x_update has the rightmost # dimension transposed to the leftmost dimension. w_old = (tf.gather(x_start, coord, axis=-1) + tf.gather(x_update, coord, axis=0)) # This is the coordinatewise Newton update if no L1 regularization. # In above notation, newton_step = -t * (approximation of d/dz|z=0 ULLSC). second_deriv = _hessian_diag_elt_with_l2(coord) newton_step = -_mul_ignoring_nones( # pylint: disable=invalid-unary-operand-type learning_rate, (tf.gather(grad_loss_with_l2, coord, axis=-1) + tf.gather(hess_matmul_x_update, coord, axis=0))) / second_deriv # Applying the soft-threshold operator accounts for L1 regularization. # In above notation, delta = # prox_{t*l1_regularizer*L1}(w_old + newton_step) - w_old. delta = ( soft_threshold( w_old + newton_step, _mul_ignoring_nones(learning_rate, l1_regularizer) / second_deriv) - w_old) def _do_update(x_update_diff_norm_sq, x_update, hess_matmul_x_update): # pylint: disable=missing-docstring hessian_column_with_l2 = sparse_or_dense_matvecmul( hessian_unregularized_loss_outer, hessian_unregularized_loss_middle * _sparse_or_dense_matmul_onehot( hessian_unregularized_loss_outer, coord), adjoint_a=True) if l2_regularizer is not None: hessian_column_with_l2 += _one_hot_like( hessian_column_with_l2, coord, on_value=2. * l2_regularizer) # Move the batch dimensions of `hessian_column_with_l2` to rightmost in # order to conform to `hess_matmul_x_update`. n = tf.rank(hessian_column_with_l2) perm = tf.roll(tf.range(n), shift=1, axis=0) hessian_column_with_l2 = tf.transpose( a=hessian_column_with_l2, perm=perm) # Update the entire batch at `coord` even if `delta` may be 0 at some # batch coordinates. In those cases, adding `delta` is a no-op. x_update = tf.tensor_scatter_nd_add(x_update, [[coord]], [delta]) with tf.control_dependencies([x_update]): x_update_diff_norm_sq_ = x_update_diff_norm_sq + delta**2 hess_matmul_x_update_ = ( hess_matmul_x_update + delta * hessian_column_with_l2) # Hint that loop vars retain the same shape. x_update_diff_norm_sq_.set_shape( x_update_diff_norm_sq_.shape.merge_with( x_update_diff_norm_sq.shape)) hess_matmul_x_update_.set_shape( hess_matmul_x_update_.shape.merge_with( hess_matmul_x_update.shape)) return [x_update_diff_norm_sq_, x_update, hess_matmul_x_update_] inputs_to_update = [x_update_diff_norm_sq, x_update, hess_matmul_x_update] return [iter_ + 1] + prefer_static.cond( # Note on why checking delta (a difference of floats) for equality to # zero is ok: # # First of all, x - x == 0 in floating point -- see # https://stackoverflow.com/a/2686671 # # Delta will conceptually equal zero when one of the following holds: # (i) |w_old + newton_step| <= threshold and w_old == 0 # (ii) |w_old + newton_step| > threshold and # w_old + newton_step - sign(w_old + newton_step) * threshold # == w_old # # In case (i) comparing delta to zero is fine. # # In case (ii), newton_step conceptually equals # sign(w_old + newton_step) * threshold. # Also remember # threshold = -newton_step / (approximation of d/dz|z=0 ULLSC). # So (i) happens when # (approximation of d/dz|z=0 ULLSC) == -sign(w_old + newton_step). # If we did not require LossSmoothComponent to be strictly convex, # then this could actually happen a non-negligible amount of the time, # e.g. if the loss function is piecewise linear and one of the pieces # has slope 1. But since LossSmoothComponent is strictly convex, (i) # should not systematically happen. tf.reduce_all(tf.equal(delta, 0.)), lambda: inputs_to_update, lambda: _do_update(*inputs_to_update)) base_dtype = x_start.dtype.base_dtype iter_, x_update_diff_norm_sq, x_update, _ = tf.while_loop( cond=_loop_cond, body=_loop_body, loop_vars=[ tf.zeros([], dtype=np.int32, name='iter'), tf.zeros( batch_shape, dtype=base_dtype, name='x_update_diff_norm_sq'), tf.zeros(update_shape, dtype=base_dtype, name='x_update'), tf.zeros( update_shape, dtype=base_dtype, name='hess_matmul_x_update'), ]) # Convert back x_update to the shape of x_start by transposing the leftmost # dimension to the rightmost. n = tf.rank(x_update) perm = tf.roll(tf.range(n), shift=-1, axis=0) x_update = tf.transpose(a=x_update, perm=perm) converged = tf.reduce_all(x_update_diff_norm_sq < x_update_diff_norm_sq_convergence_threshold) return x_start + x_update, converged, iter_ / dims