示例#1
0
def test_CXX11CodePrinter():
    assert CXX11CodePrinter().doprint(log1p(x)) == 'std::log1p(x)'

    cxx11printer = CXX11CodePrinter()
    assert cxx11printer.language == 'C++'
    assert cxx11printer.standard == 'C++11'
    assert 'operator' in cxx11printer.reserved_words
    assert 'noexcept' in cxx11printer.reserved_words
    assert 'concept' not in cxx11printer.reserved_words
示例#2
0
def test_CXX11CodePrinter():
    assert CXX11CodePrinter().doprint(log1p(x)) == "std::log1p(x)"

    cxx11printer = CXX11CodePrinter()
    assert cxx11printer.language == "C++"
    assert cxx11printer.standard == "C++11"
    assert "operator" in cxx11printer.reserved_words
    assert "noexcept" in cxx11printer.reserved_words
    assert "concept" not in cxx11printer.reserved_words
示例#3
0
def test_subclass_print_method__ns():
    class MyPrinter(CXX11CodePrinter):
        _ns = "my_library::"

    p = CXX11CodePrinter()
    myp = MyPrinter()

    assert p.doprint(log1p(x)) == "std::log1p(x)"
    assert myp.doprint(log1p(x)) == "my_library::log1p(x)"
示例#4
0
def GenerateFunction(function_name, expressions, expression_names,
                     write_rowwise):
    printer = CXX11CodePrinter()

    function_parameters = set()
    for expression in expressions:
        function_parameters = function_parameters.union(
            expression.free_symbols)
    function_parameters = MakeInputParameterList(function_parameters)

    text = ''

    (replacements, reduced_expressions) = cse(expressions,
                                              numbered_symbols('term'))

    for replacement in replacements:
        text += 'const Scalar ' + str(replacement[0]) + ' = ' + ccode(
            replacement[1]) + ';\n'
    if len(replacements) > 0:
        text += '\n'

    for reduced_expression, expression_name, write_expression_rowwise in zip(
            reduced_expressions, expression_names, write_rowwise):
        if isinstance(reduced_expression, ImmutableDenseMatrix):
            if write_expression_rowwise:
                for row in range(0, reduced_expression.rows):
                    row_name = (expression_name if reduced_expression.rows == 1
                                else expression_name + '_row_' + str(row))
                    result = MatrixSymbol(row_name, 1, reduced_expression.cols)
                    text += printer.doprint(reduced_expression[row, :],
                                            assign_to=result) + '\n'
                    function_parameters += ', Scalar* ' + row_name
            else:
                result = MatrixSymbol(expression_name, reduced_expression.rows,
                                      reduced_expression.cols)
                text += printer.doprint(reduced_expression,
                                        assign_to=result) + '\n'
                function_parameters += ', Scalar* ' + expression_name
        else:
            text += '*' + expression_name + ' = ' + ccode(
                reduced_expression) + ';\n'
            function_parameters += ', Scalar* ' + expression_name

    print('')
    print('void ' + function_name + '(' + function_parameters + ') {')
    for line in text.split('\n')[:-1]:  # The last line in text is empty
        print('  ' + line)
    print('}')
def GenerateFunction(function_name, expressions, expression_names,
                     write_rowwise):
    printer = CXX11CodePrinter()

    function_parameters = set()
    for expression in expressions:
        function_parameters = function_parameters.union(
            expression.free_symbols)
    function_parameters = MakeInputParameterList(function_parameters)

    text = ''

    (replacements, reduced_expressions) = cse(
        expressions, numbered_symbols('term')
    )  # , optimizations='basic')  # NOTE: This mostly resulted in higher opcount!

    # Simplify subexpressions
    replacements = list(replacements)
    for i in range(0, len(replacements)):
        replacements[i] = (replacements[i][0], simplify(replacements[i][1]))
    reduced_expressions = simplify(reduced_expressions)

    # Set in again
    simplify_op_threshold = 20

    for i in range(0, len(replacements)):
        for k in range(i + 1, len(replacements)):
            if replacements[i][0] in replacements[k][1].free_symbols:
                replacements[k] = (replacements[k][0], replacements[k][1].subs(
                    replacements[i][0], replacements[i][1]))
        for k in range(0, len(reduced_expressions)):
            if replacements[i][0] in reduced_expressions[k].free_symbols:
                reduced_expressions[k] = reduced_expressions[k].subs(
                    replacements[i][0], replacements[i][1])

    # Repeat (to maybe find better subexpressions)
    (replacements, reduced_expressions) = cse(reduced_expressions,
                                              numbered_symbols('term'))

    print('\n// opcount = ' +
          str(count_ops(replacements) + count_ops(reduced_expressions)))

    for replacement in replacements:
        text += 'const Scalar ' + str(replacement[0]) + ' = ' + ccode(
            replacement[1]) + ';\n'
    if len(replacements) > 0:
        text += '\n'

    for reduced_expression, expression_name, write_expression_rowwise in zip(
            reduced_expressions, expression_names, write_rowwise):
        if isinstance(reduced_expression, MatrixBase):
            if write_expression_rowwise:
                for row in range(0, reduced_expression.rows):
                    row_name = (expression_name if reduced_expression.rows == 1
                                else expression_name + '_row_' + str(row))
                    result = MatrixSymbol(row_name, 1, reduced_expression.cols)
                    text += printer.doprint(reduced_expression[row, :],
                                            assign_to=result) + '\n'
                    function_parameters += ', Scalar* ' + row_name
            else:
                result = MatrixSymbol(expression_name, reduced_expression.rows,
                                      reduced_expression.cols)
                text += printer.doprint(reduced_expression,
                                        assign_to=result) + '\n'
                function_parameters += ', Scalar* ' + expression_name
        else:
            text += '*' + expression_name + ' = ' + ccode(
                reduced_expression) + ';\n'
            function_parameters += ', Scalar* ' + expression_name

    print('inline void ' + function_name + '(' + function_parameters + ') {')
    for line in text.split('\n')[:-1]:  # The last line in text is empty
        print('  ' + line)
    print('}')