Esempio n. 1
0
def diagonal(a, offset=0, axis1=0, axis2=1):  # pylint: disable=missing-docstring
    a = asarray(a).data

    maybe_rank = a.shape.rank
    if maybe_rank is not None and offset == 0 and (
            axis1 == maybe_rank - 2
            or axis1 == -2) and (axis2 == maybe_rank - 1 or axis2 == -1):
        return np_utils.tensor_to_ndarray(array_ops.matrix_diag_part(a))

    a = moveaxis(np_utils.tensor_to_ndarray(a), (axis1, axis2), (-2, -1)).data

    a_shape = array_ops.shape(a)

    def _zeros():  # pylint: disable=missing-docstring
        return (array_ops.zeros(array_ops.concat([a_shape[:-1], [0]], 0),
                                dtype=a.dtype), 0)

    # All zeros since diag_part doesn't handle all possible k (aka offset).
    # Written this way since cond will run shape inference on both branches,
    # and diag_part shape inference will fail when offset is out of bounds.
    a, offset = np_utils.cond(
        np_utils.logical_or(
            np_utils.less_equal(offset, -1 * np_utils.getitem(a_shape, -2)),
            np_utils.greater_equal(offset, np_utils.getitem(a_shape, -1)),
        ), _zeros, lambda: (a, offset))

    a = np_utils.tensor_to_ndarray(array_ops.matrix_diag_part(a, k=offset))
    return a
Esempio n. 2
0
def diag(v, k=0):  # pylint: disable=missing-docstring
    """Raises an error if input is not 1- or 2-d."""
    v = asarray(v).data
    v_rank = array_ops.rank(v)

    v.shape.with_rank_at_most(2)

    # TODO(nareshmodi): Consider a np_utils.Assert version that will fail during
    # tracing time if the shape is known.
    control_flow_ops.Assert(
        np_utils.logical_or(math_ops.equal(v_rank, 1),
                            math_ops.equal(v_rank, 2)), [v_rank])

    def _diag(v, k):
        return np_utils.cond(
            math_ops.equal(array_ops.size(v), 0),
            lambda: array_ops.zeros([abs(k), abs(k)], dtype=v.dtype),
            lambda: array_ops.matrix_diag(v, k=k))

    def _diag_part(v, k):
        v_shape = array_ops.shape(v)
        v, k = np_utils.cond(
            np_utils.logical_or(
                np_utils.less_equal(k, -1 * np_utils.getitem(v_shape, 0)),
                np_utils.greater_equal(k, np_utils.getitem(v_shape, 1)),
            ), lambda: (array_ops.zeros([0, 0], dtype=v.dtype), 0), lambda:
            (v, k))
        result = array_ops.matrix_diag_part(v, k=k)
        return result

    result = np_utils.cond(math_ops.equal(v_rank, 1), lambda: _diag(v, k),
                           lambda: _diag_part(v, k))
    return np_utils.tensor_to_ndarray(result)
Esempio n. 3
0
 def f(a, b):  # pylint: disable=missing-docstring
     return np_utils.cond(
         np_utils.logical_or(math_ops.equal(array_ops.rank(a), 0),
                             math_ops.equal(array_ops.rank(b), 0)),
         lambda: a * b,
         lambda: np_utils.cond(  # pylint: disable=g-long-lambda
             math_ops.equal(array_ops.rank(b), 1), lambda: math_ops.
             tensordot(a, b, axes=[[-1], [-1]]), lambda: math_ops.tensordot(
                 a, b, axes=[[-1], [-2]])))
Esempio n. 4
0
 def _diag_part(v, k):
     v_shape = array_ops.shape(v)
     v, k = np_utils.cond(
         np_utils.logical_or(
             np_utils.less_equal(k, -1 * np_utils.getitem(v_shape, 0)),
             np_utils.greater_equal(k, np_utils.getitem(v_shape, 1)),
         ), lambda: (array_ops.zeros([0, 0], dtype=v.dtype), 0), lambda:
         (v, k))
     result = array_ops.matrix_diag_part(v, k=k)
     return result
Esempio n. 5
0
 def f(a, b):
     return np_utils.cond(
         np_utils.logical_or(math_ops.equal(array_ops.rank(a), 0),
                             math_ops.equal(array_ops.rank(b), 0)),
         lambda: a * b, lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]))