Ejemplo n.º 1
0
def diagonal(a, offset=0, axis1=0, axis2=1):  # pylint: disable=missing-docstring
  a = asarray(a).data

  maybe_rank = a.shape.rank
  if maybe_rank is not None and offset == 0 and (
      axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or
                                                   axis2 == -1):
    return utils.tensor_to_ndarray(tf.linalg.diag_part(a))

  a = moveaxis(utils.tensor_to_ndarray(a), (axis1, axis2), (-2, -1)).data

  a_shape = tf.shape(a)

  def _zeros():  # pylint: disable=missing-docstring
    return (tf.zeros(tf.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0)

  # All zeros since diag_part doesn't handle all possible k (aka offset).
  # Written this way since cond will run shape inference on both branches,
  # and diag_part shape inference will fail when offset is out of bounds.
  a, offset = utils.cond(
      utils.logical_or(
          utils.less_equal(offset, -1 * utils.getitem(a_shape, -2)),
          utils.greater_equal(offset, utils.getitem(a_shape, -1)),
      ), _zeros, lambda: (a, offset))

  a = utils.tensor_to_ndarray(tf.linalg.diag_part(a, k=offset))
  return a
Ejemplo n.º 2
0
 def _diag_part(v, k):
   v_shape = tf.shape(v)
   v, k = utils.cond(
       utils.logical_or(
           utils.less_equal(k, -1 * utils.getitem(v_shape, 0)),
           utils.greater_equal(k, utils.getitem(v_shape, 1)),
       ), lambda: (tf.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k))
   result = tf.linalg.diag_part(v, k=k)
   return result
Ejemplo n.º 3
0
 def f(a, b):  # pylint: disable=missing-docstring
   # We can't assign to captured variable `axisa`, so make a new variable
   axis_a = axisa
   axis_b = axisb
   axis_c = axisc
   if axis is not None:
     axis_a = axis
     axis_b = axis
     axis_c = axis
   if axis_a < 0:
     axis_a = utils.add(axis_a, tf.rank(a))
   if axis_b < 0:
     axis_b = utils.add(axis_b, tf.rank(b))
   def maybe_move_axis_to_last(a, axis):
     def move_axis_to_last(a, axis):
       return tf.transpose(
           a, tf.concat(
               [tf.range(axis), tf.range(axis + 1, tf.rank(a)), [axis]],
               axis=0))
     return utils.cond(
         axis == utils.subtract(tf.rank(a), 1),
         lambda: a,
         lambda: move_axis_to_last(a, axis))
   a = maybe_move_axis_to_last(a, axis_a)
   b = maybe_move_axis_to_last(b, axis_b)
   a_dim = utils.getitem(tf.shape(a), -1)
   b_dim = utils.getitem(tf.shape(b), -1)
   def maybe_pad_0(a, size_of_last_dim):
     def pad_0(a):
       return tf.pad(a, tf.concat([tf.zeros([tf.rank(a) - 1, 2], tf.int32),
                                   tf.constant([[0, 1]], tf.int32)], axis=0))
     return utils.cond(size_of_last_dim == 2,
                       lambda: pad_0(a),
                       lambda: a)
   a = maybe_pad_0(a, a_dim)
   b = maybe_pad_0(b, b_dim)
   c = tf.linalg.cross(*utils.tf_broadcast(a, b))
   if axis_c < 0:
     axis_c = utils.add(axis_c, tf.rank(c))
   def move_last_to_axis(a, axis):
     r = tf.rank(a)
     return tf.transpose(
         a, tf.concat(
             [tf.range(axis), [r - 1], tf.range(axis, r - 1)], axis=0))
   c = utils.cond(
       (a_dim == 2) & (b_dim == 2),
       lambda: c[..., 2],
       lambda: utils.cond(  # pylint: disable=g-long-lambda
           axis_c == utils.subtract(tf.rank(c), 1),
           lambda: c,
           lambda: move_last_to_axis(c, axis_c)))
   return c
Ejemplo n.º 4
0
def trace(a, offset=0, axis1=0, axis2=1, dtype=None):  # pylint: disable=missing-docstring
    a = array_ops.asarray(a).data

    if offset == 0:
        a_shape = a.shape
        if a_shape.rank is not None:
            rank = len(a_shape)
            if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1
                                                       or axis2 == rank - 1):
                return utils.tensor_to_ndarray(tf.linalg.trace(a))

    a_rank = tf.rank(a)
    if axis1 < 0:
        axis1 += a_rank
    if axis2 < 0:
        axis2 += a_rank

    minaxis = tf.minimum(axis1, axis2)
    maxaxis = tf.maximum(axis1, axis2)

    # Move axes of interest to the end.
    range_rank = tf.range(a_rank)
    perm = tf.concat([
        range_rank[0:minaxis], range_rank[minaxis + 1:maxaxis],
        range_rank[maxaxis + 1:], [axis1, axis2]
    ],
                     axis=0)
    a = tf.transpose(a, perm)

    a_shape = tf.shape(a)

    # All zeros since diag_part doesn't handle all possible k (aka offset).
    # Written this way since cond will run shape inference on both branches,
    # and diag_part shape inference will fail when offset is out of bounds.
    a, offset = utils.cond(
        utils.logical_or(
            utils.less_equal(offset, -1 * utils.getitem(a_shape, -2)),
            utils.greater_equal(offset, utils.getitem(a_shape, -1)),
        ), lambda: (tf.zeros_like(a), 0), lambda: (a, offset))

    a = utils.tensor_to_ndarray(tf.linalg.diag_part(a, k=offset))
    return array_ops.sum(a, -1, dtype)