def testResolveBinaryExpressionTypeMVPromotesLeftSide(self): m44 = types.MatrixType(types.Integer(), 4, 4) f4 = types.VectorType(types.Float(), 4) r = types.ResolveBinaryExpressionType(op.Operation.MUL, m44, f4) assert r.GetOperandType(0) == types.MatrixType(types.Float(), 4, 4)
def testResolveBinaryExpressionForMMMulCombinedSizeIsLarger(self): # [a b] # [c d] [1 2 3 4] # [e f] [5 6 7 8] # [g h] left = types.MatrixType(types.Float(), 4, 2) right = types.MatrixType(types.Float(), 2, 4) resultType = types.ResolveBinaryExpressionType(op.Operation.MUL, left, right) expectedType = types.MatrixType(types.Float(), 4, 4) assert resultType.GetReturnType() == expectedType
def testResolveBinaryExpressionForMMMulCombinedSizeIsSmaller(self): # [1 2] # [a b c d] [3 4] # [e f g h] [5 6] # [7 8] left = types.MatrixType(types.Float(), 2, 4) right = types.MatrixType(types.Float(), 4, 2) resultType = types.ResolveBinaryExpressionType(op.Operation.MUL, left, right) expectedType = types.MatrixType(types.Float(), 2, 2) assert resultType.GetReturnType() == expectedType
def testResolveBinaryExpressionWorksOnCompatibleSizes(self): r = types.ResolveBinaryExpressionType( op.Operation.MUL, types.MatrixType(types.Float(), 4, 2), types.VectorType(types.Float(), 2)) expectedType = types.VectorType(types.Float(), 4) assert r.GetReturnType() == expectedType
def testResolveBinaryExpressionTypeMV(self): m44 = types.MatrixType(types.Float(), 4, 4) f4 = types.VectorType(types.Float(), 4) r = types.ResolveBinaryExpressionType(op.Operation.MUL, m44, f4) assert r.GetReturnType() == f4
def testResolveBinaryExpressionForMM(self): mt = types.MatrixType(types.Float(), 4, 4) for operation in { op.Operation.ADD, op.Operation.SUB, op.Operation.MUL }: resultType = types.ResolveBinaryExpressionType(operation, mt, mt) assert resultType.GetReturnType() == mt
def testResolveBinaryExpressionMVFailsForNonMultiply(self): invalidOperations = [ op.Operation.ADD, op.Operation.SUB, op.Operation.DIV ] for operation in invalidOperations: with pytest.raises(Exception): types.ResolveBinaryExpressionType( operation, types.MatrixType(types.Float(), 4, 4), types.VectorType(types.Float(), 2))
def testResolveBinaryExpressionFailsForMMDiv(self): mt = types.MatrixType(types.Float(), 4, 4) with pytest.raises(Exception): types.ResolveBinaryExpressionType(op.Operation.DIV, mt, mt)
def testResolveBinaryExpressionMVFailsOnIncompatibleSizes(self): with pytest.raises(Exception): types.ResolveBinaryExpressionType( op.Operation.MUL, types.MatrixType(types.Float(), 2, 4), types.VectorType(types.Float(), 2))
def testResolveBinaryExpressionForMMMulOnIncompatibleSizes(self): left = types.MatrixType(types.Float(), 4, 2) right = types.MatrixType(types.Float(), 3, 4) with pytest.raises(Exception): types.ResolveBinaryExpressionType(op.Operation.MUL, left, right)