def forward_dp(bp_diags, tp_diags, batch_size, input_max_len, target_max_len):
    """
    :return: forward variable alpha with shape batch_size x input_max_len x target_max_len
    """
    def next_state(x, trans_probs):
        blank_probs = trans_probs[0]
        truth_probs = trans_probs[1]

        x_b = tf.concat(
            [LOG_0 * tf.ones(shape=[batch_size, 1]), x[:, :-1] + blank_probs],
            axis=1)
        x_t = x + truth_probs

        x = tf.math.reduce_logsumexp(tf.stack([x_b, x_t], axis=0), axis=0)
        return x

    initial_alpha = tf.concat([
        tf.zeros(shape=[batch_size, 1]),
        tf.ones(shape=[batch_size, input_max_len - 1]) * LOG_0
    ],
                              axis=1)

    fwd = tf.scan(next_state, (bp_diags[:-1, :, :-1], tp_diags),
                  initializer=initial_alpha)

    alpha = tf.transpose(tf.concat(
        [tf.expand_dims(initial_alpha, axis=0), fwd], axis=0),
                         perm=[1, 2, 0])
    alpha = matrix_diag_part_v2(alpha,
                                k=(0, target_max_len - 1),
                                padding_value=LOG_0)
    alpha = tf.transpose(tf.reverse(alpha, axis=[1]), perm=[0, 2, 1])

    return alpha
Exemple #2
0
def backward_dp(
    bp_diags,
    tp_diags,
    batch_size,
    input_max_len,
    target_max_len,
    label_length,
    logit_length,
    blank_sl,
):
    """
    :return: backward variable beta with shape batch_size x input_max_len x target_max_len
    """
    def next_state(x, mask_and_trans_probs):
        mask_s, blank_probs_s, truth_probs = mask_and_trans_probs

        beta_b = tf.concat(
            [x[:, 1:] + blank_probs_s, LOG_0 * tf.ones(shape=[batch_size, 1])],
            axis=1)
        beta_t = tf.concat(
            [x[:, :-1] + truth_probs, LOG_0 * tf.ones(shape=[batch_size, 1])],
            axis=1)

        beta_next = reduce_logsumexp(tf.stack([beta_b, beta_t], axis=0),
                                     axis=0)
        masked_beta_next = nan_to_zero(
            beta_next * tf.expand_dims(mask_s, axis=1)) + nan_to_zero(
                x * tf.expand_dims((1.0 - mask_s), axis=1))
        return tf.reshape(masked_beta_next, shape=tf.shape(x))

    # Initial beta for batches.
    initial_beta_mask = tf.one_hot(logit_length - 1, depth=input_max_len + 1)
    initial_beta = tf.expand_dims(blank_sl,
                                  axis=1) * initial_beta_mask + nan_to_zero(
                                      LOG_0 * (1.0 - initial_beta_mask))

    # Mask for scan iterations.
    mask = tf.sequence_mask(
        logit_length + label_length - 1,
        input_max_len + target_max_len - 2,
        dtype=tf.dtypes.float32,
    )
    mask = tf.transpose(mask, perm=[1, 0])

    bwd = tf.scan(
        next_state,
        (mask, bp_diags[:-1, :, :], tp_diags),
        initializer=initial_beta,
        reverse=True,
    )

    beta = tf.transpose(tf.concat(
        [bwd, tf.expand_dims(initial_beta, axis=0)], axis=0),
                        perm=[1, 2, 0])[:, :-1, :]
    beta = matrix_diag_part_v2(beta,
                               k=(0, target_max_len - 1),
                               padding_value=LOG_0)
    beta = tf.transpose(tf.reverse(beta, axis=[1]), perm=[0, 2, 1])

    return beta
def extract_diagonals(log_probs):
    time_steps = tf.shape(log_probs)[1]  # T
    output_steps = tf.shape(log_probs)[2]  # U + 1
    reverse_log_probs = tf.reverse(log_probs, axis=[-1])
    paddings = [[0, 0], [0, 0], [time_steps - 1, 0]]
    padded_reverse_log_probs = tf.pad(reverse_log_probs, paddings,
                                      'CONSTANT', constant_values=LOG_0)
    diagonals = matrix_diag_part_v2(padded_reverse_log_probs, k=(0, time_steps + output_steps - 2),
                                    padding_value=LOG_0)

    return tf.transpose(diagonals, perm=[1, 0, 2])
