def fill_lower_triangular(x, name="fill_lower_triangular"):
  """Creates a (batch of) lower triangular matrix from a vector of inputs.

  If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
  b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
  `n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`.

  Note: This function is very slow; possibly 10x slower than zero-ing out the
  upper-triangular portion of a full matrix.

  Example:

  ```python
  fill_lower_triangular([1, 2, 3, 4, 5, 6])
  # Returns: [[1, 0, 0],
  #           [2, 3, 0],
  #           [4, 5, 6]]
  ```

  Args:
    x: `Tensor` representing lower triangular elements.
    name: `String`. The name to give this op.

  Returns:
    tril: `Tensor` with lower triangular elements filled from `x`.
  """
  with ops.name_scope(name, values=(x,)):
    x = ops.convert_to_tensor(x, name="x")
    ndims = x.get_shape().ndims
    if ndims is not None and x.get_shape()[-1].value is not None:
      d = x.get_shape()[-1].value
      # d = n^2/2 + n/2 implies n is:
      n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
      final_shape = x.get_shape()[:-1].concatenate(
          tensor_shape.TensorShape([n, n]))
    else:
      ndims = array_ops.rank(x)
      d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
      # d = n^2/2 + n/2 implies n is:
      n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
                        dtype=dtypes.int32)
      final_shape = x.get_shape()[:-1].concatenate(
          tensor_shape.TensorShape([None, None]))

    # Make ids for each batch dim.
    if (x.get_shape().ndims is not None and
        x.get_shape()[:-1].is_fully_defined()):
      batch_shape = np.asarray(x.get_shape()[:-1].as_list(), dtype=np.int32)
      m = np.prod(batch_shape)
    else:
      batch_shape = array_ops.shape(x)[:-1]
      m = array_ops.reduce_prod(batch_shape)

    # Flatten batch dims.
    y = array_ops.reshape(x, [-1, d])

    # Prepend a zero to each row.
    y = array_ops.pad(y, paddings=[[0, 0], [1, 0]])

    # Make ids for each batch dim.
    if x.get_shape()[:-1].is_fully_defined():
      m = np.asarray(np.prod(x.get_shape()[:-1].as_list()), dtype=np.int32)
    else:
      m = array_ops.reduce_prod(array_ops.shape(x)[:-1])
    batch_ids = math_ops.range(m)

    def make_tril_ids(n):
      """Internal helper to create vector of linear indices into y."""
      cols = array_ops.reshape(array_ops.tile(math_ops.range(n), [n]), [n, n])
      rows = array_ops.tile(
          array_ops.expand_dims(math_ops.range(n), -1), [1, n])
      pred = math_ops.greater(cols, rows)
      tril_ids = array_ops.tile(array_ops.reshape(
          math_ops.cumsum(math_ops.range(n)), [n, 1]), [1, n]) + cols
      tril_ids = math_ops.select(pred,
                                 array_ops.zeros([n, n], dtype=dtypes.int32),
                                 tril_ids + 1)
      tril_ids = array_ops.reshape(tril_ids, [-1])
      return tril_ids
    tril_ids = make_tril_ids(n)

    # Assemble the ids into pairs.
    idx = array_ops.pack([
        array_ops.tile(array_ops.expand_dims(batch_ids, -1), [1, n*n]),
        array_ops.tile([tril_ids], [m, 1])])
    idx = array_ops.transpose(idx, [1, 2, 0])

    y = array_ops.gather_nd(y, idx)
    y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]]))

    y.set_shape(y.get_shape().merge_with(final_shape))

    return y
