Ejemplo n.º 1
0
def diagonal(a, offset=0, axis1=0, axis2=1):  # pylint: disable=missing-docstring
  a = asarray(a).data

  maybe_rank = a.shape.rank
  if maybe_rank is not None and offset == 0 and (
      axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or
                                                   axis2 == -1):
    return utils.tensor_to_ndarray(tf.linalg.diag_part(a))

  a = moveaxis(utils.tensor_to_ndarray(a), (axis1, axis2), (-2, -1)).data

  a_shape = tf.shape(a)

  def _zeros():  # pylint: disable=missing-docstring
    return (tf.zeros(tf.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0)

  # All zeros since diag_part doesn't handle all possible k (aka offset).
  # Written this way since cond will run shape inference on both branches,
  # and diag_part shape inference will fail when offset is out of bounds.
  a, offset = utils.cond(
      utils.logical_or(
          utils.less_equal(offset, -1 * utils.getitem(a_shape, -2)),
          utils.greater_equal(offset, utils.getitem(a_shape, -1)),
      ), _zeros, lambda: (a, offset))

  a = utils.tensor_to_ndarray(tf.linalg.diag_part(a, k=offset))
  return a
Ejemplo n.º 2
0
def diag(v, k=0):  # pylint: disable=missing-docstring
  """Raises an error if input is not 1- or 2-d."""
  v = asarray(v).data
  v_rank = tf.rank(v)

  v.shape.with_rank_at_most(2)

  # TODO(nareshmodi): Consider a utils.Assert version that will fail during
  # tracing time if the shape is known.
  tf.debugging.Assert(
      utils.logical_or(tf.equal(v_rank, 1), tf.equal(v_rank, 2)), [v_rank])

  def _diag(v, k):
    return utils.cond(
        tf.equal(tf.size(v), 0),
        lambda: tf.zeros([abs(k), abs(k)], dtype=v.dtype),
        lambda: tf.linalg.diag(v, k=k))

  def _diag_part(v, k):
    v_shape = tf.shape(v)
    v, k = utils.cond(
        utils.logical_or(
            utils.less_equal(k, -1 * utils.getitem(v_shape, 0)),
            utils.greater_equal(k, utils.getitem(v_shape, 1)),
        ), lambda: (tf.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k))
    result = tf.linalg.diag_part(v, k=k)
    return result

  result = utils.cond(
      tf.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k))
  return utils.tensor_to_ndarray(result)
Ejemplo n.º 3
0
 def f(a, b):  # pylint: disable=missing-docstring
     return utils.cond(
         utils.logical_or(tf.rank(a) == 0,
                          tf.rank(b) == 0),
         lambda: a * b,
         lambda: utils.cond(  # pylint: disable=g-long-lambda
             tf.rank(b) == 1, lambda: tf.tensordot(a, b, axes=[[-1], [-1]]),
             lambda: tf.tensordot(a, b, axes=[[-1], [-2]])))
Ejemplo n.º 4
0
 def _diag_part(v, k):
   v_shape = tf.shape(v)
   v, k = utils.cond(
       utils.logical_or(
           utils.less_equal(k, -1 * utils.getitem(v_shape, 0)),
           utils.greater_equal(k, utils.getitem(v_shape, 1)),
       ), lambda: (tf.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k))
   result = tf.linalg.diag_part(v, k=k)
   return result
Ejemplo n.º 5
0
def trace(a, offset=0, axis1=0, axis2=1, dtype=None):  # pylint: disable=missing-docstring
    a = array_ops.asarray(a).data

    if offset == 0:
        a_shape = a.shape
        if a_shape.rank is not None:
            rank = len(a_shape)
            if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1
                                                       or axis2 == rank - 1):
                return utils.tensor_to_ndarray(tf.linalg.trace(a))

    a_rank = tf.rank(a)
    if axis1 < 0:
        axis1 += a_rank
    if axis2 < 0:
        axis2 += a_rank

    minaxis = tf.minimum(axis1, axis2)
    maxaxis = tf.maximum(axis1, axis2)

    # Move axes of interest to the end.
    range_rank = tf.range(a_rank)
    perm = tf.concat([
        range_rank[0:minaxis], range_rank[minaxis + 1:maxaxis],
        range_rank[maxaxis + 1:], [axis1, axis2]
    ],
                     axis=0)
    a = tf.transpose(a, perm)

    a_shape = tf.shape(a)

    # All zeros since diag_part doesn't handle all possible k (aka offset).
    # Written this way since cond will run shape inference on both branches,
    # and diag_part shape inference will fail when offset is out of bounds.
    a, offset = utils.cond(
        utils.logical_or(
            utils.less_equal(offset, -1 * utils.getitem(a_shape, -2)),
            utils.greater_equal(offset, utils.getitem(a_shape, -1)),
        ), lambda: (tf.zeros_like(a), 0), lambda: (a, offset))

    a = utils.tensor_to_ndarray(tf.linalg.diag_part(a, k=offset))
    return array_ops.sum(a, -1, dtype)
Ejemplo n.º 6
0
 def f(a, b):
     return utils.cond(utils.logical_or(tf.rank(a) == 0,
                                        tf.rank(b) == 0), lambda: a * b,
                       lambda: tf.tensordot(a, b, axes=[[-1], [-1]]))