Beispiel #1
0
def diagonal(a, offset=0, axis1=0, axis2=1):  # pylint: disable=missing-docstring
    a = asarray(a).data

    maybe_rank = a.shape.rank
    if maybe_rank is not None and offset == 0 and (
            axis1 == maybe_rank - 2
            or axis1 == -2) and (axis2 == maybe_rank - 1 or axis2 == -1):
        return np_utils.tensor_to_ndarray(array_ops.matrix_diag_part(a))

    a = moveaxis(np_utils.tensor_to_ndarray(a), (axis1, axis2), (-2, -1)).data

    a_shape = array_ops.shape(a)

    def _zeros():  # pylint: disable=missing-docstring
        return (array_ops.zeros(array_ops.concat([a_shape[:-1], [0]], 0),
                                dtype=a.dtype), 0)

    # All zeros since diag_part doesn't handle all possible k (aka offset).
    # Written this way since cond will run shape inference on both branches,
    # and diag_part shape inference will fail when offset is out of bounds.
    a, offset = np_utils.cond(
        np_utils.logical_or(
            np_utils.less_equal(offset, -1 * np_utils.getitem(a_shape, -2)),
            np_utils.greater_equal(offset, np_utils.getitem(a_shape, -1)),
        ), _zeros, lambda: (a, offset))

    a = np_utils.tensor_to_ndarray(array_ops.matrix_diag_part(a, k=offset))
    return a
Beispiel #2
0
 def _diag_part(v, k):
     v_shape = array_ops.shape(v)
     v, k = np_utils.cond(
         np_utils.logical_or(
             np_utils.less_equal(k, -1 * np_utils.getitem(v_shape, 0)),
             np_utils.greater_equal(k, np_utils.getitem(v_shape, 1)),
         ), lambda: (array_ops.zeros([0, 0], dtype=v.dtype), 0), lambda:
         (v, k))
     result = array_ops.matrix_diag_part(v, k=k)
     return result
Beispiel #3
0
    def f(a, b):  # pylint: disable=missing-docstring
        # We can't assign to captured variable `axisa`, so make a new variable
        if axis is None:
            axis_a = axisa
            axis_b = axisb
            axis_c = axisc
        else:
            axis_a = axis
            axis_b = axis
            axis_c = axis
        if axis_a < 0:
            axis_a = np_utils.add(axis_a, array_ops.rank(a))
        if axis_b < 0:
            axis_b = np_utils.add(axis_b, array_ops.rank(b))

        def maybe_move_axis_to_last(a, axis):
            def move_axis_to_last(a, axis):
                return array_ops.transpose(
                    a,
                    array_ops.concat([
                        math_ops.range(axis),
                        math_ops.range(axis + 1, array_ops.rank(a)), [axis]
                    ],
                                     axis=0))

            return np_utils.cond(
                axis == np_utils.subtract(array_ops.rank(a), 1), lambda: a,
                lambda: move_axis_to_last(a, axis))

        a = maybe_move_axis_to_last(a, axis_a)
        b = maybe_move_axis_to_last(b, axis_b)
        a_dim = np_utils.getitem(array_ops.shape(a), -1)
        b_dim = np_utils.getitem(array_ops.shape(b), -1)

        def maybe_pad_0(a, size_of_last_dim):
            def pad_0(a):
                return array_ops.pad(
                    a,
                    array_ops.concat([
                        array_ops.zeros([array_ops.rank(a) - 1, 2],
                                        dtypes.int32),
                        constant_op.constant([[0, 1]], dtypes.int32)
                    ],
                                     axis=0))

            return np_utils.cond(math_ops.equal(size_of_last_dim, 2),
                                 lambda: pad_0(a), lambda: a)

        a = maybe_pad_0(a, a_dim)
        b = maybe_pad_0(b, b_dim)
        c = math_ops.cross(*np_utils.tf_broadcast(a, b))
        if axis_c < 0:
            axis_c = np_utils.add(axis_c, array_ops.rank(c))

        def move_last_to_axis(a, axis):
            r = array_ops.rank(a)
            return array_ops.transpose(
                a,
                array_ops.concat([
                    math_ops.range(axis), [r - 1],
                    math_ops.range(axis, r - 1)
                ],
                                 axis=0))

        c = np_utils.cond(
            (a_dim == 2) & (b_dim == 2),
            lambda: c[..., 2],
            lambda: np_utils.cond(  # pylint: disable=g-long-lambda
                axis_c == np_utils.subtract(array_ops.rank(c), 1), lambda: c,
                lambda: move_last_to_axis(c, axis_c)))
        return c