Exemple #4
0
def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None):
    r"""Multiplies tridiagonal matrix by matrix.

  `diagonals` is representation of 3-diagonal NxN matrix, which depends on
  `diagonals_format`.

  In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with
  two inner-most dimensions representing the square tridiagonal matrices.
  Elements outside of the three diagonals will be ignored.

  If `sequence` format, `diagonals` is list or tuple of three tensors:
  `[superdiag, maindiag, subdiag]`, each having shape [..., M]. Last element
  of `superdiag` first element of `subdiag` are ignored.

  In `compact` format the three diagonals are brought together into one tensor
  of shape `[..., 3, M]`, with last two dimensions containing superdiagonals,
  diagonals, and subdiagonals, in order. Similarly to `sequence` format,
  elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored.

  The `sequence` format is recommended as the one with the best performance.

  `rhs` is matrix to the right of multiplication. It has shape `[..., M, N]`.

  Example:

  ```python
  superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
  maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
  subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
  diagonals = [superdiag, maindiag, subdiag]
  rhs = tf.constant([[1, 1], [1, 1], [1, 1]], dtype=tf.float64)
  x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')
  ```

  Args:
    diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The
      shape depends of `diagonals_format`, see description above. Must be
      `float32`, `float64`, `complex64`, or `complex128`.
    rhs: A `Tensor` of shape [..., M, N] and with the same dtype as `diagonals`.
    diagonals_format: one of `sequence`, or `compact`. Default is `compact`.
    name:  A name to give this `Op` (optional).

  Returns:
    A `Tensor` of shape [..., M, N] containing the result of multiplication.

  Raises:
    ValueError: An unsupported type is provided as input, or when the input
    tensors have incorrect shapes.
  """
    if diagonals_format == 'compact':
        superdiag = diagonals[..., 0, :]
        maindiag = diagonals[..., 1, :]
        subdiag = diagonals[..., 2, :]
    elif diagonals_format == 'sequence':
        superdiag, maindiag, subdiag = diagonals
    elif diagonals_format == 'matrix':
        m1 = tensor_shape.dimension_value(diagonals.shape[-1])
        m2 = tensor_shape.dimension_value(diagonals.shape[-2])
        if m1 and m2 and m1 != m2:
            raise ValueError(
                'Expected last two dimensions of diagonals to be same, got {} and {}'
                .format(m1, m2))

        maindiag = array_ops.matrix_diag_part(diagonals)
        superdiag = gen_array_ops.matrix_diag_part_v2(diagonals,
                                                      k=1,
                                                      padding_value=0.)
        superdiag = array_ops.concat([
            superdiag,
            array_ops.zeros_like(superdiag[..., 0])[..., array_ops.newaxis]
        ],
                                     axis=-1)
        subdiag = gen_array_ops.matrix_diag_part_v2(diagonals,
                                                    k=-1,
                                                    padding_value=0.)
        subdiag = array_ops.concat([
            array_ops.zeros_like(subdiag[..., 0])[..., array_ops.newaxis],
            subdiag
        ],
                                   axis=-1)
    else:
        raise ValueError('Unrecognized diagonals_format: %s' %
                         diagonals_format)

    # C++ backend requires matrices.
    # Converting 1-dimensional vectors to matrices with 1 row.
    superdiag = array_ops.expand_dims(superdiag, -2)
    maindiag = array_ops.expand_dims(maindiag, -2)
    subdiag = array_ops.expand_dims(subdiag, -2)

    return linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs,
                                          name)
