Example #1
0
def diagonal(a, offset=0, axis1=0, axis2=1):  # pylint: disable=missing-docstring
  a = asarray(a).data

  maybe_rank = a.shape.rank
  if maybe_rank is not None and offset == 0 and (
      axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or
                                                   axis2 == -1):
    return utils.tensor_to_ndarray(tf.linalg.diag_part(a))

  a = moveaxis(utils.tensor_to_ndarray(a), (axis1, axis2), (-2, -1)).data

  a_shape = tf.shape(a)

  def _zeros():  # pylint: disable=missing-docstring
    return (tf.zeros(tf.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0)

  # All zeros since diag_part doesn't handle all possible k (aka offset).
  # Written this way since cond will run shape inference on both branches,
  # and diag_part shape inference will fail when offset is out of bounds.
  a, offset = utils.cond(
      utils.logical_or(
          utils.less_equal(offset, -1 * utils.getitem(a_shape, -2)),
          utils.greater_equal(offset, utils.getitem(a_shape, -1)),
      ), _zeros, lambda: (a, offset))

  a = utils.tensor_to_ndarray(tf.linalg.diag_part(a, k=offset))
  return a
Example #2
0
def diag(v, k=0):  # pylint: disable=missing-docstring
  """Raises an error if input is not 1- or 2-d."""
  v = asarray(v).data
  v_rank = tf.rank(v)

  v.shape.with_rank_at_most(2)

  # TODO(nareshmodi): Consider a utils.Assert version that will fail during
  # tracing time if the shape is known.
  tf.debugging.Assert(
      utils.logical_or(tf.equal(v_rank, 1), tf.equal(v_rank, 2)), [v_rank])

  def _diag(v, k):
    return utils.cond(
        tf.equal(tf.size(v), 0),
        lambda: tf.zeros([abs(k), abs(k)], dtype=v.dtype),
        lambda: tf.linalg.diag(v, k=k))

  def _diag_part(v, k):
    v_shape = tf.shape(v)
    v, k = utils.cond(
        utils.logical_or(
            utils.less_equal(k, -1 * utils.getitem(v_shape, 0)),
            utils.greater_equal(k, utils.getitem(v_shape, 1)),
        ), lambda: (tf.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k))
    result = tf.linalg.diag_part(v, k=k)
    return result

  result = utils.cond(
      tf.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k))
  return utils.tensor_to_ndarray(result)
Example #3
0
 def f(a, b):  # pylint: disable=missing-docstring
     return utils.cond(
         utils.logical_or(tf.rank(a) == 0,
                          tf.rank(b) == 0),
         lambda: a * b,
         lambda: utils.cond(  # pylint: disable=g-long-lambda
             tf.rank(b) == 1, lambda: tf.tensordot(a, b, axes=[[-1], [-1]]),
             lambda: tf.tensordot(a, b, axes=[[-1], [-2]])))
Example #4
0
 def _diag_part(v, k):
   v_shape = tf.shape(v)
   v, k = utils.cond(
       utils.logical_or(
           utils.less_equal(k, -1 * utils.getitem(v_shape, 0)),
           utils.greater_equal(k, utils.getitem(v_shape, 1)),
       ), lambda: (tf.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k))
   result = tf.linalg.diag_part(v, k=k)
   return result
Example #5
0
 def f(a, b):
     return utils.cond(utils.logical_or(tf.rank(a) == 0,
                                        tf.rank(b) == 0), lambda: a * b,
                       lambda: tf.tensordot(a, b, axes=[[-1], [-1]]))