def forward_dp(bp_diags, tp_diags, batch_size, input_max_len, target_max_len): """ :return: forward variable alpha with shape batch_size x input_max_len x target_max_len """ def next_state(x, trans_probs): blank_probs = trans_probs[0] truth_probs = trans_probs[1] x_b = tf.concat( [LOG_0 * tf.ones(shape=[batch_size, 1]), x[:, :-1] + blank_probs], axis=1) x_t = x + truth_probs x = tf.math.reduce_logsumexp(tf.stack([x_b, x_t], axis=0), axis=0) return x initial_alpha = tf.concat([ tf.zeros(shape=[batch_size, 1]), tf.ones(shape=[batch_size, input_max_len - 1]) * LOG_0 ], axis=1) fwd = tf.scan(next_state, (bp_diags[:-1, :, :-1], tp_diags), initializer=initial_alpha) alpha = tf.transpose(tf.concat( [tf.expand_dims(initial_alpha, axis=0), fwd], axis=0), perm=[1, 2, 0]) alpha = matrix_diag_part_v2(alpha, k=(0, target_max_len - 1), padding_value=LOG_0) alpha = tf.transpose(tf.reverse(alpha, axis=[1]), perm=[0, 2, 1]) return alpha
def backward_dp( bp_diags, tp_diags, batch_size, input_max_len, target_max_len, label_length, logit_length, blank_sl, ): """ :return: backward variable beta with shape batch_size x input_max_len x target_max_len """ def next_state(x, mask_and_trans_probs): mask_s, blank_probs_s, truth_probs = mask_and_trans_probs beta_b = tf.concat( [x[:, 1:] + blank_probs_s, LOG_0 * tf.ones(shape=[batch_size, 1])], axis=1) beta_t = tf.concat( [x[:, :-1] + truth_probs, LOG_0 * tf.ones(shape=[batch_size, 1])], axis=1) beta_next = reduce_logsumexp(tf.stack([beta_b, beta_t], axis=0), axis=0) masked_beta_next = nan_to_zero( beta_next * tf.expand_dims(mask_s, axis=1)) + nan_to_zero( x * tf.expand_dims((1.0 - mask_s), axis=1)) return tf.reshape(masked_beta_next, shape=tf.shape(x)) # Initial beta for batches. initial_beta_mask = tf.one_hot(logit_length - 1, depth=input_max_len + 1) initial_beta = tf.expand_dims(blank_sl, axis=1) * initial_beta_mask + nan_to_zero( LOG_0 * (1.0 - initial_beta_mask)) # Mask for scan iterations. mask = tf.sequence_mask( logit_length + label_length - 1, input_max_len + target_max_len - 2, dtype=tf.dtypes.float32, ) mask = tf.transpose(mask, perm=[1, 0]) bwd = tf.scan( next_state, (mask, bp_diags[:-1, :, :], tp_diags), initializer=initial_beta, reverse=True, ) beta = tf.transpose(tf.concat( [bwd, tf.expand_dims(initial_beta, axis=0)], axis=0), perm=[1, 2, 0])[:, :-1, :] beta = matrix_diag_part_v2(beta, k=(0, target_max_len - 1), padding_value=LOG_0) beta = tf.transpose(tf.reverse(beta, axis=[1]), perm=[0, 2, 1]) return beta
def extract_diagonals(log_probs): time_steps = tf.shape(log_probs)[1] # T output_steps = tf.shape(log_probs)[2] # U + 1 reverse_log_probs = tf.reverse(log_probs, axis=[-1]) paddings = [[0, 0], [0, 0], [time_steps - 1, 0]] padded_reverse_log_probs = tf.pad(reverse_log_probs, paddings, 'CONSTANT', constant_values=LOG_0) diagonals = matrix_diag_part_v2(padded_reverse_log_probs, k=(0, time_steps + output_steps - 2), padding_value=LOG_0) return tf.transpose(diagonals, perm=[1, 0, 2])
def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None): r"""Multiplies tridiagonal matrix by matrix. `diagonals` is representation of 3-diagonal NxN matrix, which depends on `diagonals_format`. In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with two inner-most dimensions representing the square tridiagonal matrices. Elements outside of the three diagonals will be ignored. If `sequence` format, `diagonals` is list or tuple of three tensors: `[superdiag, maindiag, subdiag]`, each having shape [..., M]. Last element of `superdiag` first element of `subdiag` are ignored. In `compact` format the three diagonals are brought together into one tensor of shape `[..., 3, M]`, with last two dimensions containing superdiagonals, diagonals, and subdiagonals, in order. Similarly to `sequence` format, elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored. The `sequence` format is recommended as the one with the best performance. `rhs` is matrix to the right of multiplication. It has shape `[..., M, N]`. Example: ```python superdiag = tf.constant([-1, -1, 0], dtype=tf.float64) maindiag = tf.constant([2, 2, 2], dtype=tf.float64) subdiag = tf.constant([0, -1, -1], dtype=tf.float64) diagonals = [superdiag, maindiag, subdiag] rhs = tf.constant([[1, 1], [1, 1], [1, 1]], dtype=tf.float64) x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence') ``` Args: diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The shape depends of `diagonals_format`, see description above. Must be `float32`, `float64`, `complex64`, or `complex128`. rhs: A `Tensor` of shape [..., M, N] and with the same dtype as `diagonals`. diagonals_format: one of `sequence`, or `compact`. Default is `compact`. name: A name to give this `Op` (optional). Returns: A `Tensor` of shape [..., M, N] containing the result of multiplication. Raises: ValueError: An unsupported type is provided as input, or when the input tensors have incorrect shapes. """ if diagonals_format == 'compact': superdiag = diagonals[..., 0, :] maindiag = diagonals[..., 1, :] subdiag = diagonals[..., 2, :] elif diagonals_format == 'sequence': superdiag, maindiag, subdiag = diagonals elif diagonals_format == 'matrix': m1 = tensor_shape.dimension_value(diagonals.shape[-1]) m2 = tensor_shape.dimension_value(diagonals.shape[-2]) if m1 and m2 and m1 != m2: raise ValueError( 'Expected last two dimensions of diagonals to be same, got {} and {}' .format(m1, m2)) maindiag = array_ops.matrix_diag_part(diagonals) superdiag = gen_array_ops.matrix_diag_part_v2(diagonals, k=1, padding_value=0.) superdiag = array_ops.concat([ superdiag, array_ops.zeros_like(superdiag[..., 0])[..., array_ops.newaxis] ], axis=-1) subdiag = gen_array_ops.matrix_diag_part_v2(diagonals, k=-1, padding_value=0.) subdiag = array_ops.concat([ array_ops.zeros_like(subdiag[..., 0])[..., array_ops.newaxis], subdiag ], axis=-1) else: raise ValueError('Unrecognized diagonals_format: %s' % diagonals_format) # C++ backend requires matrices. # Converting 1-dimensional vectors to matrices with 1 row. superdiag = array_ops.expand_dims(superdiag, -2) maindiag = array_ops.expand_dims(maindiag, -2) subdiag = array_ops.expand_dims(subdiag, -2) return linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs, name)
def tridiagonal_solve(diagonals, rhs, diagonals_format='compact', transpose_rhs=False, conjugate_rhs=False, name=None, partial_pivoting=True): r"""Solves tridiagonal systems of equations. The input can be supplied in various formats: `matrix`, `sequence` and `compact`, specified by the `diagonals_format` arg. In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with two inner-most dimensions representing the square tridiagonal matrices. Elements outside of the three diagonals will be ignored. In `sequence` format, `diagonals` are supplied as a tuple or list of three tensors of shapes `[..., N]`, `[..., M]`, `[..., N]` representing superdiagonals, diagonals, and subdiagonals, respectively. `N` can be either `M-1` or `M`; in the latter case, the last element of superdiagonal and the first element of subdiagonal will be ignored. In `compact` format the three diagonals are brought together into one tensor of shape `[..., 3, M]`, with last two dimensions containing superdiagonals, diagonals, and subdiagonals, in order. Similarly to `sequence` format, elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored. The `compact` format is recommended as the one with best performance. In case you need to cast a tensor into a compact format manually, use `tf.gather_nd`. An example for a tensor of shape [m, m]: ```python rhs = tf.constant([...]) matrix = tf.constant([[...]]) m = matrix.shape[0] dummy_idx = [0, 0] # An arbitrary element to use as a dummy indices = [[[i, i + 1] for i in range(m - 1)] + [dummy_idx], # Superdiagonal [[i, i] for i in range(m)], # Diagonal [dummy_idx] + [[i + 1, i] for i in range(m - 1)]] # Subdiagonal diagonals=tf.gather_nd(matrix, indices) x = tf.linalg.tridiagonal_solve(diagonals, rhs) ``` Regardless of the `diagonals_format`, `rhs` is a tensor of shape `[..., M]` or `[..., M, K]`. The latter allows to simultaneously solve K systems with the same left-hand sides and K different right-hand sides. If `transpose_rhs` is set to `True` the expected shape is `[..., M]` or `[..., K, M]`. The batch dimensions, denoted as `...`, must be the same in `diagonals` and `rhs`. The output is a tensor of the same shape as `rhs`: either `[..., M]` or `[..., M, K]`. The op isn't guaranteed to raise an error if the input matrix is not invertible. `tf.debugging.check_numerics` can be applied to the output to detect invertibility problems. **Note**: with large batch sizes, the computation on the GPU may be slow, if either `partial_pivoting=True` or there are multiple right-hand sides (`K > 1`). If this issue arises, consider if it's possible to disable pivoting and have `K = 1`, or, alternatively, consider using CPU. On CPU, solution is computed via Gaussian elimination with or without partial pivoting, depending on `partial_pivoting` parameter. On GPU, Nvidia's cuSPARSE library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv Args: diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The shape depends of `diagonals_format`, see description above. Must be `float32`, `float64`, `complex64`, or `complex128`. rhs: A `Tensor` of shape [..., M] or [..., M, K] and with the same dtype as `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known statically, `rhs` will be treated as a matrix rather than a vector. diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is `compact`. transpose_rhs: If `True`, `rhs` is transposed before solving (has no effect if the shape of rhs is [..., M]). conjugate_rhs: If `True`, `rhs` is conjugated before solving. name: A name to give this `Op` (optional). partial_pivoting: whether to perform partial pivoting. `True` by default. Partial pivoting makes the procedure more stable, but slower. Partial pivoting is unnecessary in some cases, including diagonally dominant and symmetric positive definite matrices (see e.g. theorem 9.12 in [1]). Returns: A `Tensor` of shape [..., M] or [..., M, K] containing the solutions. Raises: ValueError: An unsupported type is provided as input, or when the input tensors have incorrect shapes. [1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms: Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7. """ if diagonals_format == 'compact': return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, conjugate_rhs, partial_pivoting, name) if diagonals_format == 'sequence': if not isinstance(diagonals, (tuple, list)) or len(diagonals) != 3: raise ValueError( 'Expected diagonals to be a sequence of length 3.') superdiag, maindiag, subdiag = diagonals if (not subdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1]) or not superdiag.shape[:-1].is_compatible_with( maindiag.shape[:-1])): raise ValueError( 'Tensors representing the three diagonals must have the same shape,' 'except for the last dimension, got {}, {}, {}'.format( subdiag.shape, maindiag.shape, superdiag.shape)) m = tensor_shape.dimension_value(maindiag.shape[-1]) def pad_if_necessary(t, name, last_dim_padding): n = tensor_shape.dimension_value(t.shape[-1]) if not n or n == m: return t if n == m - 1: paddings = ([[0, 0] for _ in range(len(t.shape) - 1)] + [last_dim_padding]) return array_ops.pad(t, paddings) raise ValueError( 'Expected {} to be have length {} or {}, got {}.'.format( name, m, m - 1, n)) subdiag = pad_if_necessary(subdiag, 'subdiagonal', [1, 0]) superdiag = pad_if_necessary(superdiag, 'superdiagonal', [0, 1]) diagonals = array_ops.stack((superdiag, maindiag, subdiag), axis=-2) return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, conjugate_rhs, partial_pivoting, name) if diagonals_format == 'matrix': m1 = tensor_shape.dimension_value(diagonals.shape[-1]) m2 = tensor_shape.dimension_value(diagonals.shape[-2]) if m1 and m2 and m1 != m2: raise ValueError( 'Expected last two dimensions of diagonals to be same, got {} and {}' .format(m1, m2)) m = m1 or m2 diagonals = gen_array_ops.matrix_diag_part_v2(diagonals, k=(-1, 1), padding_value=0.) # matrix_diag_part pads at the end. Because the subdiagonal has the # convention of having the padding in the front, we need to rotate the last # Tensor. superdiag, d, subdiag = array_ops.unstack(diagonals, num=3, axis=-2) subdiag = manip_ops.roll(subdiag, shift=1, axis=-1) diagonals = array_ops.stack((superdiag, d, subdiag), axis=-2) return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, conjugate_rhs, partial_pivoting, name) raise ValueError( 'Unrecognized diagonals_format: {}'.format(diagonals_format))