Exemple #5
0
def tridiagonal_solve(diagonals,
                      rhs,
                      diagonals_format='compact',
                      transpose_rhs=False,
                      conjugate_rhs=False,
                      name=None,
                      partial_pivoting=True):
    r"""Solves tridiagonal systems of equations.

  The input can be supplied in various formats: `matrix`, `sequence` and
  `compact`, specified by the `diagonals_format` arg.

  In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with
  two inner-most dimensions representing the square tridiagonal matrices.
  Elements outside of the three diagonals will be ignored.

  In `sequence` format, `diagonals` are supplied as a tuple or list of three
  tensors of shapes `[..., N]`, `[..., M]`, `[..., N]` representing
  superdiagonals, diagonals, and subdiagonals, respectively. `N` can be either
  `M-1` or `M`; in the latter case, the last element of superdiagonal and the
  first element of subdiagonal will be ignored.

  In `compact` format the three diagonals are brought together into one tensor
  of shape `[..., 3, M]`, with last two dimensions containing superdiagonals,
  diagonals, and subdiagonals, in order. Similarly to `sequence` format,
  elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored.

  The `compact` format is recommended as the one with best performance. In case
  you need to cast a tensor into a compact format manually, use `tf.gather_nd`.
  An example for a tensor of shape [m, m]:

  ```python
  rhs = tf.constant([...])
  matrix = tf.constant([[...]])
  m = matrix.shape[0]
  dummy_idx = [0, 0]  # An arbitrary element to use as a dummy
  indices = [[[i, i + 1] for i in range(m - 1)] + [dummy_idx],  # Superdiagonal
           [[i, i] for i in range(m)],                          # Diagonal
           [dummy_idx] + [[i + 1, i] for i in range(m - 1)]]    # Subdiagonal
  diagonals=tf.gather_nd(matrix, indices)
  x = tf.linalg.tridiagonal_solve(diagonals, rhs)
  ```

  Regardless of the `diagonals_format`, `rhs` is a tensor of shape `[..., M]` or
  `[..., M, K]`. The latter allows to simultaneously solve K systems with the
  same left-hand sides and K different right-hand sides. If `transpose_rhs`
  is set to `True` the expected shape is `[..., M]` or `[..., K, M]`.

  The batch dimensions, denoted as `...`, must be the same in `diagonals` and
  `rhs`.

  The output is a tensor of the same shape as `rhs`: either `[..., M]` or
  `[..., M, K]`.

  The op isn't guaranteed to raise an error if the input matrix is not
  invertible. `tf.debugging.check_numerics` can be applied to the output to
  detect invertibility problems.

  **Note**: with large batch sizes, the computation on the GPU may be slow, if
  either `partial_pivoting=True` or there are multiple right-hand sides
  (`K > 1`). If this issue arises, consider if it's possible to disable pivoting
  and have `K = 1`, or, alternatively, consider using CPU.

  On CPU, solution is computed via Gaussian elimination with or without partial
  pivoting, depending on `partial_pivoting` parameter. On GPU, Nvidia's cuSPARSE
  library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv

  Args:
    diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The
      shape depends of `diagonals_format`, see description above. Must be
      `float32`, `float64`, `complex64`, or `complex128`.
    rhs: A `Tensor` of shape [..., M] or [..., M, K] and with the same dtype as
      `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known
      statically, `rhs` will be treated as a matrix rather than a vector.
    diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is
      `compact`.
    transpose_rhs: If `True`, `rhs` is transposed before solving (has no effect
      if the shape of rhs is [..., M]).
    conjugate_rhs: If `True`, `rhs` is conjugated before solving.
    name:  A name to give this `Op` (optional).
    partial_pivoting: whether to perform partial pivoting. `True` by default.
      Partial pivoting makes the procedure more stable, but slower. Partial
      pivoting is unnecessary in some cases, including diagonally dominant and
      symmetric positive definite matrices (see e.g. theorem 9.12 in [1]).

  Returns:
    A `Tensor` of shape [..., M] or [..., M, K] containing the solutions.

  Raises:
    ValueError: An unsupported type is provided as input, or when the input
    tensors have incorrect shapes.

  [1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms:
  Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7.

  """
    if diagonals_format == 'compact':
        return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
                                                 conjugate_rhs,
                                                 partial_pivoting, name)

    if diagonals_format == 'sequence':
        if not isinstance(diagonals, (tuple, list)) or len(diagonals) != 3:
            raise ValueError(
                'Expected diagonals to be a sequence of length 3.')

        superdiag, maindiag, subdiag = diagonals
        if (not subdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1])
                or not superdiag.shape[:-1].is_compatible_with(
                    maindiag.shape[:-1])):
            raise ValueError(
                'Tensors representing the three diagonals must have the same shape,'
                'except for the last dimension, got {}, {}, {}'.format(
                    subdiag.shape, maindiag.shape, superdiag.shape))

        m = tensor_shape.dimension_value(maindiag.shape[-1])

        def pad_if_necessary(t, name, last_dim_padding):
            n = tensor_shape.dimension_value(t.shape[-1])
            if not n or n == m:
                return t
            if n == m - 1:
                paddings = ([[0, 0] for _ in range(len(t.shape) - 1)] +
                            [last_dim_padding])
                return array_ops.pad(t, paddings)
            raise ValueError(
                'Expected {} to be have length {} or {}, got {}.'.format(
                    name, m, m - 1, n))

        subdiag = pad_if_necessary(subdiag, 'subdiagonal', [1, 0])
        superdiag = pad_if_necessary(superdiag, 'superdiagonal', [0, 1])

        diagonals = array_ops.stack((superdiag, maindiag, subdiag), axis=-2)
        return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
                                                 conjugate_rhs,
                                                 partial_pivoting, name)

    if diagonals_format == 'matrix':
        m1 = tensor_shape.dimension_value(diagonals.shape[-1])
        m2 = tensor_shape.dimension_value(diagonals.shape[-2])
        if m1 and m2 and m1 != m2:
            raise ValueError(
                'Expected last two dimensions of diagonals to be same, got {} and {}'
                .format(m1, m2))
        m = m1 or m2
        diagonals = gen_array_ops.matrix_diag_part_v2(diagonals,
                                                      k=(-1, 1),
                                                      padding_value=0.)
        # matrix_diag_part pads at the end. Because the subdiagonal has the
        # convention of having the padding in the front, we need to rotate the last
        # Tensor.
        superdiag, d, subdiag = array_ops.unstack(diagonals, num=3, axis=-2)
        subdiag = manip_ops.roll(subdiag, shift=1, axis=-1)
        diagonals = array_ops.stack((superdiag, d, subdiag), axis=-2)
        return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
                                                 conjugate_rhs,
                                                 partial_pivoting, name)

    raise ValueError(
        'Unrecognized diagonals_format: {}'.format(diagonals_format))