def fill_lower_triangular(x, validate_args=False, name="fill_lower_triangular"):
  """Creates a (batch of) lower triangular matrix from a vector of inputs.

  If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
  b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
  `n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`.

  Although the non-batch complexity is O(n^2), large constants and sub-optimal
  vectorization means the complexity of this function is 5x slower than zeroing
  out the upper triangular, i.e., `tf.matrix_band_part(X, -1, 0)`.  This
  function becomes competitive only when several matmul/cholesky/etc ops can be
  ellided in constructing the input.  Example: wiring a fully connected layer as
  a covariance matrix; this function reduces the final layer by 2x and possibly
  reduces the network arch complexity considerably.  In most cases it is better
  to simply build a full matrix and zero out the upper triangular elements,
  e.g., `tril = tf.matrix_band_part(full, -1, 0)`, rather than directly
  construct a lower triangular.

  Example:

  ```python
  fill_lower_triangular([1, 2, 3, 4, 5, 6])
  # Returns: [[1, 0, 0],
  #           [2, 3, 0],
  #           [4, 5, 6]]
  ```

  For comparison, a pure numpy version of this function can be found in
  `distribution_util_test.py`, function `_fill_lower_triangular`.

  Args:
    x: `Tensor` representing lower triangular elements.
    validate_args: `Boolean`, default `False`.  Whether to ensure the shape of
      `x` can be mapped to a lower triangular matrix (controls non-static checks
      only).
    name: `String`. The name to give this op.

  Returns:
    tril: `Tensor` with lower triangular elements filled from `x`.

  Raises:
    ValueError: if shape if `x` has static shape which cannot be mapped to a
      lower triangular matrix.
  """
  # TODO(jvdillon): Replace this code with dedicated op when it exists.
  with ops.name_scope(name, values=(x,)):
    x = ops.convert_to_tensor(x, name="x")
    if (x.get_shape().ndims is not None and
        x.get_shape()[-1].value is not None):
      d = x.get_shape()[-1].value
      # d = n(n+1)/2 implies n is:
      n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
      d_inferred = n * (n + 1) /2
      if d != d_inferred:
        raise ValueError("Input cannot be mapped to a lower triangular; "
                         "n*(n+1)/2 = %d != %d" % (d_inferred, d))
      final_shape = x.get_shape()[:-1].concatenate(
          tensor_shape.TensorShape([n, n]))
    else:
      d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
      # d = n(n+1)/2 implies n is:
      n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
                        dtype=dtypes.int32)
      if validate_args:
        is_valid_input_shape = check_ops.assert_equal(
            n * (n + 1) / 2, d,
            message="Input cannot be mapped to a lower triangular.")
        n = control_flow_ops.with_dependencies([is_valid_input_shape], n)
      final_shape = x.get_shape()[:-1].concatenate(
          tensor_shape.TensorShape([None, None]))

    def tril_ids(n):
      """Internal helper to create vector of linear indices into y."""
      # Build the ids statically; chose 512 because it implies 1MiB.
      if not contrib_framework.is_tensor(n) and n <= 512:
        ids = np.arange(n**2, dtype=np.int32)
        rows = (ids / n).astype(np.int32)  # Implicit floor.
        # We need to stop incrementing the index when we encounter
        # upper-triangular elements.  The idea here is to compute the
        # lower-right number of zeros then by "symmetry" subtract this from the
        # total number of zeros, n(n-1)/2.
        # Then we note that: n(n-1)/2 - (n-r)*(n-r-1)/2 = r(2n-r-1)/2
        offset = (rows * (2 * n - rows - 1) / 2).astype(np.int32)
        # We could also zero out when (rows < cols) == (rows < ids-n*rows).
        # mask = (ids <= (n + 1) * rows).astype(np.int32)
      else:
        ids = math_ops.range(n**2)
        rows = math_ops.cast(ids / n, dtype=dtypes.int32)
        offset = math_ops.cast(rows * (2 * n - rows - 1) / 2,
                               dtype=dtypes.int32)
      return ids - offset

    # Special-case non-batch case.
    if x.get_shape().ndims == 1:
      y = array_ops.gather(x, array_ops.reshape(tril_ids(n), [n, n]))
      y = array_ops.matrix_band_part(y, -1, 0)
      y.set_shape(y.get_shape().merge_with(final_shape))
      return y

    # Make ids for each batch dim.
    if (x.get_shape().ndims is not None and
        x.get_shape()[:-1].is_fully_defined()):
      batch_shape = np.asarray(x.get_shape()[:-1].as_list(), dtype=np.int32)
      m = np.prod(batch_shape).astype(np.int32)
    else:
      batch_shape = array_ops.shape(x)[:-1]
      m = array_ops.reduce_prod(array_ops.shape(x)[:-1])
    batch_ids = math_ops.range(m)

    # Assemble the tril_ids into batch,tril_id pairs.
    idx = array_ops.stack([
        array_ops.tile(array_ops.expand_dims(batch_ids, 1), [1, n * n]),
        array_ops.tile(array_ops.expand_dims(tril_ids(n), 0), [m, 1])
    ])
    idx = array_ops.transpose(idx, [1, 2, 0])

    # Gather up, reshape, and return.
    y = array_ops.reshape(x, [-1, d])
    y = array_ops.gather_nd(y, idx)
    y = array_ops.reshape(y, array_ops.concat([batch_shape, [n, n]], 0))
    y = array_ops.matrix_band_part(y, -1, 0)
    y.set_shape(y.get_shape().merge_with(final_shape))
    return y
