def _get_multiply_adjoint_sum_axes(oshape, ishape, mshape): ishape_exp, mshape_exp = util._expand_shapes(ishape, mshape) max_ndim = max(len(ishape), len(mshape)) sum_axes = [] for i, m, o, d in zip(ishape_exp, mshape_exp, oshape, range(max_ndim)): if (i == 1 and (m != 1 or o != 1)): sum_axes.append(d) return sum_axes
def _get_multiply_oshape(ishape, mshape): ishape_exp, mshape_exp = util._expand_shapes(ishape, mshape) max_ndim = max(len(ishape), len(mshape)) oshape = [] for i, m, d in zip(ishape_exp, mshape_exp, range(max_ndim)): if not (i == m or i == 1 or m == 1): raise ValueError('Invalid shapes: {ishape}, {mshape}.'.format( ishape=ishape, mshape=mshape)) oshape.append(max(i, m)) return oshape
def _get_matmul_oshape(ishape, mshape, adjoint): ishape_exp, mshape_exp = util._expand_shapes(ishape, mshape) if adjoint: mshape_exp[-1], mshape_exp[-2] = mshape_exp[-2], mshape_exp[-1] oshape = [] for i, m in zip(ishape_exp[:-2], mshape_exp[:-2]): if not (i == m or i == 1 or m == 1): raise ValueError('Invalid shapes: {ishape}, {mshape}.'.format( ishape=ishape, mshape=mshape)) oshape.append(max(i, m)) if mshape_exp[-1] != ishape_exp[-2]: raise ValueError('Invalid shapes: {ishape}, {mshape}.'.format( ishape=ishape, mshape=mshape)) oshape += [mshape_exp[-2], ishape_exp[-1]] return oshape
def _get_right_matmul_oshape(ishape, mshape, adjoint): ishape_exp, mshape_exp = util._expand_shapes(ishape, mshape) if adjoint: mshape_exp[-1], mshape_exp[-2] = mshape_exp[-2], mshape_exp[-1] max_ndim = max(len(ishape), len(mshape)) oshape = [] for i, m, d in zip(ishape_exp[:-2], mshape_exp[:-2], range(max_ndim - 2)): if not (i == m or i == 1 or m == 1): raise ValueError('Invalid shapes: {ishape}, {mshape}.'.format( ishape=ishape, mshape=mshape)) oshape.append(max(i, m)) if ishape_exp[-1] != mshape_exp[-2]: raise ValueError('Invalid shapes: {ishape}, {mshape}.'.format( ishape=ishape, mshape=mshape)) oshape += [ishape_exp[-2], mshape_exp[-1]] return oshape