コード例 #1
0
def diagonal(a, offset=0, axis1=0, axis2=1):  # pylint: disable=missing-docstring
  a = asarray(a).data

  maybe_rank = a.shape.rank
  if maybe_rank is not None and offset == 0 and (
      axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or
                                                   axis2 == -1):
    return utils.tensor_to_ndarray(tf.linalg.diag_part(a))

  a = moveaxis(utils.tensor_to_ndarray(a), (axis1, axis2), (-2, -1)).data

  a_shape = tf.shape(a)

  def _zeros():  # pylint: disable=missing-docstring
    return (tf.zeros(tf.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0)

  # All zeros since diag_part doesn't handle all possible k (aka offset).
  # Written this way since cond will run shape inference on both branches,
  # and diag_part shape inference will fail when offset is out of bounds.
  a, offset = utils.cond(
      utils.logical_or(
          utils.less_equal(offset, -1 * utils.getitem(a_shape, -2)),
          utils.greater_equal(offset, utils.getitem(a_shape, -1)),
      ), _zeros, lambda: (a, offset))

  a = utils.tensor_to_ndarray(tf.linalg.diag_part(a, k=offset))
  return a
コード例 #2
0
 def _diag_part(v, k):
   v_shape = tf.shape(v)
   v, k = utils.cond(
       utils.logical_or(
           utils.less_equal(k, -1 * utils.getitem(v_shape, 0)),
           utils.greater_equal(k, utils.getitem(v_shape, 1)),
       ), lambda: (tf.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k))
   result = tf.linalg.diag_part(v, k=k)
   return result
コード例 #3
0
ファイル: math_ops.py プロジェクト: zhaoqiuye/trax
    def f(a, b):  # pylint: disable=missing-docstring
        # We can't assign to captured variable `axisa`, so make a new variable
        axis_a = axisa
        axis_b = axisb
        axis_c = axisc
        if axis is not None:
            axis_a = axis
            axis_b = axis
            axis_c = axis
        if axis_a < 0:
            axis_a = utils.add(axis_a, tf.rank(a))
        if axis_b < 0:
            axis_b = utils.add(axis_b, tf.rank(b))

        def maybe_move_axis_to_last(a, axis):
            def move_axis_to_last(a, axis):
                return tf.transpose(
                    a,
                    tf.concat([
                        tf.range(axis),
                        tf.range(axis + 1, tf.rank(a)), [axis]
                    ],
                              axis=0))

            return utils.cond(axis == utils.subtract(tf.rank(a), 1), lambda: a,
                              lambda: move_axis_to_last(a, axis))

        a = maybe_move_axis_to_last(a, axis_a)
        b = maybe_move_axis_to_last(b, axis_b)
        a_dim = utils.getitem(tf.shape(a), -1)
        b_dim = utils.getitem(tf.shape(b), -1)

        def maybe_pad_0(a, size_of_last_dim):
            def pad_0(a):
                return tf.pad(
                    a,
                    tf.concat([
                        tf.zeros([tf.rank(a) - 1, 2], tf.int32),
                        tf.constant([[0, 1]], tf.int32)
                    ],
                              axis=0))

            return utils.cond(size_of_last_dim == 2, lambda: pad_0(a),
                              lambda: a)

        a = maybe_pad_0(a, a_dim)
        b = maybe_pad_0(b, b_dim)
        c = tf.linalg.cross(*utils.tf_broadcast(a, b))
        if axis_c < 0:
            axis_c = utils.add(axis_c, tf.rank(c))

        def move_last_to_axis(a, axis):
            r = tf.rank(a)
            return tf.transpose(
                a,
                tf.concat([tf.range(axis), [r - 1],
                           tf.range(axis, r - 1)],
                          axis=0))

        c = utils.cond(
            (a_dim == 2) & (b_dim == 2),
            lambda: c[..., 2],
            lambda: utils.cond(  # pylint: disable=g-long-lambda
                axis_c == utils.subtract(tf.rank(c), 1), lambda: c, lambda:
                move_last_to_axis(c, axis_c)))
        return c