Esempio n. 3
0
def fill_lower_triangular(x,
                          validate_args=False,
                          name="fill_lower_triangular"):
    """Creates a (batch of) lower triangular matrix from a vector of inputs.

  If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1,
  b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
  `n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))`.

  Although the non-batch complexity is O(n^2), large constants and sub-optimal
  vectorization means the complexity of this function is 5x slower than zeroing
  out the upper triangular, i.e., `tf.matrix_band_part(X, -1, 0)`.  This
  function becomes competitive only when several matmul/cholesky/etc ops can be
  ellided in constructing the input.  Example: wiring a fully connected layer as
  a covariance matrix; this function reduces the final layer by 2x and possibly
  reduces the network arch complexity considerably.  In most cases it is better
  to simply build a full matrix and zero out the upper triangular elements,
  e.g., `tril = tf.matrix_band_part(full, -1, 0)`, rather than directly
  construct a lower triangular.

  Example:

  ```python
  fill_lower_triangular([1, 2, 3, 4, 5, 6])
  # Returns: [[1, 0, 0],
  #           [2, 3, 0],
  #           [4, 5, 6]]
  ```

  For comparison, a pure numpy version of this function can be found in
  `distribution_util_test.py`, function `_fill_lower_triangular`.

  Args:
    x: `Tensor` representing lower triangular elements.
    validate_args: `Boolean`, default `False`.  Whether to ensure the shape of
      `x` can be mapped to a lower triangular matrix (controls non-static checks
      only).
    name: `String`. The name to give this op.

  Returns:
    tril: `Tensor` with lower triangular elements filled from `x`.

  Raises:
    ValueError: if shape if `x` has static shape which cannot be mapped to a
      lower triangular matrix.
  """
    # TODO(jvdillon): Replace this code with dedicated op when it exists.
    with ops.name_scope(name, values=(x, )):
        x = ops.convert_to_tensor(x, name="x")
        if (x.get_shape().ndims is not None
                and x.get_shape()[-1].value is not None):
            d = x.get_shape()[-1].value
            # d = n(n+1)/2 implies n is:
            n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.))
            d_inferred = n * (n + 1) / 2
            if d != d_inferred:
                raise ValueError(
                    "Input cannot be mapped to a lower triangular; "
                    "n*(n+1)/2 = %d != %d" % (d_inferred, d))
            final_shape = x.get_shape()[:-1].concatenate(
                tensor_shape.TensorShape([n, n]))
        else:
            d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
            # d = n(n+1)/2 implies n is:
            n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
                              dtype=dtypes.int32)
            if validate_args:
                is_valid_input_shape = check_ops.assert_equal(
                    n * (n + 1) / 2,
                    d,
                    message="Input cannot be mapped to a lower triangular.")
                n = control_flow_ops.with_dependencies([is_valid_input_shape],
                                                       n)
            final_shape = x.get_shape()[:-1].concatenate(
                tensor_shape.TensorShape([None, None]))

        def tril_ids(n):
            """Internal helper to create vector of linear indices into y."""
            # Build the ids statically; chose 512 because it implies 1MiB.
            if not contrib_framework.is_tensor(n) and n <= 512:
                ids = np.arange(n**2, dtype=np.int32)
                rows = (ids / n).astype(np.int32)  # Implicit floor.
                # We need to stop incrementing the index when we encounter
                # upper-triangular elements.  The idea here is to compute the
                # lower-right number of zeros then by "symmetry" subtract this from the
                # total number of zeros, n(n-1)/2.
                # Then we note that: n(n-1)/2 - (n-r)*(n-r-1)/2 = r(2n-r-1)/2
                offset = (rows * (2 * n - rows - 1) / 2).astype(np.int32)
                # We could also zero out when (rows < cols) == (rows < ids-n*rows).
                # mask = (ids <= (n + 1) * rows).astype(np.int32)
            else:
                ids = math_ops.range(n**2)
                rows = math_ops.cast(ids / n, dtype=dtypes.int32)
                offset = math_ops.cast(rows * (2 * n - rows - 1) / 2,
                                       dtype=dtypes.int32)
            return ids - offset

        # Special-case non-batch case.
        if x.get_shape().ndims == 1:
            y = array_ops.gather(x, array_ops.reshape(tril_ids(n), [n, n]))
            y = array_ops.matrix_band_part(y, -1, 0)
            y.set_shape(y.get_shape().merge_with(final_shape))
            return y

        # Make ids for each batch dim.
        if (x.get_shape().ndims is not None
                and x.get_shape()[:-1].is_fully_defined()):
            batch_shape = np.asarray(x.get_shape()[:-1].as_list(),
                                     dtype=np.int32)
            m = np.prod(batch_shape).astype(np.int32)
        else:
            batch_shape = array_ops.shape(x)[:-1]
            m = array_ops.reduce_prod(array_ops.shape(x)[:-1])
        batch_ids = math_ops.range(m)

        # Assemble the tril_ids into batch,tril_id pairs.
        idx = array_ops.stack([
            array_ops.tile(array_ops.expand_dims(batch_ids, 1), [1, n * n]),
            array_ops.tile(array_ops.expand_dims(tril_ids(n), 0), [m, 1])
        ])
        idx = array_ops.transpose(idx, [1, 2, 0])

        # Gather up, reshape, and return.
        y = array_ops.reshape(x, [-1, d])
        y = array_ops.gather_nd(y, idx)
        y = array_ops.reshape(y, array_ops.concat_v2([batch_shape, [n, n]], 0))
        y = array_ops.matrix_band_part(y, -1, 0)
        y.set_shape(y.get_shape().merge_with(final_shape))
        return y