Esempio n. 1
0
def _get_multiply_adjoint_sum_axes(oshape, ishape, mshape):
    ishape_exp, mshape_exp = util._expand_shapes(ishape, mshape)
    max_ndim = max(len(ishape), len(mshape))
    sum_axes = []
    for i, m, o, d in zip(ishape_exp, mshape_exp, oshape, range(max_ndim)):
        if (i == 1 and (m != 1 or o != 1)):
            sum_axes.append(d)

    return sum_axes
Esempio n. 2
0
def _get_multiply_oshape(ishape, mshape):
    ishape_exp, mshape_exp = util._expand_shapes(ishape, mshape)
    max_ndim = max(len(ishape), len(mshape))
    oshape = []
    for i, m, d in zip(ishape_exp, mshape_exp, range(max_ndim)):
        if not (i == m or i == 1 or m == 1):
            raise ValueError('Invalid shapes: {ishape}, {mshape}.'.format(
                ishape=ishape, mshape=mshape))

        oshape.append(max(i, m))

    return oshape
Esempio n. 3
0
def _get_matmul_oshape(ishape, mshape, adjoint):
    ishape_exp, mshape_exp = util._expand_shapes(ishape, mshape)
    if adjoint:
        mshape_exp[-1], mshape_exp[-2] = mshape_exp[-2], mshape_exp[-1]

    oshape = []
    for i, m in zip(ishape_exp[:-2], mshape_exp[:-2]):
        if not (i == m or i == 1 or m == 1):
            raise ValueError('Invalid shapes: {ishape}, {mshape}.'.format(
                ishape=ishape, mshape=mshape))

        oshape.append(max(i, m))

    if mshape_exp[-1] != ishape_exp[-2]:
        raise ValueError('Invalid shapes: {ishape}, {mshape}.'.format(
            ishape=ishape, mshape=mshape))

    oshape += [mshape_exp[-2], ishape_exp[-1]]

    return oshape
Esempio n. 4
0
def _get_right_matmul_oshape(ishape, mshape, adjoint):
    ishape_exp, mshape_exp = util._expand_shapes(ishape, mshape)
    if adjoint:
        mshape_exp[-1], mshape_exp[-2] = mshape_exp[-2], mshape_exp[-1]

    max_ndim = max(len(ishape), len(mshape))
    oshape = []
    for i, m, d in zip(ishape_exp[:-2], mshape_exp[:-2], range(max_ndim - 2)):
        if not (i == m or i == 1 or m == 1):
            raise ValueError('Invalid shapes: {ishape}, {mshape}.'.format(
                ishape=ishape, mshape=mshape))

        oshape.append(max(i, m))

    if ishape_exp[-1] != mshape_exp[-2]:
        raise ValueError('Invalid shapes: {ishape}, {mshape}.'.format(
            ishape=ishape, mshape=mshape))

    oshape += [ishape_exp[-2], mshape_exp[-1]]

    return oshape